Skip to main content

anamnesis_core/
embedding.rs

1//! `EmbeddingProvider` trait — the only seam between Anamnesis's RAG layer
2//! and the underlying model runtime.
3//!
4//! ## Why a trait
5//!
6//! Anamnesis runs its **own** RAG stack (see `docs/BLUEPRINT.md §6.6.1`).
7//! Source-system vectors (mem0, Hermes, …) are kept only as `provenance` and
8//! never enter the retrieval path. This trait is therefore the *only* path by
9//! which any vector ever reaches the index.
10//!
11//! ## Invariants every implementor must hold
12//!
13//! 1. **Stable `model_id`** — must be deterministic for the lifetime of a
14//!    given (provider, model, version) tuple. The store uses
15//!    `(content_hash, model_id)` as the embedding cache key; a drifting id
16//!    silently invalidates the cache or, worse, mixes incompatible vectors.
17//! 2. **Stable `dim`** — must match the size of every vector returned.
18//! 3. **Deterministic normalization** — vectors returned by `embed_query` and
19//!    `embed_batch` must be in the same numeric regime (e.g. both L2-
20//!    normalised) so cosine similarity is meaningful.
21//! 4. **Pure** — the trait itself does no IO scheduling; callers (the
22//!    embedding worker) own batching, retry, and concurrency.
23
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26
27use crate::error::Result;
28
29/// Stable identifier for an embedding model.
30///
31/// Format convention (not enforced, but recommended):
32/// `"<provider>:<model>:<version>"`, e.g. `"local:multilingual-e5-small:1"`.
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(transparent)]
35pub struct ModelId(pub String);
36
37impl ModelId {
38    /// Build a model id from provider + model name + version.
39    pub fn new(provider: &str, model: &str, version: u32) -> Self {
40        Self(format!("{provider}:{model}:{version}"))
41    }
42
43    /// Borrow as `&str`.
44    pub fn as_str(&self) -> &str {
45        &self.0
46    }
47}
48
49impl std::fmt::Display for ModelId {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.write_str(&self.0)
52    }
53}
54
55/// Hint to the provider for asymmetric models (e.g. e5 / bge use different
56/// prefixes for queries vs documents).
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum EmbeddingTask {
59    /// The text is a user search query.
60    Query,
61    /// The text is an indexed document/chunk.
62    Document,
63}
64
65/// The only seam by which vectors enter the Anamnesis index.
66#[async_trait]
67pub trait EmbeddingProvider: Send + Sync {
68    /// Stable id — see invariants in the module docs.
69    fn model_id(&self) -> ModelId;
70
71    /// Vector dimensionality. Must match every vector returned.
72    fn dim(&self) -> u16;
73
74    /// Embed a single query string. Default impl forwards to `embed_batch`.
75    async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
76        let mut out = self.embed_batch(&[text], EmbeddingTask::Query).await?;
77        out.pop()
78            .ok_or_else(|| crate::error::Error::Other("provider returned no vector".into()))
79    }
80
81    /// Embed a batch of texts. The provider is responsible for chunking the
82    /// batch into model-friendly sizes; the worker calling this owns the
83    /// outer batching loop.
84    async fn embed_batch(&self, texts: &[&str], task: EmbeddingTask) -> Result<Vec<Vec<f32>>>;
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn model_id_format_is_stable() {
93        let id = ModelId::new("local", "multilingual-e5-small", 1);
94        assert_eq!(id.as_str(), "local:multilingual-e5-small:1");
95    }
96
97    #[test]
98    fn model_id_roundtrips_through_json() {
99        let id = ModelId::new("local", "bge-m3", 1);
100        let s = serde_json::to_string(&id).unwrap();
101        // serde_transparent → just the string, not an object
102        assert_eq!(s, "\"local:bge-m3:1\"");
103        let back: ModelId = serde_json::from_str(&s).unwrap();
104        assert_eq!(back, id);
105    }
106
107    /// Minimal in-memory provider used to lock the trait shape and prove the
108    /// default `embed_query` forwarding works.
109    struct FakeProvider {
110        id: ModelId,
111        dim: u16,
112    }
113
114    #[async_trait]
115    impl EmbeddingProvider for FakeProvider {
116        fn model_id(&self) -> ModelId {
117            self.id.clone()
118        }
119        fn dim(&self) -> u16 {
120            self.dim
121        }
122        async fn embed_batch(&self, texts: &[&str], _task: EmbeddingTask) -> Result<Vec<Vec<f32>>> {
123            // Deterministic dummy vector: length 4, filled with text length / 100.
124            Ok(texts
125                .iter()
126                .map(|t| {
127                    let v = (t.len() as f32) / 100.0;
128                    vec![v; self.dim as usize]
129                })
130                .collect())
131        }
132    }
133
134    #[tokio::test]
135    async fn default_embed_query_forwards_to_batch() {
136        let p = FakeProvider {
137            id: ModelId::new("test", "fake", 1),
138            dim: 4,
139        };
140        let v = p.embed_query("hello world").await.unwrap();
141        assert_eq!(v.len(), 4);
142        assert!((v[0] - 0.11).abs() < f32::EPSILON);
143    }
144
145    #[tokio::test]
146    async fn batch_returns_one_vector_per_input() {
147        let p = FakeProvider {
148            id: ModelId::new("test", "fake", 1),
149            dim: 4,
150        };
151        let v = p
152            .embed_batch(&["a", "bb", "ccc"], EmbeddingTask::Document)
153            .await
154            .unwrap();
155        assert_eq!(v.len(), 3);
156        assert!(v.iter().all(|row| row.len() == 4));
157    }
158
159    #[tokio::test]
160    async fn embed_query_propagates_empty_provider_result() {
161        struct Empty;
162        #[async_trait]
163        impl EmbeddingProvider for Empty {
164            fn model_id(&self) -> ModelId {
165                ModelId::new("test", "empty", 1)
166            }
167            fn dim(&self) -> u16 {
168                4
169            }
170            async fn embed_batch(
171                &self,
172                _texts: &[&str],
173                _task: EmbeddingTask,
174            ) -> Result<Vec<Vec<f32>>> {
175                Ok(vec![])
176            }
177        }
178        let err = Empty.embed_query("x").await.unwrap_err();
179        assert!(format!("{err}").contains("no vector"));
180    }
181}