1use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24use arrow_array::{
25 Array, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, RecordBatch, StringArray,
26 types::Float32Type,
27};
28use arrow_schema::{DataType, Field, Schema, SchemaRef};
29use async_trait::async_trait;
30use futures::TryStreamExt;
31use lancedb::{
32 DistanceType, connect,
33 connection::Connection,
34 query::{ExecutableQuery, QueryBase},
35};
36use tokio::sync::RwLock;
37use uuid::Uuid;
38
39use crate::error::{VectorDBError, VectorDBResult};
40use crate::models::{SearchResult, VectorPoint};
41use crate::vector_db_trait::VectorDB;
42
43fn collection_name(data_type: &str, field_name: &str) -> String {
44 format!("{data_type}_{field_name}")
45}
46
47fn map_lance_err(e: lancedb::Error) -> VectorDBError {
48 VectorDBError::StorageError(format!("lancedb: {e}"))
49}
50
51fn dimension_from_schema(schema: &SchemaRef) -> Option<usize> {
54 schema.field_with_name("vector").ok().and_then(|f| {
55 if let DataType::FixedSizeList(_, dim) = f.data_type() {
56 usize::try_from(*dim).ok()
57 } else {
58 None
59 }
60 })
61}
62
63fn build_schema(dimension: usize) -> SchemaRef {
64 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
65 Arc::new(Schema::new(vec![
66 Field::new("id", DataType::FixedSizeBinary(16), false),
67 Field::new(
68 "vector",
69 DataType::FixedSizeList(vector_field, dimension as i32),
70 false,
71 ),
72 Field::new("metadata", DataType::Utf8, false),
73 ]))
74}
75
76fn points_to_batch(
77 schema: SchemaRef,
78 dimension: usize,
79 collection: &str,
80 points: &[VectorPoint],
81) -> VectorDBResult<RecordBatch> {
82 if let Some(p) = points.iter().find(|p| p.vector.len() != dimension) {
83 return Err(VectorDBError::DimensionMismatch {
84 collection: collection.to_string(),
85 expected: dimension,
86 actual: p.vector.len(),
87 });
88 }
89
90 let id_array = FixedSizeBinaryArray::try_from_iter(points.iter().map(|p| *p.id.as_bytes()))
91 .map_err(|e| VectorDBError::StorageError(format!("id column build: {e}")))?;
92
93 let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
94 points
95 .iter()
96 .map(|p| Some(p.vector.iter().map(|v| Some(*v)).collect::<Vec<_>>())),
97 dimension as i32,
98 );
99
100 let metadata_array = StringArray::from(
101 points
102 .iter()
103 .map(|p| serde_json::to_string(&p.metadata))
104 .collect::<Result<Vec<_>, _>>()?,
105 );
106
107 RecordBatch::try_new(
108 schema,
109 vec![
110 Arc::new(id_array),
111 Arc::new(vector_array),
112 Arc::new(metadata_array),
113 ],
114 )
115 .map_err(|e| VectorDBError::StorageError(format!("record batch build: {e}")))
116}
117
118fn search_results_from_batches(batches: Vec<RecordBatch>) -> VectorDBResult<Vec<SearchResult>> {
119 let mut out = Vec::new();
120 for batch in batches {
121 let id_col = batch
122 .column_by_name("id")
123 .ok_or_else(|| VectorDBError::StorageError("missing id column".to_string()))?
124 .as_any()
125 .downcast_ref::<FixedSizeBinaryArray>()
126 .ok_or_else(|| VectorDBError::StorageError("id column type mismatch".to_string()))?;
127
128 let metadata_col = batch
129 .column_by_name("metadata")
130 .ok_or_else(|| VectorDBError::StorageError("missing metadata column".to_string()))?
131 .as_any()
132 .downcast_ref::<StringArray>()
133 .ok_or_else(|| {
134 VectorDBError::StorageError("metadata column type mismatch".to_string())
135 })?;
136
137 let distance_col = batch
141 .column_by_name("_distance")
142 .ok_or_else(|| VectorDBError::StorageError("missing _distance column".to_string()))?
143 .as_any()
144 .downcast_ref::<Float32Array>()
145 .ok_or_else(|| {
146 VectorDBError::StorageError("_distance column type mismatch".to_string())
147 })?;
148
149 for row in 0..batch.num_rows() {
150 let id_bytes = id_col.value(row);
151 let id = Uuid::from_slice(id_bytes)
152 .map_err(|e| VectorDBError::StorageError(format!("id is not a valid UUID: {e}")))?;
153
154 let metadata: HashMap<String, serde_json::Value> =
155 serde_json::from_str(metadata_col.value(row))?;
156
157 let distance = distance_col.value(row).max(0.0);
159 let score = (1.0 - distance).clamp(-1.0, 1.0);
160
161 out.push(SearchResult {
162 id,
163 score,
164 metadata,
165 });
166 }
167 }
168 Ok(out)
169}
170
171pub struct LanceDbAdapter {
173 connection: Connection,
174 dimensions: Arc<RwLock<HashMap<String, usize>>>,
177}
178
179impl LanceDbAdapter {
180 pub async fn new(path: PathBuf) -> VectorDBResult<Self> {
182 if let Some(parent) = path.parent()
183 && !parent.as_os_str().is_empty()
184 {
185 std::fs::create_dir_all(parent)?;
186 }
187 let uri = path.to_str().ok_or_else(|| {
188 VectorDBError::StorageError(format!("lancedb path is not valid UTF-8: {path:?}"))
189 })?;
190 let connection = connect(uri).execute().await.map_err(map_lance_err)?;
191 Ok(Self {
192 connection,
193 dimensions: Arc::new(RwLock::new(HashMap::new())),
194 })
195 }
196
197 async fn cached_dimension(&self, table_name: &str) -> Option<usize> {
198 self.dimensions.read().await.get(table_name).copied()
199 }
200
201 async fn resolved_dimension(&self, table_name: &str) -> VectorDBResult<usize> {
202 if let Some(dim) = self.cached_dimension(table_name).await {
203 return Ok(dim);
204 }
205 let table = self
206 .connection
207 .open_table(table_name)
208 .execute()
209 .await
210 .map_err(|e| match e {
211 lancedb::Error::TableNotFound { .. } => {
212 VectorDBError::CollectionNotFound(table_name.to_string())
213 }
214 other => map_lance_err(other),
215 })?;
216 let schema = table.schema().await.map_err(map_lance_err)?;
217 let dim = dimension_from_schema(&schema).ok_or_else(|| {
218 VectorDBError::StorageError(format!(
219 "table '{table_name}' has no FixedSizeList<Float32, _> vector column"
220 ))
221 })?;
222 self.dimensions
223 .write()
224 .await
225 .insert(table_name.to_string(), dim);
226 Ok(dim)
227 }
228}
229
230#[async_trait]
231impl VectorDB for LanceDbAdapter {
232 async fn create_collection(
233 &self,
234 data_type: &str,
235 field_name: &str,
236 dimension: usize,
237 ) -> VectorDBResult<()> {
238 let name = collection_name(data_type, field_name);
239 if self.has_collection(data_type, field_name).await? {
240 return Ok(());
242 }
243 let schema = build_schema(dimension);
244 self.connection
245 .create_empty_table(&name, schema)
246 .execute()
247 .await
248 .map_err(map_lance_err)?;
249 self.dimensions.write().await.insert(name, dimension);
250 Ok(())
251 }
252
253 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
254 let target = collection_name(data_type, field_name);
255 let names = self
256 .connection
257 .table_names()
258 .execute()
259 .await
260 .map_err(map_lance_err)?;
261 Ok(names.iter().any(|n| n == &target))
262 }
263
264 async fn index_points(
265 &self,
266 data_type: &str,
267 field_name: &str,
268 points: &[VectorPoint],
269 ) -> VectorDBResult<()> {
270 if points.is_empty() {
271 return Ok(());
272 }
273 let name = collection_name(data_type, field_name);
274 let dimension = self.resolved_dimension(&name).await?;
275 let schema = build_schema(dimension);
276 let batch = points_to_batch(schema.clone(), dimension, &name, points)?;
277 let table = self
278 .connection
279 .open_table(&name)
280 .execute()
281 .await
282 .map_err(map_lance_err)?;
283 let id_values: Vec<String> = points
285 .iter()
286 .map(|p| {
287 let bytes = p.id.as_bytes();
288 let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect();
290 format!("X'{hex}'")
291 })
292 .collect();
293 if !id_values.is_empty() {
294 let predicate = format!("id IN ({})", id_values.join(", "));
295 table
296 .delete(predicate.as_str())
297 .await
298 .map_err(map_lance_err)?;
299 }
300 let _ = schema; table
302 .add(vec![batch])
303 .execute()
304 .await
305 .map_err(map_lance_err)?;
306 Ok(())
307 }
308
309 async fn search_similar(
310 &self,
311 data_type: &str,
312 field_name: &str,
313 query_vector: &[f32],
314 top_k: usize,
315 ) -> VectorDBResult<Vec<SearchResult>> {
316 let name = collection_name(data_type, field_name);
317 let table = self
318 .connection
319 .open_table(&name)
320 .execute()
321 .await
322 .map_err(|e| match e {
323 lancedb::Error::TableNotFound { .. } => {
324 VectorDBError::CollectionNotFound(name.clone())
325 }
326 other => map_lance_err(other),
327 })?;
328 let stream = table
329 .query()
330 .limit(top_k)
331 .nearest_to(query_vector)
332 .map_err(map_lance_err)?
333 .distance_type(DistanceType::Cosine)
334 .execute()
335 .await
336 .map_err(map_lance_err)?;
337 let batches: Vec<RecordBatch> = stream.try_collect().await.map_err(map_lance_err)?;
338 search_results_from_batches(batches)
339 }
340
341 async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
342 let name = collection_name(data_type, field_name);
343 match self.connection.drop_table(&name, &[]).await {
344 Ok(()) => {
345 self.dimensions.write().await.remove(&name);
346 Ok(())
347 }
348 Err(lancedb::Error::TableNotFound { .. }) => Ok(()),
349 Err(other) => Err(map_lance_err(other)),
350 }
351 }
352
353 async fn delete_points(
354 &self,
355 data_type: &str,
356 field_name: &str,
357 point_ids: &[Uuid],
358 ) -> VectorDBResult<()> {
359 if point_ids.is_empty() {
360 return Ok(());
361 }
362 let name = collection_name(data_type, field_name);
363 let table = self
364 .connection
365 .open_table(&name)
366 .execute()
367 .await
368 .map_err(|e| match e {
369 lancedb::Error::TableNotFound { .. } => {
370 VectorDBError::CollectionNotFound(name.clone())
371 }
372 other => map_lance_err(other),
373 })?;
374 let id_values: Vec<String> = point_ids
375 .iter()
376 .map(|id| {
377 let hex: String = id.as_bytes().iter().map(|b| format!("{b:02X}")).collect();
378 format!("X'{hex}'")
379 })
380 .collect();
381 let predicate = format!("id IN ({})", id_values.join(", "));
382 table
383 .delete(predicate.as_str())
384 .await
385 .map_err(map_lance_err)?;
386 Ok(())
387 }
388
389 async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
390 let name = collection_name(data_type, field_name);
391 let table = self
392 .connection
393 .open_table(&name)
394 .execute()
395 .await
396 .map_err(|e| match e {
397 lancedb::Error::TableNotFound { .. } => {
398 VectorDBError::CollectionNotFound(name.clone())
399 }
400 other => map_lance_err(other),
401 })?;
402 table.count_rows(None).await.map_err(map_lance_err)
403 }
404
405 async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
406 let names = self
407 .connection
408 .table_names()
409 .execute()
410 .await
411 .map_err(map_lance_err)?;
412 Ok(names
413 .into_iter()
414 .filter_map(|n| {
415 n.find('_')
416 .map(|i| (n[..i].to_string(), n[i + 1..].to_string()))
417 })
418 .collect())
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 #![allow(
425 clippy::unwrap_used,
426 clippy::expect_used,
427 reason = "test code — panics are acceptable"
428 )]
429 use super::*;
430 use serde_json::json;
431 use tempfile::tempdir;
432
433 fn point(id: Uuid, vector: Vec<f32>, kind: &str) -> VectorPoint {
434 VectorPoint::new(id, vector).with_metadata("kind", json!(kind))
435 }
436
437 async fn fresh_adapter() -> (LanceDbAdapter, tempfile::TempDir) {
438 let dir = tempdir().unwrap();
439 let path = dir.path().join("store.lance");
440 let adapter = LanceDbAdapter::new(path).await.unwrap();
441 (adapter, dir)
442 }
443
444 #[tokio::test]
445 async fn create_and_has_collection_roundtrip() {
446 let (adapter, _dir) = fresh_adapter().await;
447 assert!(!adapter.has_collection("Chunk", "text").await.unwrap());
448 adapter.create_collection("Chunk", "text", 4).await.unwrap();
449 assert!(adapter.has_collection("Chunk", "text").await.unwrap());
450 adapter.create_collection("Chunk", "text", 4).await.unwrap();
452 }
453
454 #[tokio::test]
455 async fn index_and_search_finds_closest_point() {
456 let (adapter, _dir) = fresh_adapter().await;
457 adapter.create_collection("Chunk", "text", 3).await.unwrap();
458
459 let target = Uuid::new_v4();
460 let other = Uuid::new_v4();
461 let points = vec![
462 point(target, vec![1.0, 0.0, 0.0], "target"),
463 point(other, vec![0.0, 1.0, 0.0], "other"),
464 ];
465 adapter
466 .index_points("Chunk", "text", &points)
467 .await
468 .unwrap();
469
470 let results = adapter
471 .search_similar("Chunk", "text", &[1.0, 0.0, 0.0], 2)
472 .await
473 .unwrap();
474 assert_eq!(results.len(), 2);
475 assert_eq!(results[0].id, target, "nearest point should be the target");
476 assert_eq!(results[0].metadata.get("kind").unwrap(), &json!("target"));
477 assert!(results[0].score > 0.99);
479 }
480
481 #[tokio::test]
482 async fn collection_size_reports_row_count() {
483 let (adapter, _dir) = fresh_adapter().await;
484 adapter.create_collection("Chunk", "text", 2).await.unwrap();
485 let points = vec![
486 point(Uuid::new_v4(), vec![0.0, 1.0], "a"),
487 point(Uuid::new_v4(), vec![1.0, 0.0], "b"),
488 ];
489 adapter
490 .index_points("Chunk", "text", &points)
491 .await
492 .unwrap();
493 assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 2);
494 }
495
496 #[tokio::test]
497 async fn delete_points_removes_by_id() {
498 let (adapter, _dir) = fresh_adapter().await;
499 adapter.create_collection("Chunk", "text", 2).await.unwrap();
500 let keep = Uuid::new_v4();
501 let drop = Uuid::new_v4();
502 adapter
503 .index_points(
504 "Chunk",
505 "text",
506 &[
507 point(keep, vec![1.0, 0.0], "keep"),
508 point(drop, vec![0.0, 1.0], "drop"),
509 ],
510 )
511 .await
512 .unwrap();
513
514 adapter
515 .delete_points("Chunk", "text", &[drop])
516 .await
517 .unwrap();
518
519 assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 1);
520 let results = adapter
521 .search_similar("Chunk", "text", &[0.0, 1.0], 5)
522 .await
523 .unwrap();
524 assert!(results.iter().all(|r| r.id != drop));
525 }
526
527 #[tokio::test]
528 async fn index_points_replaces_existing_id() {
529 let (adapter, _dir) = fresh_adapter().await;
530 adapter.create_collection("Chunk", "text", 2).await.unwrap();
531 let id = Uuid::new_v4();
532 adapter
533 .index_points("Chunk", "text", &[point(id, vec![1.0, 0.0], "v1")])
534 .await
535 .unwrap();
536 adapter
537 .index_points("Chunk", "text", &[point(id, vec![0.0, 1.0], "v2")])
538 .await
539 .unwrap();
540 assert_eq!(adapter.collection_size("Chunk", "text").await.unwrap(), 1);
541
542 let results = adapter
543 .search_similar("Chunk", "text", &[0.0, 1.0], 1)
544 .await
545 .unwrap();
546 assert_eq!(results.len(), 1);
547 assert_eq!(results[0].id, id);
548 assert_eq!(results[0].metadata.get("kind").unwrap(), &json!("v2"));
549 }
550
551 #[tokio::test]
552 async fn delete_collection_drops_table_and_is_idempotent() {
553 let (adapter, _dir) = fresh_adapter().await;
554 adapter.create_collection("Chunk", "text", 2).await.unwrap();
555 assert!(adapter.has_collection("Chunk", "text").await.unwrap());
556 adapter.delete_collection("Chunk", "text").await.unwrap();
557 assert!(!adapter.has_collection("Chunk", "text").await.unwrap());
558 adapter.delete_collection("Chunk", "text").await.unwrap();
560 }
561
562 #[tokio::test]
563 async fn list_and_prune_collections() {
564 let (adapter, _dir) = fresh_adapter().await;
565 adapter.create_collection("Chunk", "text", 2).await.unwrap();
566 adapter
567 .create_collection("Entity", "name", 2)
568 .await
569 .unwrap();
570
571 let mut listed: Vec<_> = adapter.list_collections().await.unwrap();
572 listed.sort();
573 assert_eq!(
574 listed,
575 vec![
576 ("Chunk".to_string(), "text".to_string()),
577 ("Entity".to_string(), "name".to_string()),
578 ]
579 );
580
581 adapter.prune().await.unwrap();
582 assert_eq!(adapter.list_collections().await.unwrap().len(), 0);
583 }
584
585 #[tokio::test]
586 async fn dimension_mismatch_returns_error() {
587 let (adapter, _dir) = fresh_adapter().await;
588 adapter.create_collection("Chunk", "text", 3).await.unwrap();
589 let err = adapter
590 .index_points(
591 "Chunk",
592 "text",
593 &[point(Uuid::new_v4(), vec![1.0, 0.0], "bad")],
594 )
595 .await
596 .unwrap_err();
597 assert!(
598 matches!(
599 err,
600 VectorDBError::DimensionMismatch {
601 expected: 3,
602 actual: 2,
603 ..
604 }
605 ),
606 "expected DimensionMismatch, got {err:?}"
607 );
608 }
609
610 #[tokio::test]
611 async fn store_persists_across_reopen() {
612 let dir = tempdir().unwrap();
613 let path = dir.path().join("persist.lance");
614 let id = Uuid::new_v4();
615
616 {
617 let adapter = LanceDbAdapter::new(path.clone()).await.unwrap();
618 adapter.create_collection("Chunk", "text", 2).await.unwrap();
619 adapter
620 .index_points("Chunk", "text", &[point(id, vec![1.0, 0.0], "v1")])
621 .await
622 .unwrap();
623 }
624
625 let adapter = LanceDbAdapter::new(path).await.unwrap();
627 assert!(adapter.has_collection("Chunk", "text").await.unwrap());
628 let results = adapter
629 .search_similar("Chunk", "text", &[1.0, 0.0], 1)
630 .await
631 .unwrap();
632 assert_eq!(results.len(), 1);
633 assert_eq!(results[0].id, id);
634 }
635}