Skip to main content

brainos_hippocampus/
embedding.rs

1//! Embedding pipeline — Ollama and OpenAI-compatible backends.
2//!
3//! The active provider is determined by `llm.provider` in Brain config:
4//! - `"ollama"` → calls `POST /api/embed` on the local Ollama server
5//! - `"openai"` → calls `POST /v1/embeddings` on any OpenAI-compatible endpoint
6//!
7//! The embedding model and output dimension are explicit config values:
8//!
9//! ```yaml
10//! llm:
11//!   provider: "ollama"
12//!   base_url: "http://localhost:11434"
13//!
14//! embedding:
15//!   model: "nomic-embed-text"   # must be pulled in Ollama / available via OpenAI
16//!   dimensions: 768              # must match the model's actual output size
17//! ```
18//!
19//! For OpenAI:
20//! ```yaml
21//! llm:
22//!   provider: "openai"
23//!   base_url: "https://api.openai.com/v1"
24//!   api_key: "sk-..."
25//!
26//! embedding:
27//!   model: "text-embedding-3-small"
28//!   dimensions: 1536
29//! ```
30
31use async_trait::async_trait;
32use serde::{Deserialize, Serialize};
33use thiserror::Error;
34use tracing::info;
35
36// ─── Errors ──────────────────────────────────────────────────────────────────
37
38/// Errors from the embedding pipeline.
39#[derive(Debug, Error)]
40pub enum EmbeddingError {
41    #[error("HTTP error: {0}")]
42    Http(String),
43
44    #[error("Response parse error: {0}")]
45    Parse(String),
46
47    #[error("Shape error: {0}")]
48    Shape(String),
49
50    #[error("Provider not available: {0}")]
51    ProviderUnavailable(String),
52
53    #[error("Provider initialization error: {0}")]
54    Provider(String),
55}
56
57/// Build a `reqwest::Client` with the given timeout, mapping construction
58/// failure to [`EmbeddingError::Provider`]. Shared by Ollama + OpenAI
59/// embedding providers so the timeout/error-mapping policy stays
60/// single-sourced.
61fn build_http_client(timeout: std::time::Duration) -> Result<reqwest::Client, EmbeddingError> {
62    reqwest::Client::builder()
63        .timeout(timeout)
64        .build()
65        .map_err(|e| EmbeddingError::Provider(format!("Failed to create HTTP client: {e}")))
66}
67
68/// Deterministically generate a non-zero fallback embedding and normalize it.
69///
70/// This is used when the embedding provider is unavailable or returns an invalid
71/// vector shape/value. The output is stable for the same `(seed, dimensions)`.
72pub fn deterministic_fallback_embedding(seed: &str, dimensions: usize) -> Vec<f32> {
73    if dimensions == 0 {
74        return Vec::new();
75    }
76
77    // FNV-1a 64-bit hash as deterministic PRNG seed.
78    let mut state: u64 = 0xcbf29ce484222325;
79    for b in seed.as_bytes() {
80        state ^= u64::from(*b);
81        state = state.wrapping_mul(0x100000001b3);
82    }
83    if state == 0 {
84        state = 1;
85    }
86
87    let mut out = Vec::with_capacity(dimensions);
88    for _ in 0..dimensions {
89        // xorshift64*
90        state ^= state >> 12;
91        state ^= state << 25;
92        state ^= state >> 27;
93        let r = state.wrapping_mul(0x2545f4914f6cdd1d);
94        let unit = (r as f64 / u64::MAX as f64) as f32;
95        out.push(unit * 2.0 - 1.0);
96    }
97
98    normalize_or_unit(out)
99}
100
101/// Validate and normalize an embedding vector, with deterministic fallback.
102///
103/// Conditions enforced:
104/// - exact `dimensions`
105/// - finite values only
106/// - non-zero norm
107/// - normalized output
108pub fn sanitize_embedding(candidate: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
109    if dimensions == 0 {
110        return Vec::new();
111    }
112    if candidate.len() != dimensions || candidate.iter().any(|x| !x.is_finite()) {
113        return deterministic_fallback_embedding(seed, dimensions);
114    }
115
116    let norm_sq: f32 = candidate.iter().map(|x| x * x).sum();
117    if !norm_sq.is_finite() || norm_sq <= 1e-12 {
118        return deterministic_fallback_embedding(seed, dimensions);
119    }
120
121    let normalized = normalize_or_unit(candidate);
122    if normalized.iter().all(|x| x.is_finite()) {
123        normalized
124    } else {
125        deterministic_fallback_embedding(seed, dimensions)
126    }
127}
128
129fn normalize_or_unit(mut vector: Vec<f32>) -> Vec<f32> {
130    if vector.is_empty() {
131        return vector;
132    }
133
134    let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
135    if !norm_sq.is_finite() || norm_sq <= 1e-12 {
136        let mut unit = vec![0.0_f32; vector.len()];
137        unit[0] = 1.0;
138        return unit;
139    }
140
141    let norm = norm_sq.sqrt();
142    for v in &mut vector {
143        *v /= norm;
144    }
145    vector
146}
147
148// ─── EmbeddingProvider trait ─────────────────────────────────────────────────
149
150/// Pluggable embedding backend.
151///
152/// New providers (Voyage, Cohere, …) implement this trait and register in
153/// [`Embedder::from_config`] — no other call site changes.
154#[async_trait]
155pub trait EmbeddingProvider: Send + Sync + std::fmt::Debug {
156    /// Embed a batch of texts. Implementations preserve input order.
157    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
158
159    /// Stable lower-case identifier used in logs and metrics
160    /// (e.g. `"ollama"`, `"openai"`).
161    fn provider_name(&self) -> &str;
162}
163
164// ─── Ollama Provider ─────────────────────────────────────────────────────────
165
166/// Ollama embedding provider — calls `POST /api/embed`.
167///
168/// Shares the same Ollama instance used for LLM inference.  Pull the model
169/// once with `ollama pull <model>` and it works immediately.
170#[derive(Debug)]
171pub struct OllamaProvider {
172    client: reqwest::Client,
173    base_url: String,
174    model: String,
175}
176
177#[derive(Serialize)]
178struct OllamaEmbedRequest<'a> {
179    model: &'a str,
180    input: Vec<&'a str>,
181}
182
183#[derive(Deserialize)]
184struct OllamaEmbedResponse {
185    embeddings: Vec<Vec<f32>>,
186}
187
188impl OllamaProvider {
189    pub fn new(base_url: &str, model: &str) -> Result<Self, EmbeddingError> {
190        // Ollama may need to load the model on first call — allow up to 120s
191        let client = build_http_client(brain::timeouts::EMBEDDING_OLLAMA)?;
192        Ok(Self {
193            client,
194            base_url: base_url.trim_end_matches('/').to_string(),
195            model: model.to_string(),
196        })
197    }
198
199    /// Check if the Ollama server is reachable.
200    pub async fn health_check(&self) -> bool {
201        let url = format!("{}/api/tags", self.base_url);
202        self.client
203            .get(&url)
204            .send()
205            .await
206            .map(|r| r.status().is_success())
207            .unwrap_or(false)
208    }
209}
210
211#[async_trait]
212impl EmbeddingProvider for OllamaProvider {
213    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
214        if texts.is_empty() {
215            return Ok(Vec::new());
216        }
217        let url = format!("{}/api/embed", self.base_url);
218        let resp = self
219            .client
220            .post(&url)
221            .json(&OllamaEmbedRequest {
222                model: &self.model,
223                input: texts.to_vec(),
224            })
225            .send()
226            .await
227            .map_err(|e| EmbeddingError::Http(format!("Request failed: {e}")))?;
228
229        if !resp.status().is_success() {
230            let status = resp.status();
231            let body = resp.text().await.unwrap_or_default();
232            return Err(EmbeddingError::Http(format!("HTTP {status}: {body}")));
233        }
234
235        let parsed: OllamaEmbedResponse = resp
236            .json()
237            .await
238            .map_err(|e| EmbeddingError::Parse(format!("Failed to parse Ollama response: {e}")))?;
239
240        if parsed.embeddings.len() != texts.len() {
241            return Err(EmbeddingError::Shape(format!(
242                "Expected {} embeddings, got {}",
243                texts.len(),
244                parsed.embeddings.len()
245            )));
246        }
247        Ok(parsed.embeddings)
248    }
249
250    fn provider_name(&self) -> &str {
251        "ollama"
252    }
253}
254
255// ─── OpenAI-compatible Provider ──────────────────────────────────────────────
256
257/// OpenAI-compatible embedding provider — calls `POST /v1/embeddings`.
258///
259/// Works with OpenAI, OpenRouter, Azure OpenAI, or any OpenAI-compatible
260/// local server (e.g. vLLM, LM Studio, Ollama in OpenAI-compat mode).
261#[derive(Debug)]
262pub struct OpenAIProvider {
263    client: reqwest::Client,
264    base_url: String,
265    model: String,
266    api_key: String,
267}
268
269#[derive(Serialize)]
270struct OpenAIEmbedRequest<'a> {
271    model: &'a str,
272    input: Vec<&'a str>,
273}
274
275#[derive(Deserialize)]
276struct OpenAIEmbedResponse {
277    data: Vec<OpenAIEmbedData>,
278}
279
280#[derive(Deserialize)]
281struct OpenAIEmbedData {
282    embedding: Vec<f32>,
283    index: usize,
284}
285
286impl OpenAIProvider {
287    pub fn new(base_url: &str, model: &str, api_key: &str) -> Result<Self, EmbeddingError> {
288        let client = build_http_client(brain::timeouts::EMBEDDING_OPENAI)?;
289        Ok(Self {
290            client,
291            base_url: base_url.trim_end_matches('/').to_string(),
292            model: model.to_string(),
293            api_key: api_key.to_string(),
294        })
295    }
296}
297
298#[async_trait]
299impl EmbeddingProvider for OpenAIProvider {
300    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
301        if texts.is_empty() {
302            return Ok(Vec::new());
303        }
304        let url = format!("{}/embeddings", self.base_url);
305        let resp = self
306            .client
307            .post(&url)
308            .bearer_auth(&self.api_key)
309            .json(&OpenAIEmbedRequest {
310                model: &self.model,
311                input: texts.to_vec(),
312            })
313            .send()
314            .await
315            .map_err(|e| EmbeddingError::Http(format!("Request failed: {e}")))?;
316
317        if !resp.status().is_success() {
318            let status = resp.status();
319            let body = resp.text().await.unwrap_or_default();
320            return Err(EmbeddingError::Http(format!("HTTP {status}: {body}")));
321        }
322
323        let mut parsed: OpenAIEmbedResponse = resp
324            .json()
325            .await
326            .map_err(|e| EmbeddingError::Parse(format!("Failed to parse OpenAI response: {e}")))?;
327
328        if parsed.data.len() != texts.len() {
329            return Err(EmbeddingError::Shape(format!(
330                "Expected {} embeddings, got {}",
331                texts.len(),
332                parsed.data.len()
333            )));
334        }
335        // Sort by index to guarantee order (OpenAI may reorder for batching)
336        parsed.data.sort_by_key(|d| d.index);
337        Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
338    }
339
340    fn provider_name(&self) -> &str {
341        "openai"
342    }
343}
344
345// ─── Embedder ─────────────────────────────────────────────────────────────────
346
347/// Active embedding backend — owns a trait object so new providers can be
348/// added without touching this type.
349///
350/// Constructed once at startup via [`Embedder::for_ollama`], [`Embedder::for_openai`]
351/// or [`Embedder::from_config`], then shared (behind `tokio::sync::Mutex`)
352/// across the signal pipeline.
353pub struct Embedder {
354    inner: Box<dyn EmbeddingProvider>,
355}
356
357impl std::fmt::Debug for Embedder {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        write!(f, "Embedder({})", self.inner.provider_name())
360    }
361}
362
363impl Embedder {
364    /// Wrap any provider impl. Lets crates outside this module register
365    /// custom backends (Voyage, Cohere, in-process test doubles, …).
366    pub fn new(inner: Box<dyn EmbeddingProvider>) -> Self {
367        Self { inner }
368    }
369
370    /// Create an Ollama-backed embedder.
371    pub fn for_ollama(base_url: &str, model: &str) -> Result<Self, EmbeddingError> {
372        info!(model, "Embedding provider: Ollama");
373        Ok(Self::new(Box::new(OllamaProvider::new(base_url, model)?)))
374    }
375
376    /// Create an OpenAI-compatible embedder.
377    pub fn for_openai(base_url: &str, model: &str, api_key: &str) -> Result<Self, EmbeddingError> {
378        info!(model, base_url, "Embedding provider: OpenAI-compatible");
379        Ok(Self::new(Box::new(OpenAIProvider::new(
380            base_url, model, api_key,
381        )?)))
382    }
383
384    /// Create an embedder from Brain config settings.
385    ///
386    /// Selects the appropriate backend (Ollama or OpenAI-compatible)
387    /// based on `provider`. Returns `None` if no provider is configured.
388    pub fn from_config(
389        provider: &str,
390        base_url: &str,
391        model: &str,
392        api_key: &str,
393    ) -> Result<Option<Self>, EmbeddingError> {
394        match provider {
395            "openai" => Ok(Some(Self::for_openai(base_url, model, api_key)?)),
396            _ => Ok(Some(Self::for_ollama(base_url, model)?)),
397        }
398    }
399
400    /// Embed a single text string.
401    pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
402        let mut batch = self.embed_batch(&[text]).await?;
403        Ok(batch.remove(0))
404    }
405
406    /// Embed a batch of texts for throughput.
407    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
408        self.inner.embed_batch(texts).await
409    }
410
411    /// Provider name for logging.
412    pub fn provider_name(&self) -> &str {
413        self.inner.provider_name()
414    }
415}
416
417// ─── Tests ───────────────────────────────────────────────────────────────────
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_ollama_provider_new() {
425        let p = OllamaProvider::new("http://localhost:11434", "nomic-embed-text").unwrap();
426        assert_eq!(p.model, "nomic-embed-text");
427        assert_eq!(p.base_url, "http://localhost:11434");
428    }
429
430    #[test]
431    fn test_ollama_provider_trims_trailing_slash() {
432        let p = OllamaProvider::new("http://localhost:11434/", "nomic-embed-text").unwrap();
433        assert_eq!(p.base_url, "http://localhost:11434");
434    }
435
436    #[test]
437    fn test_openai_provider_new() {
438        let p = OpenAIProvider::new(
439            "https://api.openai.com/v1",
440            "text-embedding-3-small",
441            "sk-x",
442        )
443        .unwrap();
444        assert_eq!(p.model, "text-embedding-3-small");
445        assert_eq!(p.base_url, "https://api.openai.com/v1");
446    }
447
448    #[test]
449    fn test_embedder_provider_name() {
450        let e = Embedder::for_ollama("http://localhost:11434", "nomic-embed-text").unwrap();
451        assert_eq!(e.provider_name(), "ollama");
452
453        let e2 = Embedder::for_openai("https://api.openai.com/v1", "text-embedding-3-small", "k")
454            .unwrap();
455        assert_eq!(e2.provider_name(), "openai");
456    }
457
458    /// Requires Ollama running locally with nomic-embed-text pulled.
459    #[tokio::test]
460    #[ignore = "Requires Ollama server running locally with nomic-embed-text"]
461    async fn test_ollama_embed_live() {
462        let e = Embedder::for_ollama("http://localhost:11434", "nomic-embed-text").unwrap();
463        let v = e.embed("Hello, world!").await.unwrap();
464        assert_eq!(v.len(), 768, "nomic-embed-text produces 768-dim vectors");
465    }
466
467    #[test]
468    fn test_deterministic_fallback_embedding_is_stable_and_normalized() {
469        let a = deterministic_fallback_embedding("remember rust", 16);
470        let b = deterministic_fallback_embedding("remember rust", 16);
471        let c = deterministic_fallback_embedding("remember bun", 16);
472
473        assert_eq!(a.len(), 16);
474        assert_eq!(a, b, "same seed must produce same fallback vector");
475        assert_ne!(a, c, "different seeds should produce different vectors");
476
477        let norm = a.iter().map(|x| x * x).sum::<f32>().sqrt();
478        assert!(
479            (norm - 1.0).abs() < 1e-5,
480            "fallback vector must be normalized"
481        );
482    }
483
484    #[test]
485    fn test_sanitize_embedding_rejects_invalid_inputs() {
486        let zero = vec![0.0_f32; 8];
487        let nan = vec![f32::NAN; 8];
488        let wrong = vec![0.1_f32; 4];
489
490        let a = sanitize_embedding(zero, 8, "seed-a");
491        let b = sanitize_embedding(nan, 8, "seed-b");
492        let c = sanitize_embedding(wrong, 8, "seed-c");
493
494        assert_eq!(a.len(), 8);
495        assert_eq!(b.len(), 8);
496        assert_eq!(c.len(), 8);
497        assert!(a.iter().all(|x| x.is_finite()));
498        assert!(b.iter().all(|x| x.is_finite()));
499        assert!(c.iter().all(|x| x.is_finite()));
500    }
501
502    // ─── Mock HTTP tests ────────────────────────────────────────────────────
503
504    #[tokio::test]
505    async fn test_ollama_embed_success() {
506        let mut server = mockito::Server::new_async().await;
507        let mock = server
508            .mock("POST", "/api/embed")
509            .with_status(200)
510            .with_header("content-type", "application/json")
511            .with_body(r#"{"embeddings": [[0.1, 0.2, 0.3, 0.4]]}"#)
512            .create_async()
513            .await;
514
515        let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
516        let v = embedder.embed("hello world").await.unwrap();
517        assert_eq!(v, vec![0.1, 0.2, 0.3, 0.4]);
518        mock.assert_async().await;
519    }
520
521    #[tokio::test]
522    async fn test_ollama_embed_500_error_returns_http_error() {
523        let mut server = mockito::Server::new_async().await;
524        let _mock = server
525            .mock("POST", "/api/embed")
526            .with_status(500)
527            .with_body("server error")
528            .create_async()
529            .await;
530
531        let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
532        let err = embedder.embed("hello").await.unwrap_err();
533        assert!(
534            matches!(err, EmbeddingError::Http(_)),
535            "expected Http error, got {err:?}"
536        );
537    }
538
539    #[tokio::test]
540    async fn test_ollama_embed_malformed_json() {
541        let mut server = mockito::Server::new_async().await;
542        let _mock = server
543            .mock("POST", "/api/embed")
544            .with_status(200)
545            .with_header("content-type", "application/json")
546            .with_body("not json at all")
547            .create_async()
548            .await;
549
550        let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
551        let err = embedder.embed("hello").await.unwrap_err();
552        assert!(
553            matches!(err, EmbeddingError::Parse(_)),
554            "expected Parse error, got {err:?}"
555        );
556    }
557
558    #[tokio::test]
559    async fn test_ollama_embed_shape_mismatch() {
560        let mut server = mockito::Server::new_async().await;
561        let _mock = server
562            .mock("POST", "/api/embed")
563            .with_status(200)
564            .with_header("content-type", "application/json")
565            // Request asks for 2 texts, server returns 1 embedding
566            .with_body(r#"{"embeddings": [[0.1, 0.2]]}"#)
567            .create_async()
568            .await;
569
570        let embedder = Embedder::for_ollama(&server.url(), "test-model").unwrap();
571        let err = embedder
572            .embed_batch(&["first text", "second text"])
573            .await
574            .unwrap_err();
575        assert!(
576            matches!(err, EmbeddingError::Shape(_)),
577            "expected Shape error, got {err:?}"
578        );
579    }
580
581    #[tokio::test]
582    async fn test_openai_embed_success() {
583        let mut server = mockito::Server::new_async().await;
584        let mock = server
585            .mock("POST", "/embeddings")
586            .match_header("authorization", "Bearer test-key")
587            .with_status(200)
588            .with_header("content-type", "application/json")
589            .with_body(
590                r#"{
591                    "data": [
592                        {"embedding": [0.9, 0.8, 0.7], "index": 0}
593                    ]
594                }"#,
595            )
596            .create_async()
597            .await;
598
599        let embedder =
600            Embedder::for_openai(&server.url(), "text-embedding-3-small", "test-key").unwrap();
601        let v = embedder.embed("hello").await.unwrap();
602        assert_eq!(v, vec![0.9, 0.8, 0.7]);
603        mock.assert_async().await;
604    }
605
606    #[tokio::test]
607    async fn test_openai_embed_reorders_by_index() {
608        let mut server = mockito::Server::new_async().await;
609        let _mock = server
610            .mock("POST", "/embeddings")
611            .with_status(200)
612            .with_header("content-type", "application/json")
613            // Intentionally out of order
614            .with_body(
615                r#"{
616                    "data": [
617                        {"embedding": [0.2], "index": 1},
618                        {"embedding": [0.1], "index": 0}
619                    ]
620                }"#,
621            )
622            .create_async()
623            .await;
624
625        let embedder = Embedder::for_openai(&server.url(), "model", "key").unwrap();
626        let batch = embedder.embed_batch(&["a", "b"]).await.unwrap();
627        assert_eq!(batch[0], vec![0.1]);
628        assert_eq!(batch[1], vec![0.2]);
629    }
630}