Skip to main content

cognee_vector/
lancedb_adapter.rs

1//! Embedded LanceDB vector store — the OSS default on non-Android targets.
2//!
3//! Each `(data_type, field_name)` pair maps to one LanceDB table named
4//! `"{data_type}_{field_name}"` (matching `BruteForceVectorDB`'s naming so
5//! a backend switch keeps existing on-disk data discoverable). Tables hold
6//! three columns:
7//!
8//! | column     | Arrow type                        | semantics                |
9//! |------------|-----------------------------------|--------------------------|
10//! | `id`       | `FixedSizeBinary(16)`             | UUID bytes (primary key) |
11//! | `vector`   | `FixedSizeList<Float32, dim>`     | embedding                |
12//! | `metadata` | `Utf8`                            | JSON blob                |
13//!
14//! Persistence: the LanceDB `connect(uri).execute()` call points at a
15//! filesystem directory (defaults to `{system_root_directory}/databases/cognee.lancedb`,
16//! matching the Python SDK file layout — Python parity is intentional).
17//! All writes go through LanceDB's transactional writer, so crashes mid-write
18//! don't corrupt prior versions.
19
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24use arrow_array::{
25    Array, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, RecordBatch, StringArray,
26    types::Float32Type,
27};
28use arrow_schema::{DataType, Field, Schema, SchemaRef};
29use async_trait::async_trait;
30use futures::TryStreamExt;
31use lancedb::{
32    DistanceType, connect,
33    connection::Connection,
34    query::{ExecutableQuery, QueryBase},
35};
36use tokio::sync::RwLock;
37use uuid::Uuid;
38
39use crate::error::{VectorDBError, VectorDBResult};
40use crate::models::{SearchResult, VectorPoint};
41use crate::vector_db_trait::VectorDB;
42
43fn collection_name(data_type: &str, field_name: &str) -> String {
44    format!("{data_type}_{field_name}")
45}
46
47fn map_lance_err(e: lancedb::Error) -> VectorDBError {
48    VectorDBError::StorageError(format!("lancedb: {e}"))
49}
50
51/// Dimension of a `FixedSizeList<Float32, _>` field, or `None` if it's some
52/// other type. Used when opening a pre-existing table to recover the dim.
53fn dimension_from_schema(schema: &SchemaRef) -> Option<usize> {
54    schema.field_with_name("vector").ok().and_then(|f| {
55        if let DataType::FixedSizeList(_, dim) = f.data_type() {
56            usize::try_from(*dim).ok()
57        } else {
58            None
59        }
60    })
61}
62
63fn build_schema(dimension: usize) -> SchemaRef {
64    let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
65    Arc::new(Schema::new(vec![
66        Field::new("id", DataType::FixedSizeBinary(16), false),
67        Field::new(
68            "vector",
69            DataType::FixedSizeList(vector_field, dimension as i32),
70            false,
71        ),
72        Field::new("metadata", DataType::Utf8, false),
73    ]))
74}
75
76fn points_to_batch(
77    schema: SchemaRef,
78    dimension: usize,
79    collection: &str,
80    points: &[VectorPoint],
81) -> VectorDBResult<RecordBatch> {
82    if let Some(p) = points.iter().find(|p| p.vector.len() != dimension) {
83        return Err(VectorDBError::DimensionMismatch {
84            collection: collection.to_string(),
85            expected: dimension,
86            actual: p.vector.len(),
87        });
88    }
89
90    let id_array = FixedSizeBinaryArray::try_from_iter(points.iter().map(|p| *p.id.as_bytes()))
91        .map_err(|e| VectorDBError::StorageError(format!("id column build: {e}")))?;
92
93    let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
94        points
95            .iter()
96            .map(|p| Some(p.vector.iter().map(|v| Some(*v)).collect::<Vec<_>>())),
97        dimension as i32,
98    );
99
100    let metadata_array = StringArray::from(
101        points
102            .iter()
103            .map(|p| serde_json::to_string(&p.metadata))
104            .collect::<Result<Vec<_>, _>>()?,
105    );
106
107    RecordBatch::try_new(
108        schema,
109        vec![
110            Arc::new(id_array),
111            Arc::new(vector_array),
112            Arc::new(metadata_array),
113        ],
114    )
115    .map_err(|e| VectorDBError::StorageError(format!("record batch build: {e}")))
116}
117
118fn search_results_from_batches(batches: Vec<RecordBatch>) -> VectorDBResult<Vec<SearchResult>> {
119    let mut out = Vec::new();
120    for batch in batches {
121        let id_col = batch
122            .column_by_name("id")
123            .ok_or_else(|| VectorDBError::StorageError("missing id column".to_string()))?
124            .as_any()
125            .downcast_ref::<FixedSizeBinaryArray>()
126            .ok_or_else(|| VectorDBError::StorageError("id column type mismatch".to_string()))?;
127
128        let metadata_col = batch
129            .column_by_name("metadata")
130            .ok_or_else(|| VectorDBError::StorageError("missing metadata column".to_string()))?
131            .as_any()
132            .downcast_ref::<StringArray>()
133            .ok_or_else(|| {
134                VectorDBError::StorageError("metadata column type mismatch".to_string())
135            })?;
136
137        // LanceDB's `nearest_to` appends a `_distance` column with the
138        // distance from the query (lower = closer for Cosine/L2). Convert
139        // distance back to a similarity score so callers can sort descending.
140        let distance_col = batch
141            .column_by_name("_distance")
142            .ok_or_else(|| VectorDBError::StorageError("missing _distance column".to_string()))?
143            .as_any()
144            .downcast_ref::<Float32Array>()
145            .ok_or_else(|| {
146                VectorDBError::StorageError("_distance column type mismatch".to_string())
147            })?;
148
149        for row in 0..batch.num_rows() {
150            let id_bytes = id_col.value(row);
151            let id = Uuid::from_slice(id_bytes)
152                .map_err(|e| VectorDBError::StorageError(format!("id is not a valid UUID: {e}")))?;
153
154            let metadata: HashMap<String, serde_json::Value> =
155                serde_json::from_str(metadata_col.value(row))?;
156
157            // Cosine distance is in [0, 2]; clamp + invert to similarity.
158            let distance = distance_col.value(row).max(0.0);
159            let score = (1.0 - distance).clamp(-1.0, 1.0);
160
161            out.push(SearchResult {
162                id,
163                score,
164                metadata,
165            });
166        }
167    }
168    Ok(out)
169}
170
171/// LanceDB-backed vector store.
172pub struct LanceDbAdapter {
173    connection: Connection,
174    /// Cached per-collection dimensions so we can rebuild Arrow schemas without
175    /// re-opening each table on every write/search call.
176    dimensions: Arc<RwLock<HashMap<String, usize>>>,
177}
178
179impl LanceDbAdapter {
180    /// Open (or create) a LanceDB store at the given filesystem path.
181    pub async fn new(path: PathBuf) -> VectorDBResult<Self> {
182        if let Some(parent) = path.parent()
183            && !parent.as_os_str().is_empty()
184        {
185            std::fs::create_dir_all(parent)?;
186        }
187        let uri = path.to_str().ok_or_else(|| {
188            VectorDBError::StorageError(format!("lancedb path is not valid UTF-8: {path:?}"))
189        })?;
190        let connection = connect(uri).execute().await.map_err(map_lance_err)?;
191        Ok(Self {
192            connection,
193            dimensions: Arc::new(RwLock::new(HashMap::new())),
194        })
195    }
196
197    async fn cached_dimension(&self, table_name: &str) -> Option<usize> {
198        self.dimensions.read().await.get(table_name).copied()
199    }
200
201    async fn resolved_dimension(&self, table_name: &str) -> VectorDBResult<usize> {
202        if let Some(dim) = self.cached_dimension(table_name).await {
203            return Ok(dim);
204        }
205        let table = self
206            .connection
207            .open_table(table_name)
208            .execute()
209            .await
210            .map_err(|e| match e {
211                lancedb::Error::TableNotFound { .. } => {
212                    VectorDBError::CollectionNotFound(table_name.to_string())
213                }
214                other => map_lance_err(other),
215            })?;
216        let schema = table.schema().await.map_err(map_lance_err)?;
217        let dim = dimension_from_schema(&schema).ok_or_else(|| {
218            VectorDBError::StorageError(format!(
219                "table '{table_name}' has no FixedSizeList<Float32, _> vector column"
220            ))
221        })?;
222        self.dimensions
223            .write()
224            .await
225            .insert(table_name.to_string(), dim);
226        Ok(dim)
227    }
228}
229
230#[async_trait]
231impl VectorDB for LanceDbAdapter {
232    async fn create_collection(
233        &self,
234        data_type: &str,
235        field_name: &str,
236        dimension: usize,
237    ) -> VectorDBResult<()> {
238        let name = collection_name(data_type, field_name);
239        if self.has_collection(data_type, field_name).await? {
240            // Idempotent: matches BruteForceVectorDB.create_collection semantics.
241            return Ok(());
242        }
243        let schema = build_schema(dimension);
244        self.connection
245            .create_empty_table(&name, schema)
246            .execute()
247            .await
248            .map_err(map_lance_err)?;
249        self.dimensions.write().await.insert(name, dimension);
250        Ok(())
251    }
252
253    async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
254        let target = collection_name(data_type, field_name);
255        let names = self
256            .connection
257            .table_names()
258            .execute()
259            .await
260            .map_err(map_lance_err)?;
261        Ok(names.iter().any(|n| n == &target))
262    }
263
264    async fn index_points(
265        &self,
266        data_type: &str,
267        field_name: &str,
268        points: &[VectorPoint],
269    ) -> VectorDBResult<()> {
270        if points.is_empty() {
271            return Ok(());
272        }
273        let name = collection_name(data_type, field_name);
274        let dimension = self.resolved_dimension(&name).await?;
275        let schema = build_schema(dimension);
276        let batch = points_to_batch(schema.clone(), dimension, &name, points)?;
277        let table = self
278            .connection
279            .open_table(&name)
280            .execute()
281            .await
282            .map_err(map_lance_err)?;
283        // Upsert by id so re-indexing existing points replaces them.
284        let id_values: Vec<String> = points
285            .iter()
286            .map(|p| {
287                let bytes = p.id.as_bytes();
288                // SQL hex literal: X'…' over the 16 UUID bytes.
289                let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect();
290                format!("X'{hex}'")
291            })
292            .collect();
293        if !id_values.is_empty() {
294            let predicate = format!("id IN ({})", id_values.join(", "));
295            table
296                .delete(predicate.as_str())
297                .await
298                .map_err(map_lance_err)?;
299        }
300        let _ = schema; // schema lives on the RecordBatch; nothing else needs it.
301        table
302            .add(vec![batch])
303            .execute()
304            .await
305            .map_err(map_lance_err)?;
306        Ok(())
307    }
308
309    async fn search_similar(
310        &self,
311        data_type: &str,
312        field_name: &str,
313        query_vector: &[f32],
314        top_k: usize,
315    ) -> VectorDBResult<Vec<SearchResult>> {
316        let name = collection_name(data_type, field_name);
317        let table = self
318            .connection
319            .open_table(&name)
320            .execute()
321            .await
322            .map_err(|e| match e {
323                lancedb::Error::TableNotFound { .. } => {
324                    VectorDBError::CollectionNotFound(name.clone())
325                }
326                other => map_lance_err(other),
327            })?;
328        let stream = table
329            .query()
330            .limit(top_k)
331            .nearest_to(query_vector)
332            .map_err(map_lance_err)?
333            .distance_type(DistanceType::Cosine)
334            .execute()
335            .await
336            .map_err(map_lance_err)?;
337        let batches: Vec<RecordBatch> = stream.try_collect().await.map_err(map_lance_err)?;
338        search_results_from_batches(batches)
339    }
340
341    async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
342        let name = collection_name(data_type, field_name);
343        match self.connection.drop_table(&name, &[]).await {
344            Ok(()) => {
345                self.dimensions.write().await.remove(&name);
346                Ok(())
347            }
348            Err(lancedb::Error::TableNotFound { .. }) => Ok(()),
349            Err(other) => Err(map_lance_err(other)),
350        }
351    }
352
353    async fn delete_points(
354        &self,
355        data_type: &str,
356        field_name: &str,
357        point_ids: &[Uuid],
358    ) -> VectorDBResult<()> {
359        if point_ids.is_empty() {
360            return Ok(());
361        }
362        let name = collection_name(data_type, field_name);
363        let table = self
364            .connection
365            .open_table(&name)
366            .execute()
367            .await
368            .map_err(|e| match e {
369                lancedb::Error::TableNotFound { .. } => {
370                    VectorDBError::CollectionNotFound(name.clone())
371                }
372                other => map_lance_err(other),
373            })?;
374        let id_values: Vec<String> = point_ids
375            .iter()
376            .map(|id| {
377                let hex: String = id.as_bytes().iter().map(|b| format!("{b:02X}")).collect();
378                format!("X'{hex}'")
379            })
380            .collect();
381        let predicate = format!("id IN ({})", id_values.join(", "));
382        table
383            .delete(predicate.as_str())
384            .await
385            .map_err(map_lance_err)?;
386        Ok(())
387    }
388
389    async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
390        let name = collection_name(data_type, field_name);
391        let table = self
392            .connection
393            .open_table(&name)
394            .execute()
395            .await
396            .map_err(|e| match e {
397                lancedb::Error::TableNotFound { .. } => {
398                    VectorDBError::CollectionNotFound(name.clone())
399                }
400                other => map_lance_err(other),
401            })?;
402        table.count_rows(None).await.map_err(map_lance_err)
403    }
404
405    async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
406        let names = self
407            .connection
408            .table_names()
409            .execute()
410            .await
411            .map_err(map_lance_err)?;
412        Ok(names
413            .into_iter()
414            .filter_map(|n| {
415                n.find('_')
416                    .map(|i| (n[..i].to_string(), n[i + 1..].to_string()))
417            })
418            .collect())
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    #![allow(
425        clippy::unwrap_used,
426        clippy::expect_used,
427        reason = "test code — panics are acceptable"
428    )]
429    use super::*;
430    use serde_json::json;
431    use tempfile::tempdir;
432
433    fn point(id: Uuid, vector: Vec<f32>, kind: &str) -> VectorPoint {
434        VectorPoint::new(id, vector).with_metadata("kind", json!(kind))
435    }
436
437    async fn fresh_adapter() -> (LanceDbAdapter, tempfile::TempDir) {
438        let dir = tempdir().unwrap();
439        let path = dir.path().join("store.lance");
440        let adapter = LanceDbAdapter::new(path).await.unwrap();
441        (adapter, dir)
442    }
443
444    #[tokio::test]
445    async fn create_and_has_collection_roundtrip() {
446        let (adapter, _dir) = fresh_adapter().await;
447        assert!(!adapter.has_collection("Chunk", "text").await.unwrap());
448        adapter.create_collection("Chunk", "text", 4).await.unwrap();
449        assert!(adapter.has_collection("Chunk", "text").await.unwrap());
450        // Idempotent.
451        adapter.create_collection("Chunk", "text", 4).await.unwrap();
452    }
453
454    #[tokio::test]
455    async fn index_and_search_finds_closest_point() {
456        let (adapter, _dir) = fresh_adapter().await;
457        adapter.create_collection("Chunk", "text", 3).await.unwrap();
458
459        let target = Uuid::new_v4();
460        let other = Uuid::new_v4();
461        let points = vec![
462            point(target, vec![1.0, 0.0, 0.0], "target"),
463            point(other, vec![0.0, 1.0, 0.0], "other"),
464        ];
465        adapter
466            .index_points("Chunk", "text", &points)
467            .await
468            .unwrap();
469
470        let results = adapter
471            .search_similar("Chunk", "text", &[1.0, 0.0, 0.0], 2)
472            .await
473            .unwrap();
474        assert_eq!(results.len(), 2);
475        assert_eq!(results[0].id, target, "nearest point should be the target");
476        assert_eq!(results[0].metadata.get("kind").unwrap(), &json!("target"));
477        // Cosine distance from target to itself ~= 0 → score ~= 1.
478        assert!(results[0].score > 0.99);
479    }
480
481    #[tokio::test]
482    async fn collection_size_reports_row_count() {
483        let (adapter, _dir) = fresh_adapter().await;
484        adapter.create_collection("Chunk", "text", 2).await.unwrap();
485        let points = vec![
486            point(Uuid::new_v4(), vec![0.0, 1.0], "a"),
487            point(Uuid::new_v4(), vec![1.0, 0.0], "b"),
488        ];
489        adapter
490            .index_points("Chunk", "text", &points)
491            .await
492            .unwrap();
493        assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 2);
494    }
495
496    #[tokio::test]
497    async fn delete_points_removes_by_id() {
498        let (adapter, _dir) = fresh_adapter().await;
499        adapter.create_collection("Chunk", "text", 2).await.unwrap();
500        let keep = Uuid::new_v4();
501        let drop = Uuid::new_v4();
502        adapter
503            .index_points(
504                "Chunk",
505                "text",
506                &[
507                    point(keep, vec![1.0, 0.0], "keep"),
508                    point(drop, vec![0.0, 1.0], "drop"),
509                ],
510            )
511            .await
512            .unwrap();
513
514        adapter
515            .delete_points("Chunk", "text", &[drop])
516            .await
517            .unwrap();
518
519        assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 1);
520        let results = adapter
521            .search_similar("Chunk", "text", &[0.0, 1.0], 5)
522            .await
523            .unwrap();
524        assert!(results.iter().all(|r| r.id != drop));
525    }
526
527    #[tokio::test]
528    async fn index_points_replaces_existing_id() {
529        let (adapter, _dir) = fresh_adapter().await;
530        adapter.create_collection("Chunk", "text", 2).await.unwrap();
531        let id = Uuid::new_v4();
532        adapter
533            .index_points("Chunk", "text", &[point(id, vec![1.0, 0.0], "v1")])
534            .await
535            .unwrap();
536        adapter
537            .index_points("Chunk", "text", &[point(id, vec![0.0, 1.0], "v2")])
538            .await
539            .unwrap();
540        assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 1);
541
542        let results = adapter
543            .search_similar("Chunk", "text", &[0.0, 1.0], 1)
544            .await
545            .unwrap();
546        assert_eq!(results.len(), 1);
547        assert_eq!(results[0].id, id);
548        assert_eq!(results[0].metadata.get("kind").unwrap(), &json!("v2"));
549    }
550
551    #[tokio::test]
552    async fn delete_collection_drops_table_and_is_idempotent() {
553        let (adapter, _dir) = fresh_adapter().await;
554        adapter.create_collection("Chunk", "text", 2).await.unwrap();
555        assert!(adapter.has_collection("Chunk", "text").await.unwrap());
556        adapter.delete_collection("Chunk", "text").await.unwrap();
557        assert!(!adapter.has_collection("Chunk", "text").await.unwrap());
558        // Idempotent on a missing table.
559        adapter.delete_collection("Chunk", "text").await.unwrap();
560    }
561
562    #[tokio::test]
563    async fn list_and_prune_collections() {
564        let (adapter, _dir) = fresh_adapter().await;
565        adapter.create_collection("Chunk", "text", 2).await.unwrap();
566        adapter
567            .create_collection("Entity", "name", 2)
568            .await
569            .unwrap();
570
571        let mut listed: Vec<_> = adapter.list_collections().await.unwrap();
572        listed.sort();
573        assert_eq!(
574            listed,
575            vec![
576                ("Chunk".to_string(), "text".to_string()),
577                ("Entity".to_string(), "name".to_string()),
578            ]
579        );
580
581        adapter.prune().await.unwrap();
582        assert_eq!(adapter.list_collections().await.unwrap().len(), 0);
583    }
584
585    #[tokio::test]
586    async fn dimension_mismatch_returns_error() {
587        let (adapter, _dir) = fresh_adapter().await;
588        adapter.create_collection("Chunk", "text", 3).await.unwrap();
589        let err = adapter
590            .index_points(
591                "Chunk",
592                "text",
593                &[point(Uuid::new_v4(), vec![1.0, 0.0], "bad")],
594            )
595            .await
596            .unwrap_err();
597        assert!(
598            matches!(
599                err,
600                VectorDBError::DimensionMismatch {
601                    expected: 3,
602                    actual: 2,
603                    ..
604                }
605            ),
606            "expected DimensionMismatch, got {err:?}"
607        );
608    }
609
610    #[tokio::test]
611    async fn store_persists_across_reopen() {
612        let dir = tempdir().unwrap();
613        let path = dir.path().join("persist.lance");
614        let id = Uuid::new_v4();
615
616        {
617            let adapter = LanceDbAdapter::new(path.clone()).await.unwrap();
618            adapter.create_collection("Chunk", "text", 2).await.unwrap();
619            adapter
620                .index_points("Chunk", "text", &[point(id, vec![1.0, 0.0], "v1")])
621                .await
622                .unwrap();
623        }
624
625        // Re-open at the same path; the table and row should still be there.
626        let adapter = LanceDbAdapter::new(path).await.unwrap();
627        assert!(adapter.has_collection("Chunk", "text").await.unwrap());
628        let results = adapter
629            .search_similar("Chunk", "text", &[1.0, 0.0], 1)
630            .await
631            .unwrap();
632        assert_eq!(results.len(), 1);
633        assert_eq!(results[0].id, id);
634    }
635}