1use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24use arrow_array::{
25 Array, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, RecordBatch, StringArray,
26 types::Float32Type,
27};
28use arrow_schema::{DataType, Field, Schema, SchemaRef};
29use async_trait::async_trait;
30use futures::TryStreamExt;
31use lancedb::{
32 DistanceType, connect,
33 connection::Connection,
34 query::{ExecutableQuery, QueryBase},
35};
36use tokio::sync::RwLock;
37use uuid::Uuid;
38
39use crate::error::{VectorDBError, VectorDBResult};
40use crate::models::{SearchResult, VectorPoint};
41use crate::vector_db_trait::VectorDB;
42
43fn collection_name(data_type: &str, field_name: &str) -> String {
44 format!("{data_type}_{field_name}")
45}
46
47fn map_lance_err(e: lancedb::Error) -> VectorDBError {
48 VectorDBError::StorageError(format!("lancedb: {e}"))
49}
50
51fn dimension_from_schema(schema: &SchemaRef) -> Option<usize> {
54 schema.field_with_name("vector").ok().and_then(|f| {
55 if let DataType::FixedSizeList(_, dim) = f.data_type() {
56 usize::try_from(*dim).ok()
57 } else {
58 None
59 }
60 })
61}
62
63fn build_schema(dimension: usize) -> SchemaRef {
64 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
65 Arc::new(Schema::new(vec![
66 Field::new("id", DataType::FixedSizeBinary(16), false),
67 Field::new(
68 "vector",
69 DataType::FixedSizeList(vector_field, dimension as i32),
70 false,
71 ),
72 Field::new("metadata", DataType::Utf8, false),
73 ]))
74}
75
76fn points_to_batch(
77 schema: SchemaRef,
78 dimension: usize,
79 collection: &str,
80 points: &[VectorPoint],
81) -> VectorDBResult<RecordBatch> {
82 if let Some(p) = points.iter().find(|p| p.vector.len() != dimension) {
83 return Err(VectorDBError::DimensionMismatch {
84 collection: collection.to_string(),
85 expected: dimension,
86 actual: p.vector.len(),
87 });
88 }
89
90 let id_array = FixedSizeBinaryArray::try_from_iter(points.iter().map(|p| *p.id.as_bytes()))
91 .map_err(|e| VectorDBError::StorageError(format!("id column build: {e}")))?;
92
93 let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
94 points
95 .iter()
96 .map(|p| Some(p.vector.iter().map(|v| Some(*v)).collect::<Vec<_>>())),
97 dimension as i32,
98 );
99
100 let metadata_array = StringArray::from(
101 points
102 .iter()
103 .map(|p| serde_json::to_string(&p.metadata))
104 .collect::<Result<Vec<_>, _>>()?,
105 );
106
107 RecordBatch::try_new(
108 schema,
109 vec![
110 Arc::new(id_array),
111 Arc::new(vector_array),
112 Arc::new(metadata_array),
113 ],
114 )
115 .map_err(|e| VectorDBError::StorageError(format!("record batch build: {e}")))
116}
117
118fn id_metadata_from_batches(
123 batches: Vec<RecordBatch>,
124) -> VectorDBResult<HashMap<Uuid, HashMap<String, serde_json::Value>>> {
125 let mut out = HashMap::new();
126 for batch in batches {
127 let id_col = batch
128 .column_by_name("id")
129 .ok_or_else(|| VectorDBError::StorageError("missing id column".to_string()))?
130 .as_any()
131 .downcast_ref::<FixedSizeBinaryArray>()
132 .ok_or_else(|| VectorDBError::StorageError("id column type mismatch".to_string()))?;
133 let metadata_col = batch
134 .column_by_name("metadata")
135 .ok_or_else(|| VectorDBError::StorageError("missing metadata column".to_string()))?
136 .as_any()
137 .downcast_ref::<StringArray>()
138 .ok_or_else(|| {
139 VectorDBError::StorageError("metadata column type mismatch".to_string())
140 })?;
141 for row in 0..batch.num_rows() {
142 let id = Uuid::from_slice(id_col.value(row))
143 .map_err(|e| VectorDBError::StorageError(format!("id is not a valid UUID: {e}")))?;
144 let metadata: HashMap<String, serde_json::Value> =
145 serde_json::from_str(metadata_col.value(row))?;
146 out.insert(id, metadata);
147 }
148 }
149 Ok(out)
150}
151
152fn search_results_from_batches(batches: Vec<RecordBatch>) -> VectorDBResult<Vec<SearchResult>> {
153 let mut out = Vec::new();
154 for batch in batches {
155 let id_col = batch
156 .column_by_name("id")
157 .ok_or_else(|| VectorDBError::StorageError("missing id column".to_string()))?
158 .as_any()
159 .downcast_ref::<FixedSizeBinaryArray>()
160 .ok_or_else(|| VectorDBError::StorageError("id column type mismatch".to_string()))?;
161
162 let metadata_col = batch
163 .column_by_name("metadata")
164 .ok_or_else(|| VectorDBError::StorageError("missing metadata column".to_string()))?
165 .as_any()
166 .downcast_ref::<StringArray>()
167 .ok_or_else(|| {
168 VectorDBError::StorageError("metadata column type mismatch".to_string())
169 })?;
170
171 let distance_col = batch
175 .column_by_name("_distance")
176 .ok_or_else(|| VectorDBError::StorageError("missing _distance column".to_string()))?
177 .as_any()
178 .downcast_ref::<Float32Array>()
179 .ok_or_else(|| {
180 VectorDBError::StorageError("_distance column type mismatch".to_string())
181 })?;
182
183 for row in 0..batch.num_rows() {
184 let id_bytes = id_col.value(row);
185 let id = Uuid::from_slice(id_bytes)
186 .map_err(|e| VectorDBError::StorageError(format!("id is not a valid UUID: {e}")))?;
187
188 let metadata: HashMap<String, serde_json::Value> =
189 serde_json::from_str(metadata_col.value(row))?;
190
191 let distance = distance_col.value(row).max(0.0);
193 let score = (1.0 - distance).clamp(-1.0, 1.0);
194
195 out.push(SearchResult {
196 id,
197 score,
198 metadata,
199 });
200 }
201 }
202 Ok(out)
203}
204
205pub struct LanceDbAdapter {
207 connection: Connection,
208 dimensions: Arc<RwLock<HashMap<String, usize>>>,
211}
212
213impl LanceDbAdapter {
214 pub async fn new(path: PathBuf) -> VectorDBResult<Self> {
216 if let Some(parent) = path.parent()
217 && !parent.as_os_str().is_empty()
218 {
219 std::fs::create_dir_all(parent)?;
220 }
221 let uri = path.to_str().ok_or_else(|| {
222 VectorDBError::StorageError(format!("lancedb path is not valid UTF-8: {path:?}"))
223 })?;
224 let connection = connect(uri).execute().await.map_err(map_lance_err)?;
225 Ok(Self {
226 connection,
227 dimensions: Arc::new(RwLock::new(HashMap::new())),
228 })
229 }
230
231 async fn cached_dimension(&self, table_name: &str) -> Option<usize> {
232 self.dimensions.read().await.get(table_name).copied()
233 }
234
235 async fn resolved_dimension(&self, table_name: &str) -> VectorDBResult<usize> {
236 if let Some(dim) = self.cached_dimension(table_name).await {
237 return Ok(dim);
238 }
239 let table = self
240 .connection
241 .open_table(table_name)
242 .execute()
243 .await
244 .map_err(|e| match e {
245 lancedb::Error::TableNotFound { .. } => {
246 VectorDBError::CollectionNotFound(table_name.to_string())
247 }
248 other => map_lance_err(other),
249 })?;
250 let schema = table.schema().await.map_err(map_lance_err)?;
251 let dim = dimension_from_schema(&schema).ok_or_else(|| {
252 VectorDBError::StorageError(format!(
253 "table '{table_name}' has no FixedSizeList<Float32, _> vector column"
254 ))
255 })?;
256 self.dimensions
257 .write()
258 .await
259 .insert(table_name.to_string(), dim);
260 Ok(dim)
261 }
262}
263
264#[async_trait]
265impl VectorDB for LanceDbAdapter {
266 async fn create_collection(
267 &self,
268 data_type: &str,
269 field_name: &str,
270 dimension: usize,
271 ) -> VectorDBResult<()> {
272 let name = collection_name(data_type, field_name);
273 if self.has_collection(data_type, field_name).await? {
274 return Ok(());
276 }
277 let schema = build_schema(dimension);
278 self.connection
279 .create_empty_table(&name, schema)
280 .execute()
281 .await
282 .map_err(map_lance_err)?;
283 self.dimensions.write().await.insert(name, dimension);
284 Ok(())
285 }
286
287 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
288 let target = collection_name(data_type, field_name);
289 let names = self
290 .connection
291 .table_names()
292 .execute()
293 .await
294 .map_err(map_lance_err)?;
295 Ok(names.iter().any(|n| n == &target))
296 }
297
298 async fn index_points(
299 &self,
300 data_type: &str,
301 field_name: &str,
302 points: &[VectorPoint],
303 ) -> VectorDBResult<()> {
304 if points.is_empty() {
305 return Ok(());
306 }
307 let name = collection_name(data_type, field_name);
308 let dimension = self.resolved_dimension(&name).await?;
309 let schema = build_schema(dimension);
310 let table = self
311 .connection
312 .open_table(&name)
313 .execute()
314 .await
315 .map_err(map_lance_err)?;
316 let id_values: Vec<String> = points
318 .iter()
319 .map(|p| {
320 let bytes = p.id.as_bytes();
321 let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect();
323 format!("X'{hex}'")
324 })
325 .collect();
326 let predicate = format!("id IN ({})", id_values.join(", "));
327
328 let existing = if id_values.is_empty() {
334 HashMap::new()
335 } else {
336 let stream = table
337 .query()
338 .only_if(predicate.clone())
339 .execute()
340 .await
341 .map_err(map_lance_err)?;
342 let batches: Vec<RecordBatch> = stream.try_collect().await.map_err(map_lance_err)?;
343 id_metadata_from_batches(batches)?
344 };
345 let merged_points: Vec<VectorPoint> = points
346 .iter()
347 .map(|p| {
348 let mut np = p.clone();
349 if let Some(prev_meta) = existing.get(&p.id) {
350 let prev = VectorPoint {
351 id: p.id,
352 vector: Vec::new(),
353 metadata: prev_meta.clone(),
354 };
355 np.merge_dataset_membership(&prev);
356 }
357 np
358 })
359 .collect();
360 let batch = points_to_batch(schema.clone(), dimension, &name, &merged_points)?;
361
362 if !id_values.is_empty() {
363 table
364 .delete(predicate.as_str())
365 .await
366 .map_err(map_lance_err)?;
367 }
368 let _ = schema; table
370 .add(vec![batch])
371 .execute()
372 .await
373 .map_err(map_lance_err)?;
374 Ok(())
375 }
376
377 async fn search_similar(
378 &self,
379 data_type: &str,
380 field_name: &str,
381 query_vector: &[f32],
382 top_k: usize,
383 ) -> VectorDBResult<Vec<SearchResult>> {
384 let name = collection_name(data_type, field_name);
385 let table = self
386 .connection
387 .open_table(&name)
388 .execute()
389 .await
390 .map_err(|e| match e {
391 lancedb::Error::TableNotFound { .. } => {
392 VectorDBError::CollectionNotFound(name.clone())
393 }
394 other => map_lance_err(other),
395 })?;
396 let stream = table
397 .query()
398 .limit(top_k)
399 .nearest_to(query_vector)
400 .map_err(map_lance_err)?
401 .distance_type(DistanceType::Cosine)
402 .execute()
403 .await
404 .map_err(map_lance_err)?;
405 let batches: Vec<RecordBatch> = stream.try_collect().await.map_err(map_lance_err)?;
406 search_results_from_batches(batches)
407 }
408
409 async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
410 let name = collection_name(data_type, field_name);
411 match self.connection.drop_table(&name, &[]).await {
412 Ok(()) => {
413 self.dimensions.write().await.remove(&name);
414 Ok(())
415 }
416 Err(lancedb::Error::TableNotFound { .. }) => Ok(()),
417 Err(other) => Err(map_lance_err(other)),
418 }
419 }
420
421 async fn delete_points(
422 &self,
423 data_type: &str,
424 field_name: &str,
425 point_ids: &[Uuid],
426 ) -> VectorDBResult<()> {
427 if point_ids.is_empty() {
428 return Ok(());
429 }
430 let name = collection_name(data_type, field_name);
431 let table = self
432 .connection
433 .open_table(&name)
434 .execute()
435 .await
436 .map_err(|e| match e {
437 lancedb::Error::TableNotFound { .. } => {
438 VectorDBError::CollectionNotFound(name.clone())
439 }
440 other => map_lance_err(other),
441 })?;
442 let id_values: Vec<String> = point_ids
443 .iter()
444 .map(|id| {
445 let hex: String = id.as_bytes().iter().map(|b| format!("{b:02X}")).collect();
446 format!("X'{hex}'")
447 })
448 .collect();
449 let predicate = format!("id IN ({})", id_values.join(", "));
450 table
451 .delete(predicate.as_str())
452 .await
453 .map_err(map_lance_err)?;
454 Ok(())
455 }
456
457 async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
458 let name = collection_name(data_type, field_name);
459 let table = self
460 .connection
461 .open_table(&name)
462 .execute()
463 .await
464 .map_err(|e| match e {
465 lancedb::Error::TableNotFound { .. } => {
466 VectorDBError::CollectionNotFound(name.clone())
467 }
468 other => map_lance_err(other),
469 })?;
470 table.count_rows(None).await.map_err(map_lance_err)
471 }
472
473 async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
474 let names = self
475 .connection
476 .table_names()
477 .execute()
478 .await
479 .map_err(map_lance_err)?;
480 Ok(names
481 .into_iter()
482 .filter_map(|n| {
483 n.find('_')
484 .map(|i| (n[..i].to_string(), n[i + 1..].to_string()))
485 })
486 .collect())
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 #![allow(
493 clippy::unwrap_used,
494 clippy::expect_used,
495 reason = "test code — panics are acceptable"
496 )]
497 use super::*;
498 use serde_json::json;
499 use tempfile::tempdir;
500
501 fn point(id: Uuid, vector: Vec<f32>, kind: &str) -> VectorPoint {
502 VectorPoint::new(id, vector).with_metadata("kind", json!(kind))
503 }
504
505 async fn fresh_adapter() -> (LanceDbAdapter, tempfile::TempDir) {
506 let dir = tempdir().unwrap();
507 let path = dir.path().join("store.lance");
508 let adapter = LanceDbAdapter::new(path).await.unwrap();
509 (adapter, dir)
510 }
511
512 #[tokio::test]
513 async fn create_and_has_collection_roundtrip() {
514 let (adapter, _dir) = fresh_adapter().await;
515 assert!(!adapter.has_collection("Chunk", "text").await.unwrap());
516 adapter.create_collection("Chunk", "text", 4).await.unwrap();
517 assert!(adapter.has_collection("Chunk", "text").await.unwrap());
518 adapter.create_collection("Chunk", "text", 4).await.unwrap();
520 }
521
522 #[tokio::test]
523 async fn index_and_search_finds_closest_point() {
524 let (adapter, _dir) = fresh_adapter().await;
525 adapter.create_collection("Chunk", "text", 3).await.unwrap();
526
527 let target = Uuid::new_v4();
528 let other = Uuid::new_v4();
529 let points = vec![
530 point(target, vec![1.0, 0.0, 0.0], "target"),
531 point(other, vec![0.0, 1.0, 0.0], "other"),
532 ];
533 adapter
534 .index_points("Chunk", "text", &points)
535 .await
536 .unwrap();
537
538 let results = adapter
539 .search_similar("Chunk", "text", &[1.0, 0.0, 0.0], 2)
540 .await
541 .unwrap();
542 assert_eq!(results.len(), 2);
543 assert_eq!(results[0].id, target, "nearest point should be the target");
544 assert_eq!(results[0].metadata.get("kind").unwrap(), &json!("target"));
545 assert!(results[0].score > 0.99);
547 }
548
549 #[tokio::test]
550 async fn upsert_unions_dataset_membership_across_datasets() {
551 let (adapter, _dir) = fresh_adapter().await;
555 adapter
556 .create_collection("DocumentChunk", "text", 3)
557 .await
558 .unwrap();
559
560 let content_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, b"shared content");
561 let dataset_a = Uuid::new_v4();
562 let dataset_b = Uuid::new_v4();
563 let vector = vec![1.0, 0.0, 0.0];
564
565 let p_a = VectorPoint::new(content_id, vector.clone())
566 .with_metadata("dataset_id", json!(dataset_a.to_string()));
567 let p_b = VectorPoint::new(content_id, vector.clone())
568 .with_metadata("dataset_id", json!(dataset_b.to_string()));
569
570 adapter
571 .index_points("DocumentChunk", "text", &[p_a])
572 .await
573 .unwrap();
574 adapter
575 .index_points("DocumentChunk", "text", &[p_b])
576 .await
577 .unwrap();
578
579 assert_eq!(
581 adapter
582 .collection_size("DocumentChunk", "text")
583 .await
584 .unwrap(),
585 1
586 );
587 let results = adapter
589 .search_similar("DocumentChunk", "text", &vector, 5)
590 .await
591 .unwrap();
592 let members: Vec<String> = results[0]
593 .metadata
594 .get(crate::DATASET_IDS_KEY)
595 .and_then(|v| v.as_array())
596 .map(|arr| {
597 arr.iter()
598 .filter_map(|v| v.as_str().map(str::to_string))
599 .collect()
600 })
601 .unwrap_or_default();
602 assert!(
603 members.contains(&dataset_a.to_string()) && members.contains(&dataset_b.to_string()),
604 "expected both datasets in membership, got {members:?}"
605 );
606 }
607
608 #[tokio::test]
609 async fn collection_size_reports_row_count() {
610 let (adapter, _dir) = fresh_adapter().await;
611 adapter.create_collection("Chunk", "text", 2).await.unwrap();
612 let points = vec![
613 point(Uuid::new_v4(), vec![0.0, 1.0], "a"),
614 point(Uuid::new_v4(), vec![1.0, 0.0], "b"),
615 ];
616 adapter
617 .index_points("Chunk", "text", &points)
618 .await
619 .unwrap();
620 assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 2);
621 }
622
623 #[tokio::test]
624 async fn delete_points_removes_by_id() {
625 let (adapter, _dir) = fresh_adapter().await;
626 adapter.create_collection("Chunk", "text", 2).await.unwrap();
627 let keep = Uuid::new_v4();
628 let drop = Uuid::new_v4();
629 adapter
630 .index_points(
631 "Chunk",
632 "text",
633 &[
634 point(keep, vec![1.0, 0.0], "keep"),
635 point(drop, vec![0.0, 1.0], "drop"),
636 ],
637 )
638 .await
639 .unwrap();
640
641 adapter
642 .delete_points("Chunk", "text", &[drop])
643 .await
644 .unwrap();
645
646 assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 1);
647 let results = adapter
648 .search_similar("Chunk", "text", &[0.0, 1.0], 5)
649 .await
650 .unwrap();
651 assert!(results.iter().all(|r| r.id != drop));
652 }
653
654 #[tokio::test]
655 async fn index_points_replaces_existing_id() {
656 let (adapter, _dir) = fresh_adapter().await;
657 adapter.create_collection("Chunk", "text", 2).await.unwrap();
658 let id = Uuid::new_v4();
659 adapter
660 .index_points("Chunk", "text", &[point(id, vec![1.0, 0.0], "v1")])
661 .await
662 .unwrap();
663 adapter
664 .index_points("Chunk", "text", &[point(id, vec![0.0, 1.0], "v2")])
665 .await
666 .unwrap();
667 assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 1);
668
669 let results = adapter
670 .search_similar("Chunk", "text", &[0.0, 1.0], 1)
671 .await
672 .unwrap();
673 assert_eq!(results.len(), 1);
674 assert_eq!(results[0].id, id);
675 assert_eq!(results[0].metadata.get("kind").unwrap(), &json!("v2"));
676 }
677
678 #[tokio::test]
679 async fn delete_collection_drops_table_and_is_idempotent() {
680 let (adapter, _dir) = fresh_adapter().await;
681 adapter.create_collection("Chunk", "text", 2).await.unwrap();
682 assert!(adapter.has_collection("Chunk", "text").await.unwrap());
683 adapter.delete_collection("Chunk", "text").await.unwrap();
684 assert!(!adapter.has_collection("Chunk", "text").await.unwrap());
685 adapter.delete_collection("Chunk", "text").await.unwrap();
687 }
688
689 #[tokio::test]
690 async fn list_and_prune_collections() {
691 let (adapter, _dir) = fresh_adapter().await;
692 adapter.create_collection("Chunk", "text", 2).await.unwrap();
693 adapter
694 .create_collection("Entity", "name", 2)
695 .await
696 .unwrap();
697
698 let mut listed: Vec<_> = adapter.list_collections().await.unwrap();
699 listed.sort();
700 assert_eq!(
701 listed,
702 vec![
703 ("Chunk".to_string(), "text".to_string()),
704 ("Entity".to_string(), "name".to_string()),
705 ]
706 );
707
708 adapter.prune().await.unwrap();
709 assert_eq!(adapter.list_collections().await.unwrap().len(), 0);
710 }
711
712 #[tokio::test]
713 async fn dimension_mismatch_returns_error() {
714 let (adapter, _dir) = fresh_adapter().await;
715 adapter.create_collection("Chunk", "text", 3).await.unwrap();
716 let err = adapter
717 .index_points(
718 "Chunk",
719 "text",
720 &[point(Uuid::new_v4(), vec![1.0, 0.0], "bad")],
721 )
722 .await
723 .unwrap_err();
724 assert!(
725 matches!(
726 err,
727 VectorDBError::DimensionMismatch {
728 expected: 3,
729 actual: 2,
730 ..
731 }
732 ),
733 "expected DimensionMismatch, got {err:?}"
734 );
735 }
736
737 #[tokio::test]
738 async fn store_persists_across_reopen() {
739 let dir = tempdir().unwrap();
740 let path = dir.path().join("persist.lance");
741 let id = Uuid::new_v4();
742
743 {
744 let adapter = LanceDbAdapter::new(path.clone()).await.unwrap();
745 adapter.create_collection("Chunk", "text", 2).await.unwrap();
746 adapter
747 .index_points("Chunk", "text", &[point(id, vec![1.0, 0.0], "v1")])
748 .await
749 .unwrap();
750 }
751
752 let adapter = LanceDbAdapter::new(path).await.unwrap();
754 assert!(adapter.has_collection("Chunk", "text").await.unwrap());
755 let results = adapter
756 .search_similar("Chunk", "text", &[1.0, 0.0], 1)
757 .await
758 .unwrap();
759 assert_eq!(results.len(), 1);
760 assert_eq!(results[0].id, id);
761 }
762}