Skip to main content

cognee_vector/
brute_force_vector_db.rs

1//! Pure-Rust in-memory brute-force `VectorDB` implementation.
2//!
3//! Linear-scan O(n) similarity search over all stored vectors. Used as:
4//! - the Android default (LanceDB + Arrow do not cross-compile cleanly there);
5//! - the `vector_db_url = ":memory:"` escape hatch on every target (handy
6//!   for tests and ephemeral cognify runs);
7//! - a test fixture in lieu of the testing-feature-gated `MockVectorDB`.
8//!
9//! **No persistence.** Data is lost on process restart. For durable storage
10//! prefer the default `LanceDbAdapter` (on non-Android targets) or
11//! `vector_db_provider="pgvector"`.
12//!
13//! **Memory:** O(n × dim). At ~6 GB for 1M × 1536-dim, this is a
14//! soft cap — beyond that, pgvector (or the closed `cognee-vector-qdrant`)
15//! is the correct choice.
16//!
17//! **Distance metric:** every collection uses cosine similarity
18//! (higher = more similar). The `VectorDB` trait's
19//! `create_collection(data_type, field_name, dimension)` does not carry
20//! a `DistanceMetric`; per-collection metric plumbing is beyond T5's
21//! scope.
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use tokio::sync::RwLock;
28use uuid::Uuid;
29
30use crate::error::{VectorDBError, VectorDBResult};
31use crate::models::{SearchResult, VectorPoint};
32use crate::vector_db_trait::VectorDB;
33
34#[derive(Debug)]
35struct Collection {
36    dimension: usize,
37    points: Vec<VectorPoint>,
38}
39
40/// In-memory brute-force vector database.
41///
42/// All collections are held in a single `tokio::sync::RwLock`. Cloning
43/// the struct shares the same underlying storage (`Arc`-backed), so it
44/// is safe to hand out to multiple async tasks.
45#[derive(Debug, Clone, Default)]
46pub struct BruteForceVectorDB {
47    collections: Arc<RwLock<HashMap<String, Collection>>>,
48}
49
50impl BruteForceVectorDB {
51    /// Construct an empty in-memory vector database.
52    pub fn new() -> Self {
53        Self {
54            collections: Arc::new(RwLock::new(HashMap::new())),
55        }
56    }
57
58    /// Mirror `MockVectorDB::collection_key`: `"{data_type}_{field_name}"`.
59    fn key(data_type: &str, field_name: &str) -> String {
60        format!("{data_type}_{field_name}")
61    }
62}
63
64/// Cosine similarity in `[-1.0, 1.0]`. Higher = more similar.
65///
66/// `EPSILON` guards the denominator against zero-magnitude inputs; we
67/// deliberately do not special-case NaN (matches `MockVectorDB`).
68fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
69    debug_assert_eq!(a.len(), b.len(), "cosine_similarity inputs must match");
70    let mut dot = 0.0f32;
71    let mut na = 0.0f32;
72    let mut nb = 0.0f32;
73    for i in 0..a.len() {
74        dot += a[i] * b[i];
75        na += a[i] * a[i];
76        nb += b[i] * b[i];
77    }
78    let denom = (na.sqrt() * nb.sqrt()).max(f32::EPSILON);
79    dot / denom
80}
81
82#[async_trait]
83impl VectorDB for BruteForceVectorDB {
84    async fn create_collection(
85        &self,
86        data_type: &str,
87        field_name: &str,
88        dimension: usize,
89    ) -> VectorDBResult<()> {
90        let key = Self::key(data_type, field_name);
91        let mut g = self.collections.write().await;
92        if g.contains_key(&key) {
93            return Err(VectorDBError::CollectionExists(key));
94        }
95        g.insert(
96            key,
97            Collection {
98                dimension,
99                points: Vec::new(),
100            },
101        );
102        Ok(())
103    }
104
105    async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
106        let g = self.collections.read().await;
107        Ok(g.contains_key(&Self::key(data_type, field_name)))
108    }
109
110    async fn index_points(
111        &self,
112        data_type: &str,
113        field_name: &str,
114        points: &[VectorPoint],
115    ) -> VectorDBResult<()> {
116        if points.is_empty() {
117            return Ok(());
118        }
119        let key = Self::key(data_type, field_name);
120        let mut g = self.collections.write().await;
121        let coll = g
122            .get_mut(&key)
123            .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
124
125        // Validate dimensions before mutating storage.
126        for p in points {
127            if p.vector.len() != coll.dimension {
128                return Err(VectorDBError::DimensionMismatch {
129                    collection: key.clone(),
130                    expected: coll.dimension,
131                    actual: p.vector.len(),
132                });
133            }
134        }
135
136        // Upsert by id: replace existing, otherwise append.
137        for p in points {
138            if let Some(existing) = coll.points.iter_mut().find(|x| x.id == p.id) {
139                *existing = p.clone();
140            } else {
141                coll.points.push(p.clone());
142            }
143        }
144        Ok(())
145    }
146
147    async fn search_similar(
148        &self,
149        data_type: &str,
150        field_name: &str,
151        query_vector: &[f32],
152        top_k: usize,
153    ) -> VectorDBResult<Vec<SearchResult>> {
154        let key = Self::key(data_type, field_name);
155        // Score under the read guard, then drop it before sorting + result
156        // construction so we never hold the lock across the (synchronous,
157        // but still post-await) sort/truncate step.
158        let mut scored: Vec<(Uuid, f32, HashMap<String, serde_json::Value>)> = {
159            let g = self.collections.read().await;
160            let coll = g
161                .get(&key)
162                .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
163            if query_vector.len() != coll.dimension {
164                return Err(VectorDBError::DimensionMismatch {
165                    collection: key.clone(),
166                    expected: coll.dimension,
167                    actual: query_vector.len(),
168                });
169            }
170            coll.points
171                .iter()
172                .map(|p| {
173                    (
174                        p.id,
175                        cosine_similarity(&p.vector, query_vector),
176                        p.metadata.clone(),
177                    )
178                })
179                .collect()
180        };
181
182        // Higher score first (descending). `total_cmp` orders NaN deterministically.
183        scored.sort_by(|a, b| b.1.total_cmp(&a.1));
184        scored.truncate(top_k);
185        Ok(scored
186            .into_iter()
187            .map(|(id, score, metadata)| SearchResult {
188                id,
189                score,
190                metadata,
191            })
192            .collect())
193    }
194
195    async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
196        let mut g = self.collections.write().await;
197        g.remove(&Self::key(data_type, field_name));
198        Ok(())
199    }
200
201    async fn delete_points(
202        &self,
203        data_type: &str,
204        field_name: &str,
205        point_ids: &[Uuid],
206    ) -> VectorDBResult<()> {
207        let key = Self::key(data_type, field_name);
208        let mut g = self.collections.write().await;
209        let coll = g
210            .get_mut(&key)
211            .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
212        coll.points.retain(|p| !point_ids.contains(&p.id));
213        Ok(())
214    }
215
216    async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
217        let key = Self::key(data_type, field_name);
218        let g = self.collections.read().await;
219        let coll = g
220            .get(&key)
221            .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
222        Ok(coll.points.len())
223    }
224
225    async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
226        let g = self.collections.read().await;
227        let mut out: Vec<(String, String)> = g
228            .keys()
229            .filter_map(|k| {
230                k.split_once('_')
231                    .map(|(a, b)| (a.to_string(), b.to_string()))
232            })
233            .collect();
234        out.sort();
235        Ok(out)
236    }
237}
238
239#[cfg(test)]
240#[allow(
241    clippy::unwrap_used,
242    clippy::expect_used,
243    reason = "test code — panics are acceptable"
244)]
245mod tests {
246    use super::*;
247    use std::collections::HashMap as Hm;
248    use uuid::Uuid;
249
250    fn point(id_seed: u128, v: Vec<f32>) -> VectorPoint {
251        VectorPoint {
252            id: Uuid::from_u128(id_seed),
253            vector: v,
254            metadata: Hm::new(),
255        }
256    }
257
258    #[tokio::test]
259    async fn create_then_has_collection() {
260        let db = BruteForceVectorDB::new();
261        assert!(!db.has_collection("T", "f").await.unwrap());
262        db.create_collection("T", "f", 4).await.unwrap();
263        assert!(db.has_collection("T", "f").await.unwrap());
264    }
265
266    #[tokio::test]
267    async fn create_duplicate_returns_exists() {
268        let db = BruteForceVectorDB::new();
269        db.create_collection("T", "f", 4).await.unwrap();
270        let err = db.create_collection("T", "f", 4).await.unwrap_err();
271        assert!(
272            matches!(err, VectorDBError::CollectionExists(ref k) if k == "T_f"),
273            "expected CollectionExists, got {err:?}",
274        );
275    }
276
277    #[tokio::test]
278    async fn index_dim_mismatch_returns_error() {
279        let db = BruteForceVectorDB::new();
280        db.create_collection("T", "f", 3).await.unwrap();
281        let p = point(1, vec![1.0, 2.0]); // dim 2, expected 3
282        let err = db.index_points("T", "f", &[p]).await.unwrap_err();
283        assert!(
284            matches!(
285                err,
286                VectorDBError::DimensionMismatch {
287                    expected: 3,
288                    actual: 2,
289                    ..
290                }
291            ),
292            "expected DimensionMismatch 3 vs 2, got {err:?}",
293        );
294    }
295
296    #[tokio::test]
297    async fn index_replaces_by_id() {
298        let db = BruteForceVectorDB::new();
299        db.create_collection("T", "f", 2).await.unwrap();
300        let p_v1 = point(1, vec![1.0, 0.0]);
301        let p_v2 = point(1, vec![0.0, 1.0]); // same id, new vector
302        db.index_points("T", "f", &[p_v1]).await.unwrap();
303        db.index_points("T", "f", &[p_v2]).await.unwrap();
304        assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
305
306        // Query for [0,1] → the upserted vector should score 1.0.
307        let results = db.search_similar("T", "f", &[0.0, 1.0], 1).await.unwrap();
308        assert_eq!(results.len(), 1);
309        assert!(
310            (results[0].score - 1.0).abs() < 1e-5,
311            "upserted vector should score 1.0, got {}",
312            results[0].score
313        );
314    }
315
316    #[tokio::test]
317    async fn search_ranks_descending() {
318        let db = BruteForceVectorDB::new();
319        db.create_collection("T", "f", 3).await.unwrap();
320        let a = point(1, vec![1.0, 0.0, 0.0]);
321        let b = point(2, vec![0.0, 1.0, 0.0]);
322        let c = point(3, vec![0.0, 0.0, 1.0]);
323        db.index_points("T", "f", &[a, b, c]).await.unwrap();
324
325        let results = db
326            .search_similar("T", "f", &[1.0, 0.0, 0.0], 3)
327            .await
328            .unwrap();
329        assert_eq!(results.len(), 3);
330        assert_eq!(results[0].id, Uuid::from_u128(1), "A should rank first");
331        // Tail order is implementation-defined for ties (both 0.0) under
332        // total_cmp; just assert descending and that A wins.
333        assert!(results[0].score >= results[1].score);
334        assert!(results[1].score >= results[2].score);
335        assert!(
336            (results[0].score - 1.0).abs() < 1e-5,
337            "self-similarity should be ~1.0, got {}",
338            results[0].score
339        );
340    }
341
342    #[tokio::test]
343    async fn search_empty_collection_returns_empty() {
344        let db = BruteForceVectorDB::new();
345        db.create_collection("T", "f", 3).await.unwrap();
346        let results = db
347            .search_similar("T", "f", &[1.0, 0.0, 0.0], 5)
348            .await
349            .unwrap();
350        assert!(results.is_empty());
351    }
352
353    #[tokio::test]
354    async fn search_query_dim_mismatch_returns_error() {
355        let db = BruteForceVectorDB::new();
356        db.create_collection("T", "f", 3).await.unwrap();
357        let err = db
358            .search_similar("T", "f", &[1.0, 0.0], 5)
359            .await
360            .unwrap_err();
361        assert!(
362            matches!(
363                err,
364                VectorDBError::DimensionMismatch {
365                    expected: 3,
366                    actual: 2,
367                    ..
368                }
369            ),
370            "expected DimensionMismatch, got {err:?}",
371        );
372    }
373
374    #[tokio::test]
375    async fn delete_points_removes_matching_ids() {
376        let db = BruteForceVectorDB::new();
377        db.create_collection("T", "f", 2).await.unwrap();
378        let a = point(1, vec![1.0, 0.0]);
379        let b = point(2, vec![0.0, 1.0]);
380        let c = point(3, vec![1.0, 1.0]);
381        db.index_points("T", "f", &[a, b, c]).await.unwrap();
382        db.delete_points("T", "f", &[Uuid::from_u128(1), Uuid::from_u128(3)])
383            .await
384            .unwrap();
385        assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
386    }
387
388    #[tokio::test]
389    async fn delete_collection_is_idempotent() {
390        let db = BruteForceVectorDB::new();
391        db.create_collection("T", "f", 2).await.unwrap();
392        db.delete_collection("T", "f").await.unwrap();
393        // Deleting again should not error.
394        db.delete_collection("T", "f").await.unwrap();
395        assert!(!db.has_collection("T", "f").await.unwrap());
396    }
397
398    #[tokio::test]
399    async fn list_collections_returns_pairs() {
400        let db = BruteForceVectorDB::new();
401        let empty = db.list_collections().await.unwrap();
402        assert!(empty.is_empty());
403
404        db.create_collection("DocumentChunk", "text", 3)
405            .await
406            .unwrap();
407        db.create_collection("Entity", "name", 3).await.unwrap();
408
409        let pairs = db.list_collections().await.unwrap();
410        assert_eq!(pairs.len(), 2);
411        assert!(pairs.contains(&("DocumentChunk".to_string(), "text".to_string())));
412        assert!(pairs.contains(&("Entity".to_string(), "name".to_string())));
413    }
414
415    #[tokio::test]
416    async fn collection_size_after_upsert() {
417        let db = BruteForceVectorDB::new();
418        db.create_collection("T", "f", 2).await.unwrap();
419        assert_eq!(db.collection_size("T", "f").await.unwrap(), 0);
420        db.index_points(
421            "T",
422            "f",
423            &[point(1, vec![1.0, 0.0]), point(2, vec![0.0, 1.0])],
424        )
425        .await
426        .unwrap();
427        assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
428        // Re-upsert same id 1; size stays 2.
429        db.index_points("T", "f", &[point(1, vec![0.5, 0.5])])
430            .await
431            .unwrap();
432        assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
433    }
434
435    #[tokio::test]
436    async fn collection_size_unknown_collection_errors() {
437        let db = BruteForceVectorDB::new();
438        let err = db.collection_size("T", "f").await.unwrap_err();
439        assert!(
440            matches!(err, VectorDBError::CollectionNotFound(_)),
441            "expected CollectionNotFound, got {err:?}",
442        );
443    }
444}