1use std::path::Path;
2use std::sync::Arc;
3
4use arrow_array::{
5 Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
6 UInt64Array, cast::AsArray, types::Float32Type,
7};
8use arrow_schema::{DataType, Field, Schema};
9use futures::TryStreamExt;
10use lancedb::database::CreateTableMode;
11use lancedb::index::scalar::FullTextSearchQuery;
12use lancedb::index::{Index, scalar::FtsIndexBuilder};
13use lancedb::query::{ExecutableQuery, QueryBase, Select};
14use lancedb::{Connection, Error, Result, Table, connect};
15
16const DOCUMENTS_TABLE_NAME: &str = "documents";
17const CHUNKS_TABLE_NAME: &str = "chunks";
18const MIN_VECTOR_INDEX_ROWS: usize = 256;
19
20#[derive(Clone, Debug, Eq, PartialEq)]
21pub struct DocumentRecord {
22 pub document_id: String,
23 pub content: String,
24}
25
26pub struct Chunk {
27 pub document_id: String,
28 pub chunk_index: u64,
29 pub text: String,
30 pub vector: Vec<f32>,
31}
32
33#[derive(Clone, Debug, PartialEq)]
34pub struct ChunkSearchRecord {
35 pub document_id: String,
36 pub text: String,
37 pub distance: f32,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41pub struct ChunkKeywordSearchRecord {
42 pub document_id: String,
43 pub text: String,
44 pub score: f32,
45}
46
47pub struct LanceDbBackend {
48 connection: Connection,
49 vector_dimensions: i32,
50}
51
52impl LanceDbBackend {
53 pub async fn new(path: impl AsRef<Path>, vector_dimensions: i32) -> Result<Self> {
54 if vector_dimensions <= 0 {
55 return Err(Error::InvalidInput {
56 message: "vector_dimensions must be greater than zero".to_string(),
57 });
58 }
59
60 let uri = path.as_ref().to_string_lossy();
61 let connection = connect(uri.as_ref()).execute().await?;
62
63 Ok(Self {
64 connection,
65 vector_dimensions,
66 })
67 }
68
69 pub async fn create_tables(&self) -> Result<()> {
70 self.create_documents_table().await?;
71 self.create_chunks_table().await?;
72 Ok(())
73 }
74
75 pub async fn create_documents_table(&self) -> Result<Table> {
76 self.connection
77 .create_empty_table(DOCUMENTS_TABLE_NAME, self.documents_schema())
78 .mode(CreateTableMode::exist_ok(|request| request))
79 .execute()
80 .await
81 }
82
83 pub async fn create_chunks_table(&self) -> Result<Table> {
84 self.connection
85 .create_empty_table(CHUNKS_TABLE_NAME, self.chunks_schema())
86 .mode(CreateTableMode::exist_ok(|request| request))
87 .execute()
88 .await
89 }
90
91 pub async fn insert_data(&self, documents: &[DocumentRecord], chunks: &[Chunk]) -> Result<()> {
92 let documents_batch = self.documents_batch(documents)?;
93 let chunks_batch = self.chunks_batch(chunks)?;
94
95 let documents_table = self
96 .connection
97 .open_table(DOCUMENTS_TABLE_NAME)
98 .execute()
99 .await?;
100 let chunks_table = self
101 .connection
102 .open_table(CHUNKS_TABLE_NAME)
103 .execute()
104 .await?;
105
106 chunks_table.add(chunks_batch).execute().await?;
107 if let Err(error) = documents_table.add(documents_batch).execute().await {
108 self.delete_chunks_for_documents(&chunks_table, documents)
109 .await;
110 return Err(error);
111 }
112 self.ensure_chunks_indices(&chunks_table).await?;
113
114 Ok(())
115 }
116
117 pub async fn upsert_data(&self, document: &DocumentRecord, chunks: &[Chunk]) -> Result<()> {
118 let documents_batch = self.documents_batch(std::slice::from_ref(document))?;
119 let chunks_batch = self.chunks_batch(chunks)?;
120 let previous_chunks = self.chunks_for_document(&document.document_id).await?;
121 let documents_table = self
122 .connection
123 .open_table(DOCUMENTS_TABLE_NAME)
124 .execute()
125 .await?;
126 let chunks_table = self
127 .connection
128 .open_table(CHUNKS_TABLE_NAME)
129 .execute()
130 .await?;
131
132 self.merge_replace_document_chunks(&chunks_table, &document.document_id, chunks_batch)
133 .await?;
134 if let Err(error) = self
135 .merge_upsert_documents(&documents_table, documents_batch)
136 .await
137 {
138 let _ = self
139 .restore_document_chunks(&chunks_table, &document.document_id, previous_chunks)
140 .await;
141 return Err(error);
142 }
143
144 self.ensure_chunks_indices(&chunks_table).await?;
145
146 Ok(())
147 }
148
149 pub async fn vector_search(
150 &self,
151 query_vector: Vec<f32>,
152 limit: usize,
153 ) -> Result<Vec<ChunkSearchRecord>> {
154 let table = self
155 .connection
156 .open_table(CHUNKS_TABLE_NAME)
157 .execute()
158 .await?;
159 let rows = table
160 .query()
161 .nearest_to(query_vector)?
162 .column("vector")
163 .limit(limit)
164 .select(Select::columns(&["document_id", "text", "_distance"]))
165 .execute()
166 .await?;
167 let batches = rows.try_collect::<Vec<_>>().await?;
168
169 Ok(chunk_search_records_from_batches(&batches))
170 }
171
172 pub async fn keyword_search(
173 &self,
174 query: String,
175 limit: usize,
176 ) -> Result<Vec<ChunkKeywordSearchRecord>> {
177 let table = self
178 .connection
179 .open_table(CHUNKS_TABLE_NAME)
180 .execute()
181 .await?;
182 let full_text_query = FullTextSearchQuery::new(query).with_column("text".to_string())?;
183 let rows = table
184 .query()
185 .full_text_search(full_text_query)
186 .limit(limit)
187 .select(Select::columns(&["document_id", "text", "_score"]))
188 .execute()
189 .await?;
190 let batches = rows.try_collect::<Vec<_>>().await?;
191
192 Ok(chunk_keyword_search_records_from_batches(&batches))
193 }
194
195 pub async fn list_documents(&self) -> Result<Vec<DocumentRecord>> {
196 let table = self
197 .connection
198 .open_table(DOCUMENTS_TABLE_NAME)
199 .execute()
200 .await?;
201 let rows = table
202 .query()
203 .select(Select::columns(&["document_id", "content"]))
204 .execute()
205 .await?;
206 let batches = rows.try_collect::<Vec<_>>().await?;
207 let mut documents = document_records_from_batches(&batches);
208
209 documents.sort_by(|left, right| left.document_id.cmp(&right.document_id));
210 Ok(documents)
211 }
212
213 pub async fn get_document(&self, document_id: &str) -> Result<Option<DocumentRecord>> {
214 let table = self
215 .connection
216 .open_table(DOCUMENTS_TABLE_NAME)
217 .execute()
218 .await?;
219 let rows = table
220 .query()
221 .only_if(document_id_predicate(document_id))
222 .select(Select::columns(&["document_id", "content"]))
223 .limit(1)
224 .execute()
225 .await?;
226 let batches = rows.try_collect::<Vec<_>>().await?;
227
228 Ok(document_records_from_batches(&batches).into_iter().next())
229 }
230
231 pub async fn delete_document(&self, document_id: &str) -> Result<()> {
232 let predicate = document_id_predicate(document_id);
233 let documents_table = self
234 .connection
235 .open_table(DOCUMENTS_TABLE_NAME)
236 .execute()
237 .await?;
238 let chunks_table = self
239 .connection
240 .open_table(CHUNKS_TABLE_NAME)
241 .execute()
242 .await?;
243
244 documents_table.delete(&predicate).await?;
245 chunks_table.delete(&predicate).await?;
246
247 Ok(())
248 }
249
250 pub fn connection(&self) -> &Connection {
251 &self.connection
252 }
253
254 pub fn vector_dimensions(&self) -> i32 {
255 self.vector_dimensions
256 }
257
258 async fn ensure_chunks_vector_index(&self, chunks_table: &Table) -> Result<()> {
259 let indices = chunks_table.list_indices().await?;
260 if indices
261 .iter()
262 .any(|index| index.columns == vec!["vector".to_string()])
263 {
264 return Ok(());
265 }
266 if chunks_table.count_rows(None).await? < MIN_VECTOR_INDEX_ROWS {
267 return Ok(());
268 }
269
270 chunks_table
271 .create_index(&["vector"], Index::Auto)
272 .execute()
273 .await?;
274 Ok(())
275 }
276
277 async fn ensure_chunks_keyword_index(&self, chunks_table: &Table) -> Result<()> {
278 let indices = chunks_table.list_indices().await?;
279 if indices
280 .iter()
281 .any(|index| index.columns == vec!["text".to_string()])
282 {
283 return Ok(());
284 }
285
286 chunks_table
287 .create_index(&["text"], Index::FTS(FtsIndexBuilder::default()))
288 .execute()
289 .await?;
290 Ok(())
291 }
292
293 fn documents_schema(&self) -> Arc<Schema> {
294 Arc::new(Schema::new(vec![
295 Field::new("document_id", DataType::Utf8, false),
296 Field::new("content", DataType::Utf8, false),
297 ]))
298 }
299
300 fn chunks_schema(&self) -> Arc<Schema> {
301 Arc::new(Schema::new(vec![
302 Field::new("document_id", DataType::Utf8, false),
303 Field::new("chunk_index", DataType::UInt64, false),
304 Field::new("text", DataType::Utf8, false),
305 Field::new(
306 "vector",
307 DataType::FixedSizeList(
308 Arc::new(Field::new("item", DataType::Float32, true)),
309 self.vector_dimensions,
310 ),
311 false,
312 ),
313 ]))
314 }
315
316 fn documents_batch(&self, data: &[DocumentRecord]) -> Result<RecordBatch> {
317 let document_id_values = Arc::new(StringArray::from_iter_values(
318 data.iter().map(|document| document.document_id.as_str()),
319 ));
320 let content_values = Arc::new(StringArray::from_iter_values(
321 data.iter().map(|document| document.content.as_str()),
322 ));
323
324 Ok(RecordBatch::try_new(
325 self.documents_schema(),
326 vec![document_id_values, content_values],
327 )?)
328 }
329
330 fn chunks_batch(&self, data: &[Chunk]) -> Result<RecordBatch> {
331 let expected_dimensions = self.vector_dimensions as usize;
332 for chunk in data {
333 if chunk.vector.len() != expected_dimensions {
334 return Err(Error::InvalidInput {
335 message: format!(
336 "chunk vector has dimension {}, expected {}",
337 chunk.vector.len(),
338 expected_dimensions
339 ),
340 });
341 }
342 }
343
344 let document_id_values = Arc::new(StringArray::from_iter_values(
345 data.iter().map(|chunk| chunk.document_id.as_str()),
346 ));
347 let chunk_index_values = Arc::new(UInt64Array::from_iter_values(
348 data.iter().map(|chunk| chunk.chunk_index),
349 ));
350 let text_values = Arc::new(StringArray::from_iter_values(
351 data.iter().map(|chunk| chunk.text.as_str()),
352 ));
353 let vector_values = Arc::new(
354 FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
355 data.iter()
356 .map(|chunk| Some(chunk.vector.iter().copied().map(Some))),
357 self.vector_dimensions,
358 ),
359 );
360
361 Ok(RecordBatch::try_new(
362 self.chunks_schema(),
363 vec![
364 document_id_values,
365 chunk_index_values,
366 text_values,
367 vector_values,
368 ],
369 )?)
370 }
371
372 async fn merge_upsert_documents(&self, table: &Table, batch: RecordBatch) -> Result<()> {
373 let mut merge = table.merge_insert(&["document_id"]);
374 merge
375 .when_matched_update_all(None)
376 .when_not_matched_insert_all();
377 merge.execute(record_batch_reader(batch)).await?;
378 Ok(())
379 }
380
381 async fn merge_replace_document_chunks(
382 &self,
383 table: &Table,
384 document_id: &str,
385 batch: RecordBatch,
386 ) -> Result<()> {
387 let mut merge = table.merge_insert(&["document_id", "chunk_index"]);
388 merge
389 .when_matched_update_all(None)
390 .when_not_matched_insert_all()
391 .when_not_matched_by_source_delete(Some(document_id_predicate(document_id)));
392 merge.execute(record_batch_reader(batch)).await?;
393 Ok(())
394 }
395
396 async fn restore_document_chunks(
397 &self,
398 table: &Table,
399 document_id: &str,
400 chunks: Vec<Chunk>,
401 ) -> Result<()> {
402 if chunks.is_empty() {
403 table.delete(&document_id_predicate(document_id)).await?;
404 return Ok(());
405 }
406
407 let batch = self.chunks_batch(&chunks)?;
408 self.merge_replace_document_chunks(table, document_id, batch)
409 .await
410 }
411
412 async fn chunks_for_document(&self, document_id: &str) -> Result<Vec<Chunk>> {
413 let table = self
414 .connection
415 .open_table(CHUNKS_TABLE_NAME)
416 .execute()
417 .await?;
418 let rows = table
419 .query()
420 .only_if(document_id_predicate(document_id))
421 .select(Select::columns(&[
422 "document_id",
423 "chunk_index",
424 "text",
425 "vector",
426 ]))
427 .execute()
428 .await?;
429 let batches = rows.try_collect::<Vec<_>>().await?;
430
431 Ok(chunks_from_batches(&batches))
432 }
433
434 async fn delete_chunks_for_documents(&self, table: &Table, documents: &[DocumentRecord]) {
435 for document in documents {
436 let _ = table
437 .delete(&document_id_predicate(&document.document_id))
438 .await;
439 }
440 }
441
442 async fn ensure_chunks_indices(&self, chunks_table: &Table) -> Result<()> {
443 self.ensure_chunks_vector_index(chunks_table).await?;
444 self.ensure_chunks_keyword_index(chunks_table).await
445 }
446}
447
448fn record_batch_reader(batch: RecordBatch) -> Box<dyn arrow_array::RecordBatchReader + Send> {
449 let schema = batch.schema();
450 Box::new(RecordBatchIterator::new(
451 vec![Ok(batch)].into_iter(),
452 schema,
453 ))
454}
455
456fn document_records_from_batches(batches: &[RecordBatch]) -> Vec<DocumentRecord> {
457 batches
458 .iter()
459 .flat_map(|batch| {
460 let document_ids = batch
461 .column_by_name("document_id")
462 .expect("documents query should include document_id")
463 .as_any()
464 .downcast_ref::<StringArray>()
465 .expect("document_id column should be Utf8");
466 let contents = batch
467 .column_by_name("content")
468 .expect("documents query should include content")
469 .as_any()
470 .downcast_ref::<StringArray>()
471 .expect("content column should be Utf8");
472
473 (0..batch.num_rows())
474 .map(|index| DocumentRecord {
475 document_id: document_ids.value(index).to_string(),
476 content: contents.value(index).to_string(),
477 })
478 .collect::<Vec<_>>()
479 })
480 .collect()
481}
482
483fn chunk_search_records_from_batches(batches: &[RecordBatch]) -> Vec<ChunkSearchRecord> {
484 batches
485 .iter()
486 .flat_map(|batch| {
487 let document_ids = batch
488 .column_by_name("document_id")
489 .expect("chunks query should include document_id")
490 .as_any()
491 .downcast_ref::<StringArray>()
492 .expect("document_id column should be Utf8");
493 let texts = batch
494 .column_by_name("text")
495 .expect("chunks query should include text")
496 .as_any()
497 .downcast_ref::<StringArray>()
498 .expect("text column should be Utf8");
499 let distances = batch
500 .column_by_name("_distance")
501 .expect("chunks query should include _distance")
502 .as_primitive::<Float32Type>();
503
504 (0..batch.num_rows())
505 .map(|index| ChunkSearchRecord {
506 document_id: document_ids.value(index).to_string(),
507 text: texts.value(index).to_string(),
508 distance: distances.value(index),
509 })
510 .collect::<Vec<_>>()
511 })
512 .collect()
513}
514
515fn chunk_keyword_search_records_from_batches(
516 batches: &[RecordBatch],
517) -> Vec<ChunkKeywordSearchRecord> {
518 batches
519 .iter()
520 .flat_map(|batch| {
521 let document_ids = batch
522 .column_by_name("document_id")
523 .expect("chunks query should include document_id")
524 .as_any()
525 .downcast_ref::<StringArray>()
526 .expect("document_id column should be Utf8");
527 let texts = batch
528 .column_by_name("text")
529 .expect("chunks query should include text")
530 .as_any()
531 .downcast_ref::<StringArray>()
532 .expect("text column should be Utf8");
533 let scores = batch
534 .column_by_name("_score")
535 .expect("chunks query should include _score")
536 .as_primitive::<Float32Type>();
537
538 (0..batch.num_rows())
539 .map(|index| ChunkKeywordSearchRecord {
540 document_id: document_ids.value(index).to_string(),
541 text: texts.value(index).to_string(),
542 score: scores.value(index),
543 })
544 .collect::<Vec<_>>()
545 })
546 .collect()
547}
548
549fn chunks_from_batches(batches: &[RecordBatch]) -> Vec<Chunk> {
550 batches
551 .iter()
552 .flat_map(|batch| {
553 let document_ids = batch
554 .column_by_name("document_id")
555 .expect("chunks query should include document_id")
556 .as_any()
557 .downcast_ref::<StringArray>()
558 .expect("document_id column should be Utf8");
559 let chunk_indices = batch
560 .column_by_name("chunk_index")
561 .expect("chunks query should include chunk_index")
562 .as_any()
563 .downcast_ref::<UInt64Array>()
564 .expect("chunk_index column should be UInt64");
565 let texts = batch
566 .column_by_name("text")
567 .expect("chunks query should include text")
568 .as_any()
569 .downcast_ref::<StringArray>()
570 .expect("text column should be Utf8");
571 let vectors = batch
572 .column_by_name("vector")
573 .expect("chunks query should include vector")
574 .as_any()
575 .downcast_ref::<FixedSizeListArray>()
576 .expect("vector column should be FixedSizeList");
577
578 (0..batch.num_rows())
579 .map(|index| Chunk {
580 document_id: document_ids.value(index).to_string(),
581 chunk_index: chunk_indices.value(index),
582 text: texts.value(index).to_string(),
583 vector: vectors
584 .value(index)
585 .as_any()
586 .downcast_ref::<Float32Array>()
587 .expect("vector item column should be Float32")
588 .values()
589 .to_vec(),
590 })
591 .collect::<Vec<_>>()
592 })
593 .collect()
594}
595
596fn document_id_predicate(document_id: &str) -> String {
597 format!("document_id = '{}'", document_id.replace('\'', "''"))
598}
599
600#[cfg(test)]
601mod tests {
602 use super::{CHUNKS_TABLE_NAME, Chunk, DOCUMENTS_TABLE_NAME, DocumentRecord, LanceDbBackend};
603 use arrow_array::StringArray;
604 use futures::TryStreamExt;
605 use lancedb::Error;
606 use lancedb::query::ExecutableQuery;
607
608 fn demo_document() -> DocumentRecord {
609 DocumentRecord {
610 document_id: "demo-doc".to_string(),
611 content: "knight ranger priest rogue".to_string(),
612 }
613 }
614
615 fn demo_chunks() -> Vec<Chunk> {
616 vec![
617 Chunk {
618 document_id: "demo-doc".to_string(),
619 chunk_index: 0,
620 text: "knight".to_string(),
621 vector: vec![0.9, 0.4, 0.8],
622 },
623 Chunk {
624 document_id: "demo-doc".to_string(),
625 chunk_index: 1,
626 text: "ranger".to_string(),
627 vector: vec![0.8, 0.4, 0.7],
628 },
629 Chunk {
630 document_id: "demo-doc".to_string(),
631 chunk_index: 2,
632 text: "priest".to_string(),
633 vector: vec![0.6, 0.2, 0.6],
634 },
635 Chunk {
636 document_id: "demo-doc".to_string(),
637 chunk_index: 3,
638 text: "rogue".to_string(),
639 vector: vec![0.7, 0.4, 0.7],
640 },
641 ]
642 }
643
644 fn five_dimensional_chunks() -> Vec<Chunk> {
645 vec![
646 Chunk {
647 document_id: "five-dim-doc".to_string(),
648 chunk_index: 0,
649 text: "mage".to_string(),
650 vector: vec![0.1, 0.2, 0.3, 0.4, 0.5],
651 },
652 Chunk {
653 document_id: "five-dim-doc".to_string(),
654 chunk_index: 1,
655 text: "paladin".to_string(),
656 vector: vec![0.5, 0.4, 0.3, 0.2, 0.1],
657 },
658 ]
659 }
660
661 #[tokio::test]
662 async fn initializes_lancedb_from_local_path() {
663 let temp_dir = tempfile::tempdir().unwrap();
664
665 let backend = LanceDbBackend::new(temp_dir.path(), 3).await;
666
667 assert!(backend.is_ok());
668 }
669
670 #[tokio::test]
671 async fn creates_empty_tables() {
672 let temp_dir = tempfile::tempdir().unwrap();
673 let backend = LanceDbBackend::new(temp_dir.path(), 5).await.unwrap();
674
675 backend.create_tables().await.unwrap();
676
677 assert_eq!(backend.vector_dimensions(), 5);
678 let documents_table = backend
679 .connection()
680 .open_table(DOCUMENTS_TABLE_NAME)
681 .execute()
682 .await
683 .unwrap();
684 let chunks_table = backend
685 .connection()
686 .open_table(CHUNKS_TABLE_NAME)
687 .execute()
688 .await
689 .unwrap();
690 assert_eq!(documents_table.count_rows(None).await.unwrap(), 0);
691 assert_eq!(chunks_table.count_rows(None).await.unwrap(), 0);
692 }
693
694 #[tokio::test]
695 async fn inserts_demo_rows_into_tables() {
696 let temp_dir = tempfile::tempdir().unwrap();
697 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
698
699 backend.create_tables().await.unwrap();
700 backend
701 .insert_data(&[demo_document()], &demo_chunks())
702 .await
703 .unwrap();
704
705 let documents_table = backend
706 .connection()
707 .open_table(DOCUMENTS_TABLE_NAME)
708 .execute()
709 .await
710 .unwrap();
711 let chunks_table = backend
712 .connection()
713 .open_table(CHUNKS_TABLE_NAME)
714 .execute()
715 .await
716 .unwrap();
717 assert_eq!(documents_table.count_rows(None).await.unwrap(), 1);
718 assert_eq!(chunks_table.count_rows(None).await.unwrap(), 4);
719 }
720
721 #[tokio::test]
722 async fn create_tables_is_idempotent() {
723 let temp_dir = tempfile::tempdir().unwrap();
724 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
725
726 backend.create_tables().await.unwrap();
727 backend
728 .insert_data(&[demo_document()], &demo_chunks())
729 .await
730 .unwrap();
731 backend.create_tables().await.unwrap();
732
733 let chunks_table = backend
734 .connection()
735 .open_table(CHUNKS_TABLE_NAME)
736 .execute()
737 .await
738 .unwrap();
739 assert_eq!(chunks_table.count_rows(None).await.unwrap(), 4);
740 }
741
742 #[tokio::test]
743 async fn inserts_matching_non_default_dimensions() {
744 let temp_dir = tempfile::tempdir().unwrap();
745 let backend = LanceDbBackend::new(temp_dir.path(), 5).await.unwrap();
746
747 backend.create_tables().await.unwrap();
748 backend
749 .insert_data(
750 &[DocumentRecord {
751 document_id: "five-dim-doc".to_string(),
752 content: "mage paladin".to_string(),
753 }],
754 &five_dimensional_chunks(),
755 )
756 .await
757 .unwrap();
758
759 let chunks_table = backend
760 .connection()
761 .open_table(CHUNKS_TABLE_NAME)
762 .execute()
763 .await
764 .unwrap();
765 assert_eq!(chunks_table.count_rows(None).await.unwrap(), 2);
766 }
767
768 #[tokio::test]
769 async fn rejects_mismatched_vector_dimensions() {
770 let temp_dir = tempfile::tempdir().unwrap();
771 let backend = LanceDbBackend::new(temp_dir.path(), 5).await.unwrap();
772
773 backend.create_tables().await.unwrap();
774 let error = backend
775 .insert_data(&[demo_document()], &demo_chunks())
776 .await
777 .unwrap_err();
778
779 assert!(matches!(error, Error::InvalidInput { .. }));
780 }
781
782 #[tokio::test]
783 async fn vector_search_returns_ranked_chunk_records() {
784 let temp_dir = tempfile::tempdir().unwrap();
785 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
786
787 backend.create_tables().await.unwrap();
788 backend
789 .insert_data(&[demo_document()], &demo_chunks())
790 .await
791 .unwrap();
792
793 let results = backend.vector_search(vec![0.9, 0.4, 0.8], 2).await.unwrap();
794
795 assert_eq!(results.len(), 2);
796 assert_eq!(results[0].document_id, "demo-doc");
797 assert_eq!(results[0].text, "knight");
798 assert_eq!(results[0].distance, 0.0);
799 }
800
801 #[tokio::test]
802 async fn keyword_search_returns_ranked_chunk_records() {
803 let temp_dir = tempfile::tempdir().unwrap();
804 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
805
806 backend.create_tables().await.unwrap();
807 backend
808 .insert_data(
809 &[DocumentRecord {
810 document_id: "search-doc".to_string(),
811 content: "rust database rust search ranger".to_string(),
812 }],
813 &[
814 Chunk {
815 document_id: "search-doc".to_string(),
816 chunk_index: 0,
817 text: "rust database rust search".to_string(),
818 vector: vec![0.1, 0.2, 0.3],
819 },
820 Chunk {
821 document_id: "search-doc".to_string(),
822 chunk_index: 1,
823 text: "ranger path".to_string(),
824 vector: vec![0.4, 0.5, 0.6],
825 },
826 ],
827 )
828 .await
829 .unwrap();
830
831 let results = backend.keyword_search("rust".to_string(), 1).await.unwrap();
832
833 assert_eq!(results.len(), 1);
834 assert_eq!(results[0].document_id, "search-doc");
835 assert_eq!(results[0].text, "rust database rust search");
836 assert!(results[0].score > 0.0);
837 }
838
839 #[tokio::test]
840 async fn keyword_search_returns_empty_for_missing_terms() {
841 let temp_dir = tempfile::tempdir().unwrap();
842 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
843
844 backend.create_tables().await.unwrap();
845 backend
846 .insert_data(&[demo_document()], &demo_chunks())
847 .await
848 .unwrap();
849
850 let results = backend
851 .keyword_search("warlock".to_string(), 10)
852 .await
853 .unwrap();
854
855 assert!(results.is_empty());
856 }
857
858 #[tokio::test]
859 async fn upsert_replaces_rows_for_document_id() {
860 let temp_dir = tempfile::tempdir().unwrap();
861 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
862
863 backend.create_tables().await.unwrap();
864 backend
865 .insert_data(&[demo_document()], &demo_chunks())
866 .await
867 .unwrap();
868 backend
869 .upsert_data(
870 &DocumentRecord {
871 document_id: "demo-doc".to_string(),
872 content: "replacement next".to_string(),
873 },
874 &[
875 Chunk {
876 document_id: "demo-doc".to_string(),
877 chunk_index: 0,
878 text: "replacement".to_string(),
879 vector: vec![0.1, 0.2, 0.3],
880 },
881 Chunk {
882 document_id: "demo-doc".to_string(),
883 chunk_index: 1,
884 text: "next".to_string(),
885 vector: vec![0.4, 0.5, 0.6],
886 },
887 ],
888 )
889 .await
890 .unwrap();
891
892 let table = backend
893 .connection()
894 .open_table(CHUNKS_TABLE_NAME)
895 .execute()
896 .await
897 .unwrap();
898 assert_eq!(table.count_rows(None).await.unwrap(), 2);
899
900 let rows = table.query().execute().await.unwrap();
901 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
902 let texts = batches
903 .iter()
904 .flat_map(|batch| {
905 batch
906 .column_by_name("text")
907 .unwrap()
908 .as_any()
909 .downcast_ref::<StringArray>()
910 .unwrap()
911 .iter()
912 .flatten()
913 .map(str::to_owned)
914 .collect::<Vec<_>>()
915 })
916 .collect::<Vec<_>>();
917
918 assert_eq!(texts, vec!["replacement".to_string(), "next".to_string()]);
919 assert_eq!(
920 backend.get_document("demo-doc").await.unwrap(),
921 Some(DocumentRecord {
922 document_id: "demo-doc".to_string(),
923 content: "replacement next".to_string(),
924 })
925 );
926 }
927
928 #[tokio::test]
929 async fn upsert_preserves_repeated_chunk_text_by_index() {
930 let temp_dir = tempfile::tempdir().unwrap();
931 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
932
933 backend.create_tables().await.unwrap();
934 backend
935 .insert_data(&[demo_document()], &demo_chunks())
936 .await
937 .unwrap();
938 backend
939 .upsert_data(
940 &DocumentRecord {
941 document_id: "demo-doc".to_string(),
942 content: "repeat repeat".to_string(),
943 },
944 &[
945 Chunk {
946 document_id: "demo-doc".to_string(),
947 chunk_index: 0,
948 text: "repeat".to_string(),
949 vector: vec![0.1, 0.2, 0.3],
950 },
951 Chunk {
952 document_id: "demo-doc".to_string(),
953 chunk_index: 1,
954 text: "repeat".to_string(),
955 vector: vec![0.4, 0.5, 0.6],
956 },
957 ],
958 )
959 .await
960 .unwrap();
961
962 let table = backend
963 .connection()
964 .open_table(CHUNKS_TABLE_NAME)
965 .execute()
966 .await
967 .unwrap();
968 let rows = table.query().execute().await.unwrap();
969 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
970 let texts = batches
971 .iter()
972 .flat_map(|batch| {
973 batch
974 .column_by_name("text")
975 .unwrap()
976 .as_any()
977 .downcast_ref::<StringArray>()
978 .unwrap()
979 .iter()
980 .flatten()
981 .map(str::to_owned)
982 .collect::<Vec<_>>()
983 })
984 .collect::<Vec<_>>();
985
986 assert_eq!(texts, vec!["repeat".to_string(), "repeat".to_string()]);
987 }
988
989 #[tokio::test]
990 async fn upsert_escapes_document_id_predicate() {
991 let temp_dir = tempfile::tempdir().unwrap();
992 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
993
994 backend.create_tables().await.unwrap();
995 backend
996 .insert_data(
997 &[DocumentRecord {
998 document_id: "doc-'quoted'".to_string(),
999 content: "old".to_string(),
1000 }],
1001 &[Chunk {
1002 document_id: "doc-'quoted'".to_string(),
1003 chunk_index: 0,
1004 text: "old".to_string(),
1005 vector: vec![0.1, 0.2, 0.3],
1006 }],
1007 )
1008 .await
1009 .unwrap();
1010 backend
1011 .upsert_data(
1012 &DocumentRecord {
1013 document_id: "doc-'quoted'".to_string(),
1014 content: "new".to_string(),
1015 },
1016 &[Chunk {
1017 document_id: "doc-'quoted'".to_string(),
1018 chunk_index: 0,
1019 text: "new".to_string(),
1020 vector: vec![0.4, 0.5, 0.6],
1021 }],
1022 )
1023 .await
1024 .unwrap();
1025
1026 let table = backend
1027 .connection()
1028 .open_table(CHUNKS_TABLE_NAME)
1029 .execute()
1030 .await
1031 .unwrap();
1032 assert_eq!(table.count_rows(None).await.unwrap(), 1);
1033 assert_eq!(
1034 backend.get_document("doc-'quoted'").await.unwrap(),
1035 Some(DocumentRecord {
1036 document_id: "doc-'quoted'".to_string(),
1037 content: "new".to_string(),
1038 })
1039 );
1040 }
1041
1042 #[tokio::test]
1043 async fn lists_documents_sorted_by_document_id() {
1044 let temp_dir = tempfile::tempdir().unwrap();
1045 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
1046
1047 backend.create_tables().await.unwrap();
1048 backend
1049 .insert_data(
1050 &[
1051 DocumentRecord {
1052 document_id: "b-doc".to_string(),
1053 content: "second".to_string(),
1054 },
1055 DocumentRecord {
1056 document_id: "a-doc".to_string(),
1057 content: "first".to_string(),
1058 },
1059 ],
1060 &[
1061 Chunk {
1062 document_id: "b-doc".to_string(),
1063 chunk_index: 0,
1064 text: "second".to_string(),
1065 vector: vec![0.1, 0.2, 0.3],
1066 },
1067 Chunk {
1068 document_id: "a-doc".to_string(),
1069 chunk_index: 0,
1070 text: "first".to_string(),
1071 vector: vec![0.4, 0.5, 0.6],
1072 },
1073 ],
1074 )
1075 .await
1076 .unwrap();
1077
1078 let documents = backend.list_documents().await.unwrap();
1079
1080 assert_eq!(
1081 documents,
1082 vec![
1083 DocumentRecord {
1084 document_id: "a-doc".to_string(),
1085 content: "first".to_string(),
1086 },
1087 DocumentRecord {
1088 document_id: "b-doc".to_string(),
1089 content: "second".to_string(),
1090 },
1091 ]
1092 );
1093 }
1094
1095 #[tokio::test]
1096 async fn delete_document_removes_document_and_chunks() {
1097 let temp_dir = tempfile::tempdir().unwrap();
1098 let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
1099
1100 backend.create_tables().await.unwrap();
1101 backend
1102 .insert_data(&[demo_document()], &demo_chunks())
1103 .await
1104 .unwrap();
1105
1106 backend.delete_document("demo-doc").await.unwrap();
1107
1108 let documents_table = backend
1109 .connection()
1110 .open_table(DOCUMENTS_TABLE_NAME)
1111 .execute()
1112 .await
1113 .unwrap();
1114 let chunks_table = backend
1115 .connection()
1116 .open_table(CHUNKS_TABLE_NAME)
1117 .execute()
1118 .await
1119 .unwrap();
1120 assert_eq!(documents_table.count_rows(None).await.unwrap(), 0);
1121 assert_eq!(chunks_table.count_rows(None).await.unwrap(), 0);
1122 }
1123}