Skip to main content

cognis_rag/vectorstore/
in_memory.rs

1//! In-process vector store. Linear scan for similarity search — fine
2//! for prototyping and small corpora (< ~10k docs); not for production
3//! at scale.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use uuid::Uuid;
10
11use cognis_core::{CognisError, Result};
12
13use crate::distance::Distance;
14use crate::embeddings::Embeddings;
15
16use super::{SearchResult, VectorStore};
17
18#[derive(Clone)]
19struct StoredDoc {
20    id: String,
21    text: String,
22    vector: Vec<f32>,
23    metadata: HashMap<String, serde_json::Value>,
24}
25
26/// Linear-scan in-memory vector store.
27pub struct InMemoryVectorStore {
28    embedder: Arc<dyn Embeddings>,
29    distance: Distance,
30    docs: Vec<StoredDoc>,
31}
32
33impl InMemoryVectorStore {
34    /// New empty store with the given embedder and Cosine distance.
35    pub fn new(embedder: Arc<dyn Embeddings>) -> Self {
36        Self::with_distance(embedder, Distance::Cosine)
37    }
38
39    /// New empty store with explicit distance.
40    pub fn with_distance(embedder: Arc<dyn Embeddings>, distance: Distance) -> Self {
41        Self {
42            embedder,
43            distance,
44            docs: Vec::new(),
45        }
46    }
47
48    /// Currently configured distance metric.
49    pub fn distance(&self) -> Distance {
50        self.distance
51    }
52}
53
54#[async_trait]
55impl VectorStore for InMemoryVectorStore {
56    async fn add_texts(
57        &mut self,
58        texts: Vec<String>,
59        metadata: Option<Vec<HashMap<String, serde_json::Value>>>,
60    ) -> Result<Vec<String>> {
61        if texts.is_empty() {
62            return Ok(Vec::new());
63        }
64        let vectors = self.embedder.embed_documents(texts.clone()).await?;
65        self.add_vectors(vectors, texts, metadata).await
66    }
67
68    async fn add_vectors(
69        &mut self,
70        vectors: Vec<Vec<f32>>,
71        texts: Vec<String>,
72        metadata: Option<Vec<HashMap<String, serde_json::Value>>>,
73    ) -> Result<Vec<String>> {
74        if vectors.len() != texts.len() {
75            return Err(CognisError::Configuration(format!(
76                "vectors.len() ({}) must equal texts.len() ({})",
77                vectors.len(),
78                texts.len()
79            )));
80        }
81        if let Some(m) = &metadata {
82            if m.len() != texts.len() {
83                return Err(CognisError::Configuration(format!(
84                    "metadata.len() ({}) must equal texts.len() ({})",
85                    m.len(),
86                    texts.len()
87                )));
88            }
89        }
90
91        let mut ids = Vec::with_capacity(texts.len());
92        for (i, (text, vector)) in texts.into_iter().zip(vectors).enumerate() {
93            let id = Uuid::new_v4().to_string();
94            let md = metadata.as_ref().map(|m| m[i].clone()).unwrap_or_default();
95            ids.push(id.clone());
96            self.docs.push(StoredDoc {
97                id,
98                text,
99                vector,
100                metadata: md,
101            });
102        }
103        Ok(ids)
104    }
105
106    async fn similarity_search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
107        let qv = self.embedder.embed_query(query.to_string()).await?;
108        self.similarity_search_by_vector(qv, k).await
109    }
110
111    async fn similarity_search_by_vector(
112        &self,
113        query_vector: Vec<f32>,
114        k: usize,
115    ) -> Result<Vec<SearchResult>> {
116        if self.docs.is_empty() || k == 0 {
117            return Ok(Vec::new());
118        }
119
120        // Compute scores.
121        let mut scored: Vec<(f32, &StoredDoc)> = self
122            .docs
123            .iter()
124            .map(|d| (self.distance.similarity(&query_vector, &d.vector), d))
125            .collect();
126
127        // Sort descending by score.
128        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
129
130        Ok(scored
131            .into_iter()
132            .take(k)
133            .map(|(score, d)| SearchResult {
134                id: d.id.clone(),
135                text: d.text.clone(),
136                score,
137                metadata: d.metadata.clone(),
138            })
139            .collect())
140    }
141
142    async fn delete(&mut self, ids: Vec<String>) -> Result<()> {
143        let to_delete: std::collections::HashSet<String> = ids.into_iter().collect();
144        self.docs.retain(|d| !to_delete.contains(&d.id));
145        Ok(())
146    }
147
148    fn len(&self) -> usize {
149        self.docs.len()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::embeddings::FakeEmbeddings;
157
158    fn fake_embedder(dim: usize) -> Arc<dyn Embeddings> {
159        Arc::new(FakeEmbeddings::new(dim))
160    }
161
162    #[tokio::test]
163    async fn add_texts_assigns_ids() {
164        let mut store = InMemoryVectorStore::new(fake_embedder(8));
165        let ids = store
166            .add_texts(vec!["a".into(), "b".into(), "c".into()], None)
167            .await
168            .unwrap();
169        assert_eq!(ids.len(), 3);
170        assert_eq!(store.len(), 3);
171        // IDs should be unique.
172        let unique: std::collections::HashSet<_> = ids.iter().collect();
173        assert_eq!(unique.len(), 3);
174    }
175
176    #[tokio::test]
177    async fn search_returns_matches_in_order() {
178        let mut store = InMemoryVectorStore::new(fake_embedder(8));
179        store
180            .add_texts(vec!["dog".into(), "cat".into(), "fish".into()], None)
181            .await
182            .unwrap();
183
184        let results = store.similarity_search("dog", 2).await.unwrap();
185        assert_eq!(results.len(), 2);
186        // The exact-match query "dog" should be the top result (FakeEmbeddings
187        // is deterministic, so the same input yields the same vector).
188        assert_eq!(results[0].text, "dog");
189    }
190
191    #[tokio::test]
192    async fn search_respects_k() {
193        let mut store = InMemoryVectorStore::new(fake_embedder(8));
194        store
195            .add_texts((0..10).map(|i| format!("doc {i}")).collect(), None)
196            .await
197            .unwrap();
198        let r1 = store.similarity_search("doc 5", 1).await.unwrap();
199        let r5 = store.similarity_search("doc 5", 5).await.unwrap();
200        assert_eq!(r1.len(), 1);
201        assert_eq!(r5.len(), 5);
202    }
203
204    #[tokio::test]
205    async fn metadata_roundtrip() {
206        let mut store = InMemoryVectorStore::new(fake_embedder(8));
207        let mut md = HashMap::new();
208        md.insert("source".into(), serde_json::json!("wiki"));
209        md.insert("year".into(), serde_json::json!(2024));
210        store
211            .add_texts(vec!["hello".into()], Some(vec![md.clone()]))
212            .await
213            .unwrap();
214        let r = store.similarity_search("hello", 1).await.unwrap();
215        assert_eq!(r[0].metadata.get("source").unwrap(), "wiki");
216        assert_eq!(r[0].metadata.get("year").unwrap(), 2024);
217    }
218
219    #[tokio::test]
220    async fn add_vectors_dimension_mismatch_errors() {
221        let mut store = InMemoryVectorStore::new(fake_embedder(8));
222        let err = store
223            .add_vectors(vec![vec![0.1; 8], vec![0.2; 8]], vec!["one".into()], None)
224            .await
225            .unwrap_err();
226        assert!(format!("{err}").contains("must equal"));
227    }
228
229    #[tokio::test]
230    async fn delete_removes_docs() {
231        let mut store = InMemoryVectorStore::new(fake_embedder(8));
232        let ids = store
233            .add_texts(vec!["a".into(), "b".into(), "c".into()], None)
234            .await
235            .unwrap();
236        store.delete(vec![ids[1].clone()]).await.unwrap();
237        assert_eq!(store.len(), 2);
238        let r = store.similarity_search("b", 5).await.unwrap();
239        // "b" has been deleted, so it shouldn't appear in results.
240        assert!(!r.iter().any(|s| s.text == "b"));
241    }
242
243    #[tokio::test]
244    async fn delete_unknown_ids_silent() {
245        let mut store = InMemoryVectorStore::new(fake_embedder(8));
246        store.add_texts(vec!["a".into()], None).await.unwrap();
247        // No error even though the ID doesn't exist.
248        store.delete(vec!["nonexistent".into()]).await.unwrap();
249        assert_eq!(store.len(), 1);
250    }
251
252    #[tokio::test]
253    async fn empty_store_search_returns_empty() {
254        let store = InMemoryVectorStore::new(fake_embedder(8));
255        let r = store.similarity_search("anything", 5).await.unwrap();
256        assert!(r.is_empty());
257    }
258
259    #[tokio::test]
260    async fn k_zero_returns_empty() {
261        let mut store = InMemoryVectorStore::new(fake_embedder(8));
262        store.add_texts(vec!["a".into()], None).await.unwrap();
263        let r = store.similarity_search("a", 0).await.unwrap();
264        assert!(r.is_empty());
265    }
266}