1use async_trait::async_trait;
32use serde::{Deserialize, Serialize};
33use thiserror::Error;
34use tracing::info;
35
36#[derive(Debug, Error)]
40pub enum EmbeddingError {
41 #[error("HTTP error: {0}")]
42 Http(String),
43
44 #[error("Response parse error: {0}")]
45 Parse(String),
46
47 #[error("Shape error: {0}")]
48 Shape(String),
49
50 #[error("Provider not available: {0}")]
51 ProviderUnavailable(String),
52
53 #[error("Provider initialization error: {0}")]
54 Provider(String),
55}
56
57fn build_http_client(timeout: std::time::Duration) -> Result<reqwest::Client, EmbeddingError> {
62 reqwest::Client::builder()
63 .timeout(timeout)
64 .build()
65 .map_err(|e| EmbeddingError::Provider(format!("Failed to create HTTP client: {e}")))
66}
67
68pub fn deterministic_fallback_embedding(seed: &str, dimensions: usize) -> Vec<f32> {
73 if dimensions == 0 {
74 return Vec::new();
75 }
76
77 let mut state: u64 = 0xcbf29ce484222325;
79 for b in seed.as_bytes() {
80 state ^= u64::from(*b);
81 state = state.wrapping_mul(0x100000001b3);
82 }
83 if state == 0 {
84 state = 1;
85 }
86
87 let mut out = Vec::with_capacity(dimensions);
88 for _ in 0..dimensions {
89 state ^= state >> 12;
91 state ^= state << 25;
92 state ^= state >> 27;
93 let r = state.wrapping_mul(0x2545f4914f6cdd1d);
94 let unit = (r as f64 / u64::MAX as f64) as f32;
95 out.push(unit * 2.0 - 1.0);
96 }
97
98 normalize_or_unit(out)
99}
100
101pub fn sanitize_embedding(candidate: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
109 if dimensions == 0 {
110 return Vec::new();
111 }
112 if candidate.len() != dimensions || candidate.iter().any(|x| !x.is_finite()) {
113 return deterministic_fallback_embedding(seed, dimensions);
114 }
115
116 let norm_sq: f32 = candidate.iter().map(|x| x * x).sum();
117 if !norm_sq.is_finite() || norm_sq <= 1e-12 {
118 return deterministic_fallback_embedding(seed, dimensions);
119 }
120
121 let normalized = normalize_or_unit(candidate);
122 if normalized.iter().all(|x| x.is_finite()) {
123 normalized
124 } else {
125 deterministic_fallback_embedding(seed, dimensions)
126 }
127}
128
129fn normalize_or_unit(mut vector: Vec<f32>) -> Vec<f32> {
130 if vector.is_empty() {
131 return vector;
132 }
133
134 let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
135 if !norm_sq.is_finite() || norm_sq <= 1e-12 {
136 let mut unit = vec![0.0_f32; vector.len()];
137 unit[0] = 1.0;
138 return unit;
139 }
140
141 let norm = norm_sq.sqrt();
142 for v in &mut vector {
143 *v /= norm;
144 }
145 vector
146}
147
148#[async_trait]
155pub trait EmbeddingProvider: Send + Sync + std::fmt::Debug {
156 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
158
159 fn provider_name(&self) -> &str;
162}
163
164#[derive(Debug)]
171pub struct OllamaProvider {
172 client: reqwest::Client,
173 base_url: String,
174 model: String,
175}
176
177#[derive(Serialize)]
178struct OllamaEmbedRequest<'a> {
179 model: &'a str,
180 input: Vec<&'a str>,
181}
182
183#[derive(Deserialize)]
184struct OllamaEmbedResponse {
185 embeddings: Vec<Vec<f32>>,
186}
187
188impl OllamaProvider {
189 pub fn new(base_url: &str, model: &str) -> Result<Self, EmbeddingError> {
190 let client = build_http_client(brain::timeouts::EMBEDDING_OLLAMA)?;
192 Ok(Self {
193 client,
194 base_url: base_url.trim_end_matches('/').to_string(),
195 model: model.to_string(),
196 })
197 }
198
199 pub async fn health_check(&self) -> bool {
201 let url = format!("{}/api/tags", self.base_url);
202 self.client
203 .get(&url)
204 .send()
205 .await
206 .map(|r| r.status().is_success())
207 .unwrap_or(false)
208 }
209}
210
211#[async_trait]
212impl EmbeddingProvider for OllamaProvider {
213 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
214 if texts.is_empty() {
215 return Ok(Vec::new());
216 }
217 let url = format!("{}/api/embed", self.base_url);
218 let resp = self
219 .client
220 .post(&url)
221 .json(&OllamaEmbedRequest {
222 model: &self.model,
223 input: texts.to_vec(),
224 })
225 .send()
226 .await
227 .map_err(|e| EmbeddingError::Http(format!("Request failed: {e}")))?;
228
229 if !resp.status().is_success() {
230 let status = resp.status();
231 let body = resp.text().await.unwrap_or_default();
232 return Err(EmbeddingError::Http(format!("HTTP {status}: {body}")));
233 }
234
235 let parsed: OllamaEmbedResponse = resp
236 .json()
237 .await
238 .map_err(|e| EmbeddingError::Parse(format!("Failed to parse Ollama response: {e}")))?;
239
240 if parsed.embeddings.len() != texts.len() {
241 return Err(EmbeddingError::Shape(format!(
242 "Expected {} embeddings, got {}",
243 texts.len(),
244 parsed.embeddings.len()
245 )));
246 }
247 Ok(parsed.embeddings)
248 }
249
250 fn provider_name(&self) -> &str {
251 "ollama"
252 }
253}
254
255#[derive(Debug)]
262pub struct OpenAIProvider {
263 client: reqwest::Client,
264 base_url: String,
265 model: String,
266 api_key: String,
267}
268
269#[derive(Serialize)]
270struct OpenAIEmbedRequest<'a> {
271 model: &'a str,
272 input: Vec<&'a str>,
273}
274
275#[derive(Deserialize)]
276struct OpenAIEmbedResponse {
277 data: Vec<OpenAIEmbedData>,
278}
279
280#[derive(Deserialize)]
281struct OpenAIEmbedData {
282 embedding: Vec<f32>,
283 index: usize,
284}
285
286impl OpenAIProvider {
287 pub fn new(base_url: &str, model: &str, api_key: &str) -> Result<Self, EmbeddingError> {
288 let client = build_http_client(brain::timeouts::EMBEDDING_OPENAI)?;
289 Ok(Self {
290 client,
291 base_url: base_url.trim_end_matches('/').to_string(),
292 model: model.to_string(),
293 api_key: api_key.to_string(),
294 })
295 }
296}
297
298#[async_trait]
299impl EmbeddingProvider for OpenAIProvider {
300 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
301 if texts.is_empty() {
302 return Ok(Vec::new());
303 }
304 let url = format!("{}/embeddings", self.base_url);
305 let resp = self
306 .client
307 .post(&url)
308 .bearer_auth(&self.api_key)
309 .json(&OpenAIEmbedRequest {
310 model: &self.model,
311 input: texts.to_vec(),
312 })
313 .send()
314 .await
315 .map_err(|e| EmbeddingError::Http(format!("Request failed: {e}")))?;
316
317 if !resp.status().is_success() {
318 let status = resp.status();
319 let body = resp.text().await.unwrap_or_default();
320 return Err(EmbeddingError::Http(format!("HTTP {status}: {body}")));
321 }
322
323 let mut parsed: OpenAIEmbedResponse = resp
324 .json()
325 .await
326 .map_err(|e| EmbeddingError::Parse(format!("Failed to parse OpenAI response: {e}")))?;
327
328 if parsed.data.len() != texts.len() {
329 return Err(EmbeddingError::Shape(format!(
330 "Expected {} embeddings, got {}",
331 texts.len(),
332 parsed.data.len()
333 )));
334 }
335 parsed.data.sort_by_key(|d| d.index);
337 Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
338 }
339
340 fn provider_name(&self) -> &str {
341 "openai"
342 }
343}
344
345pub struct Embedder {
354 inner: Box<dyn EmbeddingProvider>,
355}
356
357impl std::fmt::Debug for Embedder {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 write!(f, "Embedder({})", self.inner.provider_name())
360 }
361}
362
363impl Embedder {
364 pub fn new(inner: Box<dyn EmbeddingProvider>) -> Self {
367 Self { inner }
368 }
369
370 pub fn for_ollama(base_url: &str, model: &str) -> Result<Self, EmbeddingError> {
372 info!(model, "Embedding provider: Ollama");
373 Ok(Self::new(Box::new(OllamaProvider::new(base_url, model)?)))
374 }
375
376 pub fn for_openai(base_url: &str, model: &str, api_key: &str) -> Result<Self, EmbeddingError> {
378 info!(model, base_url, "Embedding provider: OpenAI-compatible");
379 Ok(Self::new(Box::new(OpenAIProvider::new(
380 base_url, model, api_key,
381 )?)))
382 }
383
384 pub fn from_config(
389 provider: &str,
390 base_url: &str,
391 model: &str,
392 api_key: &str,
393 ) -> Result<Option<Self>, EmbeddingError> {
394 match provider {
395 "openai" => Ok(Some(Self::for_openai(base_url, model, api_key)?)),
396 _ => Ok(Some(Self::for_ollama(base_url, model)?)),
397 }
398 }
399
400 pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
402 let mut batch = self.embed_batch(&[text]).await?;
403 Ok(batch.remove(0))
404 }
405
406 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
408 self.inner.embed_batch(texts).await
409 }
410
411 pub fn provider_name(&self) -> &str {
413 self.inner.provider_name()
414 }
415}
416
417#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_ollama_provider_new() {
425 let p = OllamaProvider::new("http://localhost:11434", "nomic-embed-text").unwrap();
426 assert_eq!(p.model, "nomic-embed-text");
427 assert_eq!(p.base_url, "http://localhost:11434");
428 }
429
430 #[test]
431 fn test_ollama_provider_trims_trailing_slash() {
432 let p = OllamaProvider::new("http://localhost:11434/", "nomic-embed-text").unwrap();
433 assert_eq!(p.base_url, "http://localhost:11434");
434 }
435
436 #[test]
437 fn test_openai_provider_new() {
438 let p = OpenAIProvider::new(
439 "https://api.openai.com/v1",
440 "text-embedding-3-small",
441 "sk-x",
442 )
443 .unwrap();
444 assert_eq!(p.model, "text-embedding-3-small");
445 assert_eq!(p.base_url, "https://api.openai.com/v1");
446 }
447
448 #[test]
449 fn test_embedder_provider_name() {
450 let e = Embedder::for_ollama("http://localhost:11434", "nomic-embed-text").unwrap();
451 assert_eq!(e.provider_name(), "ollama");
452
453 let e2 = Embedder::for_openai("https://api.openai.com/v1", "text-embedding-3-small", "k")
454 .unwrap();
455 assert_eq!(e2.provider_name(), "openai");
456 }
457
458 #[tokio::test]
460 #[ignore = "Requires Ollama server running locally with nomic-embed-text"]
461 async fn test_ollama_embed_live() {
462 let e = Embedder::for_ollama("http://localhost:11434", "nomic-embed-text").unwrap();
463 let v = e.embed("Hello, world!").await.unwrap();
464 assert_eq!(v.len(), 768, "nomic-embed-text produces 768-dim vectors");
465 }
466
467 #[test]
468 fn test_deterministic_fallback_embedding_is_stable_and_normalized() {
469 let a = deterministic_fallback_embedding("remember rust", 16);
470 let b = deterministic_fallback_embedding("remember rust", 16);
471 let c = deterministic_fallback_embedding("remember bun", 16);
472
473 assert_eq!(a.len(), 16);
474 assert_eq!(a, b, "same seed must produce same fallback vector");
475 assert_ne!(a, c, "different seeds should produce different vectors");
476
477 let norm = a.iter().map(|x| x * x).sum::<f32>().sqrt();
478 assert!(
479 (norm - 1.0).abs() < 1e-5,
480 "fallback vector must be normalized"
481 );
482 }
483
484 #[test]
485 fn test_sanitize_embedding_rejects_invalid_inputs() {
486 let zero = vec![0.0_f32; 8];
487 let nan = vec![f32::NAN; 8];
488 let wrong = vec![0.1_f32; 4];
489
490 let a = sanitize_embedding(zero, 8, "seed-a");
491 let b = sanitize_embedding(nan, 8, "seed-b");
492 let c = sanitize_embedding(wrong, 8, "seed-c");
493
494 assert_eq!(a.len(), 8);
495 assert_eq!(b.len(), 8);
496 assert_eq!(c.len(), 8);
497 assert!(a.iter().all(|x| x.is_finite()));
498 assert!(b.iter().all(|x| x.is_finite()));
499 assert!(c.iter().all(|x| x.is_finite()));
500 }
501
502 #[tokio::test]
505 async fn test_ollama_embed_success() {
506 let mut server = mockito::Server::new_async().await;
507 let mock = server
508 .mock("POST", "/api/embed")
509 .with_status(200)
510 .with_header("content-type", "application/json")
511 .with_body(r#"{"embeddings": [[0.1, 0.2, 0.3, 0.4]]}"#)
512 .create_async()
513 .await;
514
515 let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
516 let v = embedder.embed("hello world").await.unwrap();
517 assert_eq!(v, vec![0.1, 0.2, 0.3, 0.4]);
518 mock.assert_async().await;
519 }
520
521 #[tokio::test]
522 async fn test_ollama_embed_500_error_returns_http_error() {
523 let mut server = mockito::Server::new_async().await;
524 let _mock = server
525 .mock("POST", "/api/embed")
526 .with_status(500)
527 .with_body("server error")
528 .create_async()
529 .await;
530
531 let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
532 let err = embedder.embed("hello").await.unwrap_err();
533 assert!(
534 matches!(err, EmbeddingError::Http(_)),
535 "expected Http error, got {err:?}"
536 );
537 }
538
539 #[tokio::test]
540 async fn test_ollama_embed_malformed_json() {
541 let mut server = mockito::Server::new_async().await;
542 let _mock = server
543 .mock("POST", "/api/embed")
544 .with_status(200)
545 .with_header("content-type", "application/json")
546 .with_body("not json at all")
547 .create_async()
548 .await;
549
550 let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
551 let err = embedder.embed("hello").await.unwrap_err();
552 assert!(
553 matches!(err, EmbeddingError::Parse(_)),
554 "expected Parse error, got {err:?}"
555 );
556 }
557
558 #[tokio::test]
559 async fn test_ollama_embed_shape_mismatch() {
560 let mut server = mockito::Server::new_async().await;
561 let _mock = server
562 .mock("POST", "/api/embed")
563 .with_status(200)
564 .with_header("content-type", "application/json")
565 .with_body(r#"{"embeddings": [[0.1, 0.2]]}"#)
567 .create_async()
568 .await;
569
570 let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
571 let err = embedder
572 .embed_batch(&["first text", "second text"])
573 .await
574 .unwrap_err();
575 assert!(
576 matches!(err, EmbeddingError::Shape(_)),
577 "expected Shape error, got {err:?}"
578 );
579 }
580
581 #[tokio::test]
582 async fn test_openai_embed_success() {
583 let mut server = mockito::Server::new_async().await;
584 let mock = server
585 .mock("POST", "/embeddings")
586 .match_header("authorization", "Bearer test-key")
587 .with_status(200)
588 .with_header("content-type", "application/json")
589 .with_body(
590 r#"{
591 "data": [
592 {"embedding": [0.9, 0.8, 0.7], "index": 0}
593 ]
594 }"#,
595 )
596 .create_async()
597 .await;
598
599 let embedder =
600 Embedder::for_openai(&server.url(), "text-embedding-3-small", "test-key").unwrap();
601 let v = embedder.embed("hello").await.unwrap();
602 assert_eq!(v, vec![0.9, 0.8, 0.7]);
603 mock.assert_async().await;
604 }
605
606 #[tokio::test]
607 async fn test_openai_embed_reorders_by_index() {
608 let mut server = mockito::Server::new_async().await;
609 let _mock = server
610 .mock("POST", "/embeddings")
611 .with_status(200)
612 .with_header("content-type", "application/json")
613 .with_body(
615 r#"{
616 "data": [
617 {"embedding": [0.2], "index": 1},
618 {"embedding": [0.1], "index": 0}
619 ]
620 }"#,
621 )
622 .create_async()
623 .await;
624
625 let embedder = Embedder::for_openai(&server.url(), "model", "key").unwrap();
626 let batch = embedder.embed_batch(&["a", "b"]).await.unwrap();
627 assert_eq!(batch[0], vec![0.1]);
628 assert_eq!(batch[1], vec![0.2]);
629 }
630}