Skip to main content

cognee_vector/
brute_force_vector_db.rs

1//! Pure-Rust in-memory brute-force `VectorDB` implementation.
2//!
3//! Linear-scan O(n) similarity search over all stored vectors. Used as:
4//! - the Android default (LanceDB + Arrow do not cross-compile cleanly there);
5//! - the `vector_db_url = ":memory:"` escape hatch on every target (handy
6//!   for tests and ephemeral cognify runs);
7//! - a test fixture in lieu of the testing-feature-gated `MockVectorDB`.
8//!
9//! **No persistence.** Data is lost on process restart. For durable storage
10//! prefer the default `LanceDbAdapter` (on non-Android targets) or
11//! `vector_db_provider="pgvector"`.
12//!
13//! **Memory:** O(n × dim). At ~6 GB for 1M × 1536-dim, this is a
14//! soft cap — beyond that, pgvector (or the closed `cognee-vector-qdrant`)
15//! is the correct choice.
16//!
17//! **Distance metric:** every collection uses cosine similarity
18//! (higher = more similar). The `VectorDB` trait's
19//! `create_collection(data_type, field_name, dimension)` does not carry
20//! a `DistanceMetric`; per-collection metric plumbing is beyond T5's
21//! scope.
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use tokio::sync::RwLock;
28use uuid::Uuid;
29
30use crate::error::{VectorDBError, VectorDBResult};
31use crate::models::{SearchResult, VectorPoint};
32use crate::vector_db_trait::VectorDB;
33
34#[derive(Debug)]
35struct Collection {
36    dimension: usize,
37    points: Vec<VectorPoint>,
38}
39
40/// In-memory brute-force vector database.
41///
42/// All collections are held in a single `tokio::sync::RwLock`. Cloning
43/// the struct shares the same underlying storage (`Arc`-backed), so it
44/// is safe to hand out to multiple async tasks.
45#[derive(Debug, Clone, Default)]
46pub struct BruteForceVectorDB {
47    collections: Arc<RwLock<HashMap<String, Collection>>>,
48}
49
50impl BruteForceVectorDB {
51    /// Construct an empty in-memory vector database.
52    pub fn new() -> Self {
53        Self {
54            collections: Arc::new(RwLock::new(HashMap::new())),
55        }
56    }
57
58    /// Mirror `MockVectorDB::collection_key`: `"{data_type}_{field_name}"`.
59    fn key(data_type: &str, field_name: &str) -> String {
60        format!("{data_type}_{field_name}")
61    }
62}
63
64/// Cosine similarity in `[-1.0, 1.0]`. Higher = more similar.
65///
66/// `EPSILON` guards the denominator against zero-magnitude inputs; we
67/// deliberately do not special-case NaN (matches `MockVectorDB`).
68fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
69    debug_assert_eq!(a.len(), b.len(), "cosine_similarity inputs must match");
70    let mut dot = 0.0f32;
71    let mut na = 0.0f32;
72    let mut nb = 0.0f32;
73    for i in 0..a.len() {
74        dot += a[i] * b[i];
75        na += a[i] * a[i];
76        nb += b[i] * b[i];
77    }
78    let denom = (na.sqrt() * nb.sqrt()).max(f32::EPSILON);
79    dot / denom
80}
81
82#[async_trait]
83impl VectorDB for BruteForceVectorDB {
84    async fn create_collection(
85        &self,
86        data_type: &str,
87        field_name: &str,
88        dimension: usize,
89    ) -> VectorDBResult<()> {
90        let key = Self::key(data_type, field_name);
91        let mut g = self.collections.write().await;
92        if g.contains_key(&key) {
93            return Err(VectorDBError::CollectionExists(key));
94        }
95        g.insert(
96            key,
97            Collection {
98                dimension,
99                points: Vec::new(),
100            },
101        );
102        Ok(())
103    }
104
105    async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
106        let g = self.collections.read().await;
107        Ok(g.contains_key(&Self::key(data_type, field_name)))
108    }
109
110    async fn index_points(
111        &self,
112        data_type: &str,
113        field_name: &str,
114        points: &[VectorPoint],
115    ) -> VectorDBResult<()> {
116        if points.is_empty() {
117            return Ok(());
118        }
119        let key = Self::key(data_type, field_name);
120        let mut g = self.collections.write().await;
121        let coll = g
122            .get_mut(&key)
123            .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
124
125        // Validate dimensions before mutating storage.
126        for p in points {
127            if p.vector.len() != coll.dimension {
128                return Err(VectorDBError::DimensionMismatch {
129                    collection: key.clone(),
130                    expected: coll.dimension,
131                    actual: p.vector.len(),
132                });
133            }
134        }
135
136        // Upsert by id: replace existing, otherwise append. On replace, union
137        // dataset membership so a content-addressed point indexed under several
138        // datasets stays retrievable for all of them (cross-dataset dedup).
139        for p in points {
140            if let Some(existing) = coll.points.iter_mut().find(|x| x.id == p.id) {
141                let mut merged = p.clone();
142                merged.merge_dataset_membership(existing);
143                *existing = merged;
144            } else {
145                coll.points.push(p.clone());
146            }
147        }
148        Ok(())
149    }
150
151    async fn search_similar(
152        &self,
153        data_type: &str,
154        field_name: &str,
155        query_vector: &[f32],
156        top_k: usize,
157    ) -> VectorDBResult<Vec<SearchResult>> {
158        let key = Self::key(data_type, field_name);
159        // Score under the read guard, then drop it before sorting + result
160        // construction so we never hold the lock across the (synchronous,
161        // but still post-await) sort/truncate step.
162        let mut scored: Vec<(Uuid, f32, HashMap<String, serde_json::Value>)> = {
163            let g = self.collections.read().await;
164            let coll = g
165                .get(&key)
166                .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
167            if query_vector.len() != coll.dimension {
168                return Err(VectorDBError::DimensionMismatch {
169                    collection: key.clone(),
170                    expected: coll.dimension,
171                    actual: query_vector.len(),
172                });
173            }
174            coll.points
175                .iter()
176                .map(|p| {
177                    (
178                        p.id,
179                        cosine_similarity(&p.vector, query_vector),
180                        p.metadata.clone(),
181                    )
182                })
183                .collect()
184        };
185
186        // Higher score first (descending). `total_cmp` orders NaN deterministically.
187        scored.sort_by(|a, b| b.1.total_cmp(&a.1));
188        scored.truncate(top_k);
189        Ok(scored
190            .into_iter()
191            .map(|(id, score, metadata)| SearchResult {
192                id,
193                score,
194                metadata,
195            })
196            .collect())
197    }
198
199    async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
200        let mut g = self.collections.write().await;
201        g.remove(&Self::key(data_type, field_name));
202        Ok(())
203    }
204
205    async fn delete_points(
206        &self,
207        data_type: &str,
208        field_name: &str,
209        point_ids: &[Uuid],
210    ) -> VectorDBResult<()> {
211        let key = Self::key(data_type, field_name);
212        let mut g = self.collections.write().await;
213        let coll = g
214            .get_mut(&key)
215            .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
216        coll.points.retain(|p| !point_ids.contains(&p.id));
217        Ok(())
218    }
219
220    async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
221        let key = Self::key(data_type, field_name);
222        let g = self.collections.read().await;
223        let coll = g
224            .get(&key)
225            .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
226        Ok(coll.points.len())
227    }
228
229    async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
230        let g = self.collections.read().await;
231        let mut out: Vec<(String, String)> = g
232            .keys()
233            .filter_map(|k| {
234                k.split_once('_')
235                    .map(|(a, b)| (a.to_string(), b.to_string()))
236            })
237            .collect();
238        out.sort();
239        Ok(out)
240    }
241}
242
243#[cfg(test)]
244#[allow(
245    clippy::unwrap_used,
246    clippy::expect_used,
247    reason = "test code — panics are acceptable"
248)]
249mod tests {
250    use super::*;
251    use std::collections::HashMap as Hm;
252    use uuid::Uuid;
253
254    fn point(id_seed: u128, v: Vec<f32>) -> VectorPoint {
255        VectorPoint {
256            id: Uuid::from_u128(id_seed),
257            vector: v,
258            metadata: Hm::new(),
259        }
260    }
261
262    #[tokio::test]
263    async fn create_then_has_collection() {
264        let db = BruteForceVectorDB::new();
265        assert!(!db.has_collection("T", "f").await.unwrap());
266        db.create_collection("T", "f", 4).await.unwrap();
267        assert!(db.has_collection("T", "f").await.unwrap());
268    }
269
270    #[tokio::test]
271    async fn create_duplicate_returns_exists() {
272        let db = BruteForceVectorDB::new();
273        db.create_collection("T", "f", 4).await.unwrap();
274        let err = db.create_collection("T", "f", 4).await.unwrap_err();
275        assert!(
276            matches!(err, VectorDBError::CollectionExists(ref k) if k == "T_f"),
277            "expected CollectionExists, got {err:?}",
278        );
279    }
280
281    #[tokio::test]
282    async fn index_dim_mismatch_returns_error() {
283        let db = BruteForceVectorDB::new();
284        db.create_collection("T", "f", 3).await.unwrap();
285        let p = point(1, vec![1.0, 2.0]); // dim 2, expected 3
286        let err = db.index_points("T", "f", &[p]).await.unwrap_err();
287        assert!(
288            matches!(
289                err,
290                VectorDBError::DimensionMismatch {
291                    expected: 3,
292                    actual: 2,
293                    ..
294                }
295            ),
296            "expected DimensionMismatch 3 vs 2, got {err:?}",
297        );
298    }
299
300    #[tokio::test]
301    async fn index_replaces_by_id() {
302        let db = BruteForceVectorDB::new();
303        db.create_collection("T", "f", 2).await.unwrap();
304        let p_v1 = point(1, vec![1.0, 0.0]);
305        let p_v2 = point(1, vec![0.0, 1.0]); // same id, new vector
306        db.index_points("T", "f", &[p_v1]).await.unwrap();
307        db.index_points("T", "f", &[p_v2]).await.unwrap();
308        assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
309
310        // Query for [0,1] → the upserted vector should score 1.0.
311        let results = db.search_similar("T", "f", &[0.0, 1.0], 1).await.unwrap();
312        assert_eq!(results.len(), 1);
313        assert!(
314            (results[0].score - 1.0).abs() < 1e-5,
315            "upserted vector should score 1.0, got {}",
316            results[0].score
317        );
318    }
319
320    #[tokio::test]
321    async fn search_ranks_descending() {
322        let db = BruteForceVectorDB::new();
323        db.create_collection("T", "f", 3).await.unwrap();
324        let a = point(1, vec![1.0, 0.0, 0.0]);
325        let b = point(2, vec![0.0, 1.0, 0.0]);
326        let c = point(3, vec![0.0, 0.0, 1.0]);
327        db.index_points("T", "f", &[a, b, c]).await.unwrap();
328
329        let results = db
330            .search_similar("T", "f", &[1.0, 0.0, 0.0], 3)
331            .await
332            .unwrap();
333        assert_eq!(results.len(), 3);
334        assert_eq!(results[0].id, Uuid::from_u128(1), "A should rank first");
335        // Tail order is implementation-defined for ties (both 0.0) under
336        // total_cmp; just assert descending and that A wins.
337        assert!(results[0].score >= results[1].score);
338        assert!(results[1].score >= results[2].score);
339        assert!(
340            (results[0].score - 1.0).abs() < 1e-5,
341            "self-similarity should be ~1.0, got {}",
342            results[0].score
343        );
344    }
345
346    #[tokio::test]
347    async fn search_empty_collection_returns_empty() {
348        let db = BruteForceVectorDB::new();
349        db.create_collection("T", "f", 3).await.unwrap();
350        let results = db
351            .search_similar("T", "f", &[1.0, 0.0, 0.0], 5)
352            .await
353            .unwrap();
354        assert!(results.is_empty());
355    }
356
357    #[tokio::test]
358    async fn search_query_dim_mismatch_returns_error() {
359        let db = BruteForceVectorDB::new();
360        db.create_collection("T", "f", 3).await.unwrap();
361        let err = db
362            .search_similar("T", "f", &[1.0, 0.0], 5)
363            .await
364            .unwrap_err();
365        assert!(
366            matches!(
367                err,
368                VectorDBError::DimensionMismatch {
369                    expected: 3,
370                    actual: 2,
371                    ..
372                }
373            ),
374            "expected DimensionMismatch, got {err:?}",
375        );
376    }
377
378    #[tokio::test]
379    async fn delete_points_removes_matching_ids() {
380        let db = BruteForceVectorDB::new();
381        db.create_collection("T", "f", 2).await.unwrap();
382        let a = point(1, vec![1.0, 0.0]);
383        let b = point(2, vec![0.0, 1.0]);
384        let c = point(3, vec![1.0, 1.0]);
385        db.index_points("T", "f", &[a, b, c]).await.unwrap();
386        db.delete_points("T", "f", &[Uuid::from_u128(1), Uuid::from_u128(3)])
387            .await
388            .unwrap();
389        assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
390    }
391
392    #[tokio::test]
393    async fn delete_collection_is_idempotent() {
394        let db = BruteForceVectorDB::new();
395        db.create_collection("T", "f", 2).await.unwrap();
396        db.delete_collection("T", "f").await.unwrap();
397        // Deleting again should not error.
398        db.delete_collection("T", "f").await.unwrap();
399        assert!(!db.has_collection("T", "f").await.unwrap());
400    }
401
402    #[tokio::test]
403    async fn list_collections_returns_pairs() {
404        let db = BruteForceVectorDB::new();
405        let empty = db.list_collections().await.unwrap();
406        assert!(empty.is_empty());
407
408        db.create_collection("DocumentChunk", "text", 3)
409            .await
410            .unwrap();
411        db.create_collection("Entity", "name", 3).await.unwrap();
412
413        let pairs = db.list_collections().await.unwrap();
414        assert_eq!(pairs.len(), 2);
415        assert!(pairs.contains(&("DocumentChunk".to_string(), "text".to_string())));
416        assert!(pairs.contains(&("Entity".to_string(), "name".to_string())));
417    }
418
419    #[tokio::test]
420    async fn collection_size_after_upsert() {
421        let db = BruteForceVectorDB::new();
422        db.create_collection("T", "f", 2).await.unwrap();
423        assert_eq!(db.collection_size("T", "f").await.unwrap(), 0);
424        db.index_points(
425            "T",
426            "f",
427            &[point(1, vec![1.0, 0.0]), point(2, vec![0.0, 1.0])],
428        )
429        .await
430        .unwrap();
431        assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
432        // Re-upsert same id 1; size stays 2.
433        db.index_points("T", "f", &[point(1, vec![0.5, 0.5])])
434            .await
435            .unwrap();
436        assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
437    }
438
439    #[tokio::test]
440    async fn collection_size_unknown_collection_errors() {
441        let db = BruteForceVectorDB::new();
442        let err = db.collection_size("T", "f").await.unwrap_err();
443        assert!(
444            matches!(err, VectorDBError::CollectionNotFound(_)),
445            "expected CollectionNotFound, got {err:?}",
446        );
447    }
448}