Skip to main content

traitclaw_rag/
embedding.rs

1//! Embedding-based vector retrieval for RAG pipelines.
2//!
3//! Provides the [`EmbeddingProvider`] trait and [`EmbeddingRetriever`] —
4//! an in-memory cosine-similarity retriever backed by any embedding model.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use traitclaw_rag::embedding::{EmbeddingProvider, EmbeddingRetriever};
10//! use traitclaw_rag::{Document, Retriever};
11//! use async_trait::async_trait;
12//!
13//! struct MyEmbedder;
14//!
15//! #[async_trait]
16//! impl EmbeddingProvider for MyEmbedder {
17//!     async fn embed(&self, texts: &[&str]) -> traitclaw_core::Result<Vec<Vec<f64>>> {
18//!         // Return dummy vectors of dimension 3 for each text
19//!         Ok(texts.iter().map(|_| vec![0.1, 0.2, 0.3]).collect())
20//!     }
21//! }
22//!
23//! # async fn example() -> traitclaw_core::Result<()> {
24//! let mut retriever = EmbeddingRetriever::new(MyEmbedder);
25//! retriever.add_documents(vec![
26//!     Document::new("doc1", "Rust systems programming"),
27//!     Document::new("doc2", "Python data science"),
28//! ]).await?;
29//!
30//! let results = retriever.retrieve("Rust", 1).await?;
31//! assert_eq!(results.len(), 1);
32//! # Ok(())
33//! # }
34//! ```
35
36use async_trait::async_trait;
37use traitclaw_core::{Error, Result};
38
39use crate::{Document, Retriever};
40
41/// Async trait for computing text embeddings.
42///
43/// Implement this to integrate any embedding model (OpenAI, Cohere, local, etc.).
44#[async_trait]
45pub trait EmbeddingProvider: Send + Sync + 'static {
46    /// Compute embeddings for `texts`.
47    ///
48    /// Returns a vector of embeddings — one per input text.
49    /// All embeddings must have the same dimensionality.
50    async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f64>>>;
51}
52
53/// Stored entry: embedding vector + original document.
54struct VectorEntry {
55    embedding: Vec<f64>,
56    document: Document,
57}
58
59/// In-memory vector retriever using cosine similarity search.
60///
61/// Stores document embeddings and retrieves the top-k most similar documents
62/// for a query, optionally filtered by a minimum similarity threshold.
63pub struct EmbeddingRetriever<P: EmbeddingProvider> {
64    provider: P,
65    store: Vec<VectorEntry>,
66    similarity_threshold: f64,
67}
68
69impl<P: EmbeddingProvider> EmbeddingRetriever<P> {
70    /// Create a new retriever backed by the given [`EmbeddingProvider`].
71    #[must_use]
72    pub fn new(provider: P) -> Self {
73        Self {
74            provider,
75            store: Vec::new(),
76            similarity_threshold: 0.0,
77        }
78    }
79
80    /// Set the minimum cosine similarity required to include a result.
81    ///
82    /// Results with similarity below this threshold are excluded.
83    ///
84    /// # Example
85    ///
86    /// ```rust,no_run
87    /// # use traitclaw_rag::embedding::{EmbeddingProvider, EmbeddingRetriever};
88    /// # struct Dummy;
89    /// # #[async_trait::async_trait]
90    /// # impl EmbeddingProvider for Dummy {
91    /// #     async fn embed(&self, texts: &[&str]) -> traitclaw_core::Result<Vec<Vec<f64>>> {
92    /// #         Ok(vec![vec![0.0]; texts.len()])
93    /// #     }
94    /// # }
95    /// let retriever = EmbeddingRetriever::new(Dummy).with_similarity_threshold(0.7);
96    /// ```
97    #[must_use]
98    pub fn with_similarity_threshold(mut self, threshold: f64) -> Self {
99        self.similarity_threshold = threshold;
100        self
101    }
102
103    /// Embed and store documents in the in-memory vector store.
104    ///
105    /// Calls `embed()` exactly once with all document texts.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the embedding provider fails or returns the wrong
110    /// number of embeddings.
111    pub async fn add_documents(&mut self, documents: Vec<Document>) -> Result<()> {
112        if documents.is_empty() {
113            return Ok(());
114        }
115
116        let texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
117        let embeddings = self.provider.embed(&texts).await?;
118
119        if embeddings.len() != documents.len() {
120            return Err(Error::Runtime(format!(
121                "EmbeddingProvider returned {} embeddings for {} documents",
122                embeddings.len(),
123                documents.len()
124            )));
125        }
126
127        for (doc, emb) in documents.into_iter().zip(embeddings) {
128            self.store.push(VectorEntry {
129                embedding: emb,
130                document: doc,
131            });
132        }
133
134        Ok(())
135    }
136
137    /// Number of stored documents.
138    #[must_use]
139    pub fn len(&self) -> usize {
140        self.store.len()
141    }
142
143    /// Whether the vector store is empty.
144    #[must_use]
145    pub fn is_empty(&self) -> bool {
146        self.store.is_empty()
147    }
148}
149
150#[async_trait]
151impl<P: EmbeddingProvider> Retriever for EmbeddingRetriever<P> {
152    /// Embed `query`, compute cosine similarity with all stored docs, return top-k.
153    async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Document>> {
154        if self.store.is_empty() {
155            return Ok(Vec::new());
156        }
157
158        let query_embs = self.provider.embed(&[query]).await?;
159        let query_emb = query_embs
160            .into_iter()
161            .next()
162            .ok_or_else(|| Error::Runtime("EmbeddingProvider returned empty for query".into()))?;
163
164        let mut scored: Vec<(f64, &Document)> = self
165            .store
166            .iter()
167            .map(|entry| {
168                let sim = cosine_similarity(&query_emb, &entry.embedding);
169                (sim, &entry.document)
170            })
171            .filter(|(sim, _)| *sim >= self.similarity_threshold)
172            .collect();
173
174        // Sort by similarity descending
175        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
176        scored.truncate(limit);
177
178        let results = scored
179            .into_iter()
180            .map(|(sim, doc)| {
181                let mut d = doc.clone();
182                d.score = sim;
183                d
184            })
185            .collect();
186
187        Ok(results)
188    }
189}
190
191/// Compute cosine similarity between two vectors.
192///
193/// Returns 0.0 if either vector has zero magnitude.
194#[allow(clippy::cast_precision_loss)]
195fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
196    if a.len() != b.len() || a.is_empty() {
197        return 0.0;
198    }
199
200    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
201    let mag_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
202    let mag_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
203
204    if mag_a == 0.0 || mag_b == 0.0 {
205        return 0.0;
206    }
207
208    dot / (mag_a * mag_b)
209}
210
211// ─────────────────────────────────────────────────────────────────────────────
212// Test helper: Counting embedder
213// ─────────────────────────────────────────────────────────────────────────────
214
215#[cfg(test)]
216pub(crate) mod test_helpers {
217    use std::sync::atomic::{AtomicUsize, Ordering};
218    use std::sync::Arc;
219
220    use super::*;
221
222    /// Tracking embedder: counts embed() calls and returns deterministic vectors.
223    pub struct CountingEmbedder {
224        pub call_count: Arc<AtomicUsize>,
225        #[allow(dead_code)]
226        pub dim: usize,
227    }
228
229    impl CountingEmbedder {
230        pub fn new(dim: usize) -> Self {
231            Self {
232                call_count: Arc::new(AtomicUsize::new(0)),
233                dim,
234            }
235        }
236    }
237
238    #[async_trait]
239    impl EmbeddingProvider for CountingEmbedder {
240        async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f64>>> {
241            self.call_count.fetch_add(1, Ordering::Relaxed);
242            // Generate slightly different vectors per text (based on char count)
243            Ok(texts
244                .iter()
245                .map(|t| {
246                    let base = (t.len() % 10) as f64 / 10.0;
247                    vec![base, 1.0 - base, 0.5]
248                })
249                .collect())
250        }
251    }
252
253    /// Simple embedder that uses specific vectors for known texts.
254    pub struct FixedEmbedder(pub Vec<Vec<f64>>);
255
256    #[async_trait]
257    impl EmbeddingProvider for FixedEmbedder {
258        async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f64>>> {
259            // Returns embeddings cycling through the provided list
260            Ok(texts
261                .iter()
262                .enumerate()
263                .map(|(i, _)| self.0[i % self.0.len()].clone())
264                .collect())
265        }
266    }
267}
268
269// ─────────────────────────────────────────────────────────────────────────────
270// Tests
271// ─────────────────────────────────────────────────────────────────────────────
272
273#[cfg(test)]
274mod tests {
275    use std::sync::atomic::Ordering;
276    use std::sync::Arc;
277
278    use super::test_helpers::*;
279    use super::*;
280    use crate::Document;
281
282    fn make_docs(n: usize) -> Vec<Document> {
283        (0..n)
284            .map(|i| Document::new(format!("doc{i}"), format!("document content {i}")))
285            .collect()
286    }
287
288    #[tokio::test]
289    async fn test_add_documents_calls_embed_once() {
290        // AC #9: add_documents calls embed() exactly once with all texts
291        let embedder = CountingEmbedder::new(3);
292        let count = embedder.call_count.clone();
293        let mut retriever = EmbeddingRetriever::new(embedder);
294        retriever.add_documents(make_docs(10)).await.unwrap();
295
296        assert_eq!(
297            count.load(Ordering::Relaxed),
298            1,
299            "embed should be called exactly once"
300        );
301        assert_eq!(retriever.len(), 10);
302    }
303
304    #[tokio::test]
305    async fn test_retrieve_returns_at_most_limit() {
306        // AC #7: 10 docs → query returns ≤ limit results
307        let mut retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
308        retriever.add_documents(make_docs(10)).await.unwrap();
309
310        let results = retriever.retrieve("content", 3).await.unwrap();
311        assert!(
312            results.len() <= 3,
313            "expected ≤3 results, got {}",
314            results.len()
315        );
316    }
317
318    #[tokio::test]
319    async fn test_retrieve_sorted_by_similarity_desc() {
320        // AC #7: results sorted by similarity descending
321        let mut retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
322        retriever.add_documents(make_docs(5)).await.unwrap();
323
324        let results = retriever.retrieve("query", 5).await.unwrap();
325        for window in results.windows(2) {
326            assert!(
327                window[0].score >= window[1].score,
328                "results not sorted: {} < {}",
329                window[0].score,
330                window[1].score
331            );
332        }
333    }
334
335    #[tokio::test]
336    async fn test_similarity_threshold_filters_results() {
337        // AC #8: threshold 0.9 → fewer results than threshold 0.5
338        let vecs = vec![
339            vec![1.0, 0.0, 0.0], // identical to query → sim = 1.0
340            vec![0.0, 1.0, 0.0], // orthogonal → sim = 0.0
341            vec![0.7, 0.7, 0.0], // partial match → sim ≈ 0.49
342        ];
343
344        let mut retriever_low =
345            EmbeddingRetriever::new(FixedEmbedder(vecs.clone())).with_similarity_threshold(0.0);
346        retriever_low.add_documents(make_docs(3)).await.unwrap();
347        let results_low = retriever_low.retrieve("doc", 10).await.unwrap();
348
349        let mut retriever_high =
350            EmbeddingRetriever::new(FixedEmbedder(vecs.clone())).with_similarity_threshold(0.95);
351        retriever_high.add_documents(make_docs(3)).await.unwrap();
352        let results_high = retriever_high.retrieve("doc", 10).await.unwrap();
353
354        // High threshold → fewer results
355        assert!(
356            results_high.len() < results_low.len() || results_high.len() <= 1,
357            "high threshold should filter more: low={}, high={}",
358            results_low.len(),
359            results_high.len()
360        );
361    }
362
363    #[tokio::test]
364    async fn test_empty_store_returns_empty() {
365        let retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
366        let results = retriever.retrieve("any query", 10).await.unwrap();
367        assert!(results.is_empty());
368    }
369
370    #[tokio::test]
371    async fn test_add_empty_documents() {
372        let mut retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
373        retriever.add_documents(vec![]).await.unwrap();
374        assert!(retriever.is_empty());
375    }
376
377    #[test]
378    fn test_cosine_similarity_identical() {
379        let v = vec![1.0, 2.0, 3.0];
380        let sim = cosine_similarity(&v, &v);
381        assert!((sim - 1.0).abs() < 1e-9);
382    }
383
384    #[test]
385    fn test_cosine_similarity_orthogonal() {
386        let a = vec![1.0, 0.0];
387        let b = vec![0.0, 1.0];
388        let sim = cosine_similarity(&a, &b);
389        assert!(sim.abs() < 1e-9);
390    }
391
392    #[test]
393    fn test_cosine_similarity_zero_vector() {
394        let a = vec![0.0, 0.0];
395        let b = vec![1.0, 0.0];
396        assert!(cosine_similarity(&a, &b).abs() < f64::EPSILON);
397    }
398
399    #[test]
400    fn test_embedding_retriever_is_retriever_trait_object() {
401        // Can be used as Arc<dyn Retriever>
402        let r = EmbeddingRetriever::new(CountingEmbedder::new(3));
403        let _: Arc<dyn Retriever> = Arc::new(r);
404    }
405}