Skip to main content

retrieval_kit/backends/
lancedb.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use arrow_array::{
5    Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
6    UInt64Array, cast::AsArray, types::Float32Type,
7};
8use arrow_schema::{DataType, Field, Schema};
9use futures::TryStreamExt;
10use lancedb::database::CreateTableMode;
11use lancedb::index::scalar::FullTextSearchQuery;
12use lancedb::index::{Index, scalar::FtsIndexBuilder};
13use lancedb::query::{ExecutableQuery, QueryBase, Select};
14use lancedb::{Connection, Error, Result, Table, connect};
15
16const DOCUMENTS_TABLE_NAME: &str = "documents";
17const CHUNKS_TABLE_NAME: &str = "chunks";
18const MIN_VECTOR_INDEX_ROWS: usize = 256;
19
20#[derive(Clone, Debug, Eq, PartialEq)]
21pub struct DocumentRecord {
22    pub document_id: String,
23    pub content: String,
24}
25
26pub struct Chunk {
27    pub document_id: String,
28    pub chunk_index: u64,
29    pub text: String,
30    pub vector: Vec<f32>,
31}
32
33#[derive(Clone, Debug, PartialEq)]
34pub struct ChunkSearchRecord {
35    pub document_id: String,
36    pub text: String,
37    pub distance: f32,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41pub struct ChunkKeywordSearchRecord {
42    pub document_id: String,
43    pub text: String,
44    pub score: f32,
45}
46
47pub struct LanceDbBackend {
48    connection: Connection,
49    vector_dimensions: i32,
50}
51
52impl LanceDbBackend {
53    pub async fn new(path: impl AsRef<Path>, vector_dimensions: i32) -> Result<Self> {
54        if vector_dimensions <= 0 {
55            return Err(Error::InvalidInput {
56                message: "vector_dimensions must be greater than zero".to_string(),
57            });
58        }
59
60        let uri = path.as_ref().to_string_lossy();
61        let connection = connect(uri.as_ref()).execute().await?;
62
63        Ok(Self {
64            connection,
65            vector_dimensions,
66        })
67    }
68
69    pub async fn create_tables(&self) -> Result<()> {
70        self.create_documents_table().await?;
71        self.create_chunks_table().await?;
72        Ok(())
73    }
74
75    pub async fn create_documents_table(&self) -> Result<Table> {
76        self.connection
77            .create_empty_table(DOCUMENTS_TABLE_NAME, self.documents_schema())
78            .mode(CreateTableMode::exist_ok(|request| request))
79            .execute()
80            .await
81    }
82
83    pub async fn create_chunks_table(&self) -> Result<Table> {
84        self.connection
85            .create_empty_table(CHUNKS_TABLE_NAME, self.chunks_schema())
86            .mode(CreateTableMode::exist_ok(|request| request))
87            .execute()
88            .await
89    }
90
91    pub async fn insert_data(&self, documents: &[DocumentRecord], chunks: &[Chunk]) -> Result<()> {
92        let documents_batch = self.documents_batch(documents)?;
93        let chunks_batch = self.chunks_batch(chunks)?;
94
95        let documents_table = self
96            .connection
97            .open_table(DOCUMENTS_TABLE_NAME)
98            .execute()
99            .await?;
100        let chunks_table = self
101            .connection
102            .open_table(CHUNKS_TABLE_NAME)
103            .execute()
104            .await?;
105
106        chunks_table.add(chunks_batch).execute().await?;
107        if let Err(error) = documents_table.add(documents_batch).execute().await {
108            self.delete_chunks_for_documents(&chunks_table, documents)
109                .await;
110            return Err(error);
111        }
112        self.ensure_chunks_indices(&chunks_table).await?;
113
114        Ok(())
115    }
116
117    pub async fn upsert_data(&self, document: &DocumentRecord, chunks: &[Chunk]) -> Result<()> {
118        let documents_batch = self.documents_batch(std::slice::from_ref(document))?;
119        let chunks_batch = self.chunks_batch(chunks)?;
120        let previous_chunks = self.chunks_for_document(&document.document_id).await?;
121        let documents_table = self
122            .connection
123            .open_table(DOCUMENTS_TABLE_NAME)
124            .execute()
125            .await?;
126        let chunks_table = self
127            .connection
128            .open_table(CHUNKS_TABLE_NAME)
129            .execute()
130            .await?;
131
132        self.merge_replace_document_chunks(&chunks_table, &document.document_id, chunks_batch)
133            .await?;
134        if let Err(error) = self
135            .merge_upsert_documents(&documents_table, documents_batch)
136            .await
137        {
138            let _ = self
139                .restore_document_chunks(&chunks_table, &document.document_id, previous_chunks)
140                .await;
141            return Err(error);
142        }
143
144        self.ensure_chunks_indices(&chunks_table).await?;
145
146        Ok(())
147    }
148
149    pub async fn vector_search(
150        &self,
151        query_vector: Vec<f32>,
152        limit: usize,
153    ) -> Result<Vec<ChunkSearchRecord>> {
154        let table = self
155            .connection
156            .open_table(CHUNKS_TABLE_NAME)
157            .execute()
158            .await?;
159        let rows = table
160            .query()
161            .nearest_to(query_vector)?
162            .column("vector")
163            .limit(limit)
164            .select(Select::columns(&["document_id", "text", "_distance"]))
165            .execute()
166            .await?;
167        let batches = rows.try_collect::<Vec<_>>().await?;
168
169        Ok(chunk_search_records_from_batches(&batches))
170    }
171
172    pub async fn keyword_search(
173        &self,
174        query: String,
175        limit: usize,
176    ) -> Result<Vec<ChunkKeywordSearchRecord>> {
177        let table = self
178            .connection
179            .open_table(CHUNKS_TABLE_NAME)
180            .execute()
181            .await?;
182        let full_text_query = FullTextSearchQuery::new(query).with_column("text".to_string())?;
183        let rows = table
184            .query()
185            .full_text_search(full_text_query)
186            .limit(limit)
187            .select(Select::columns(&["document_id", "text", "_score"]))
188            .execute()
189            .await?;
190        let batches = rows.try_collect::<Vec<_>>().await?;
191
192        Ok(chunk_keyword_search_records_from_batches(&batches))
193    }
194
195    pub async fn list_documents(&self) -> Result<Vec<DocumentRecord>> {
196        let table = self
197            .connection
198            .open_table(DOCUMENTS_TABLE_NAME)
199            .execute()
200            .await?;
201        let rows = table
202            .query()
203            .select(Select::columns(&["document_id", "content"]))
204            .execute()
205            .await?;
206        let batches = rows.try_collect::<Vec<_>>().await?;
207        let mut documents = document_records_from_batches(&batches);
208
209        documents.sort_by(|left, right| left.document_id.cmp(&right.document_id));
210        Ok(documents)
211    }
212
213    pub async fn get_document(&self, document_id: &str) -> Result<Option<DocumentRecord>> {
214        let table = self
215            .connection
216            .open_table(DOCUMENTS_TABLE_NAME)
217            .execute()
218            .await?;
219        let rows = table
220            .query()
221            .only_if(document_id_predicate(document_id))
222            .select(Select::columns(&["document_id", "content"]))
223            .limit(1)
224            .execute()
225            .await?;
226        let batches = rows.try_collect::<Vec<_>>().await?;
227
228        Ok(document_records_from_batches(&batches).into_iter().next())
229    }
230
231    pub async fn delete_document(&self, document_id: &str) -> Result<()> {
232        let predicate = document_id_predicate(document_id);
233        let documents_table = self
234            .connection
235            .open_table(DOCUMENTS_TABLE_NAME)
236            .execute()
237            .await?;
238        let chunks_table = self
239            .connection
240            .open_table(CHUNKS_TABLE_NAME)
241            .execute()
242            .await?;
243
244        documents_table.delete(&predicate).await?;
245        chunks_table.delete(&predicate).await?;
246
247        Ok(())
248    }
249
250    pub fn connection(&self) -> &Connection {
251        &self.connection
252    }
253
254    pub fn vector_dimensions(&self) -> i32 {
255        self.vector_dimensions
256    }
257
258    async fn ensure_chunks_vector_index(&self, chunks_table: &Table) -> Result<()> {
259        let indices = chunks_table.list_indices().await?;
260        if indices
261            .iter()
262            .any(|index| index.columns == vec!["vector".to_string()])
263        {
264            return Ok(());
265        }
266        if chunks_table.count_rows(None).await? < MIN_VECTOR_INDEX_ROWS {
267            return Ok(());
268        }
269
270        chunks_table
271            .create_index(&["vector"], Index::Auto)
272            .execute()
273            .await?;
274        Ok(())
275    }
276
277    async fn ensure_chunks_keyword_index(&self, chunks_table: &Table) -> Result<()> {
278        let indices = chunks_table.list_indices().await?;
279        if indices
280            .iter()
281            .any(|index| index.columns == vec!["text".to_string()])
282        {
283            return Ok(());
284        }
285
286        chunks_table
287            .create_index(&["text"], Index::FTS(FtsIndexBuilder::default()))
288            .execute()
289            .await?;
290        Ok(())
291    }
292
293    fn documents_schema(&self) -> Arc<Schema> {
294        Arc::new(Schema::new(vec![
295            Field::new("document_id", DataType::Utf8, false),
296            Field::new("content", DataType::Utf8, false),
297        ]))
298    }
299
300    fn chunks_schema(&self) -> Arc<Schema> {
301        Arc::new(Schema::new(vec![
302            Field::new("document_id", DataType::Utf8, false),
303            Field::new("chunk_index", DataType::UInt64, false),
304            Field::new("text", DataType::Utf8, false),
305            Field::new(
306                "vector",
307                DataType::FixedSizeList(
308                    Arc::new(Field::new("item", DataType::Float32, true)),
309                    self.vector_dimensions,
310                ),
311                false,
312            ),
313        ]))
314    }
315
316    fn documents_batch(&self, data: &[DocumentRecord]) -> Result<RecordBatch> {
317        let document_id_values = Arc::new(StringArray::from_iter_values(
318            data.iter().map(|document| document.document_id.as_str()),
319        ));
320        let content_values = Arc::new(StringArray::from_iter_values(
321            data.iter().map(|document| document.content.as_str()),
322        ));
323
324        Ok(RecordBatch::try_new(
325            self.documents_schema(),
326            vec![document_id_values, content_values],
327        )?)
328    }
329
330    fn chunks_batch(&self, data: &[Chunk]) -> Result<RecordBatch> {
331        let expected_dimensions = self.vector_dimensions as usize;
332        for chunk in data {
333            if chunk.vector.len() != expected_dimensions {
334                return Err(Error::InvalidInput {
335                    message: format!(
336                        "chunk vector has dimension {}, expected {}",
337                        chunk.vector.len(),
338                        expected_dimensions
339                    ),
340                });
341            }
342        }
343
344        let document_id_values = Arc::new(StringArray::from_iter_values(
345            data.iter().map(|chunk| chunk.document_id.as_str()),
346        ));
347        let chunk_index_values = Arc::new(UInt64Array::from_iter_values(
348            data.iter().map(|chunk| chunk.chunk_index),
349        ));
350        let text_values = Arc::new(StringArray::from_iter_values(
351            data.iter().map(|chunk| chunk.text.as_str()),
352        ));
353        let vector_values = Arc::new(
354            FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
355                data.iter()
356                    .map(|chunk| Some(chunk.vector.iter().copied().map(Some))),
357                self.vector_dimensions,
358            ),
359        );
360
361        Ok(RecordBatch::try_new(
362            self.chunks_schema(),
363            vec![
364                document_id_values,
365                chunk_index_values,
366                text_values,
367                vector_values,
368            ],
369        )?)
370    }
371
372    async fn merge_upsert_documents(&self, table: &Table, batch: RecordBatch) -> Result<()> {
373        let mut merge = table.merge_insert(&["document_id"]);
374        merge
375            .when_matched_update_all(None)
376            .when_not_matched_insert_all();
377        merge.execute(record_batch_reader(batch)).await?;
378        Ok(())
379    }
380
381    async fn merge_replace_document_chunks(
382        &self,
383        table: &Table,
384        document_id: &str,
385        batch: RecordBatch,
386    ) -> Result<()> {
387        let mut merge = table.merge_insert(&["document_id", "chunk_index"]);
388        merge
389            .when_matched_update_all(None)
390            .when_not_matched_insert_all()
391            .when_not_matched_by_source_delete(Some(document_id_predicate(document_id)));
392        merge.execute(record_batch_reader(batch)).await?;
393        Ok(())
394    }
395
396    async fn restore_document_chunks(
397        &self,
398        table: &Table,
399        document_id: &str,
400        chunks: Vec<Chunk>,
401    ) -> Result<()> {
402        if chunks.is_empty() {
403            table.delete(&document_id_predicate(document_id)).await?;
404            return Ok(());
405        }
406
407        let batch = self.chunks_batch(&chunks)?;
408        self.merge_replace_document_chunks(table, document_id, batch)
409            .await
410    }
411
412    async fn chunks_for_document(&self, document_id: &str) -> Result<Vec<Chunk>> {
413        let table = self
414            .connection
415            .open_table(CHUNKS_TABLE_NAME)
416            .execute()
417            .await?;
418        let rows = table
419            .query()
420            .only_if(document_id_predicate(document_id))
421            .select(Select::columns(&[
422                "document_id",
423                "chunk_index",
424                "text",
425                "vector",
426            ]))
427            .execute()
428            .await?;
429        let batches = rows.try_collect::<Vec<_>>().await?;
430
431        Ok(chunks_from_batches(&batches))
432    }
433
434    async fn delete_chunks_for_documents(&self, table: &Table, documents: &[DocumentRecord]) {
435        for document in documents {
436            let _ = table
437                .delete(&document_id_predicate(&document.document_id))
438                .await;
439        }
440    }
441
442    async fn ensure_chunks_indices(&self, chunks_table: &Table) -> Result<()> {
443        self.ensure_chunks_vector_index(chunks_table).await?;
444        self.ensure_chunks_keyword_index(chunks_table).await
445    }
446}
447
448fn record_batch_reader(batch: RecordBatch) -> Box<dyn arrow_array::RecordBatchReader + Send> {
449    let schema = batch.schema();
450    Box::new(RecordBatchIterator::new(
451        vec![Ok(batch)].into_iter(),
452        schema,
453    ))
454}
455
456fn document_records_from_batches(batches: &[RecordBatch]) -> Vec<DocumentRecord> {
457    batches
458        .iter()
459        .flat_map(|batch| {
460            let document_ids = batch
461                .column_by_name("document_id")
462                .expect("documents query should include document_id")
463                .as_any()
464                .downcast_ref::<StringArray>()
465                .expect("document_id column should be Utf8");
466            let contents = batch
467                .column_by_name("content")
468                .expect("documents query should include content")
469                .as_any()
470                .downcast_ref::<StringArray>()
471                .expect("content column should be Utf8");
472
473            (0..batch.num_rows())
474                .map(|index| DocumentRecord {
475                    document_id: document_ids.value(index).to_string(),
476                    content: contents.value(index).to_string(),
477                })
478                .collect::<Vec<_>>()
479        })
480        .collect()
481}
482
483fn chunk_search_records_from_batches(batches: &[RecordBatch]) -> Vec<ChunkSearchRecord> {
484    batches
485        .iter()
486        .flat_map(|batch| {
487            let document_ids = batch
488                .column_by_name("document_id")
489                .expect("chunks query should include document_id")
490                .as_any()
491                .downcast_ref::<StringArray>()
492                .expect("document_id column should be Utf8");
493            let texts = batch
494                .column_by_name("text")
495                .expect("chunks query should include text")
496                .as_any()
497                .downcast_ref::<StringArray>()
498                .expect("text column should be Utf8");
499            let distances = batch
500                .column_by_name("_distance")
501                .expect("chunks query should include _distance")
502                .as_primitive::<Float32Type>();
503
504            (0..batch.num_rows())
505                .map(|index| ChunkSearchRecord {
506                    document_id: document_ids.value(index).to_string(),
507                    text: texts.value(index).to_string(),
508                    distance: distances.value(index),
509                })
510                .collect::<Vec<_>>()
511        })
512        .collect()
513}
514
515fn chunk_keyword_search_records_from_batches(
516    batches: &[RecordBatch],
517) -> Vec<ChunkKeywordSearchRecord> {
518    batches
519        .iter()
520        .flat_map(|batch| {
521            let document_ids = batch
522                .column_by_name("document_id")
523                .expect("chunks query should include document_id")
524                .as_any()
525                .downcast_ref::<StringArray>()
526                .expect("document_id column should be Utf8");
527            let texts = batch
528                .column_by_name("text")
529                .expect("chunks query should include text")
530                .as_any()
531                .downcast_ref::<StringArray>()
532                .expect("text column should be Utf8");
533            let scores = batch
534                .column_by_name("_score")
535                .expect("chunks query should include _score")
536                .as_primitive::<Float32Type>();
537
538            (0..batch.num_rows())
539                .map(|index| ChunkKeywordSearchRecord {
540                    document_id: document_ids.value(index).to_string(),
541                    text: texts.value(index).to_string(),
542                    score: scores.value(index),
543                })
544                .collect::<Vec<_>>()
545        })
546        .collect()
547}
548
549fn chunks_from_batches(batches: &[RecordBatch]) -> Vec<Chunk> {
550    batches
551        .iter()
552        .flat_map(|batch| {
553            let document_ids = batch
554                .column_by_name("document_id")
555                .expect("chunks query should include document_id")
556                .as_any()
557                .downcast_ref::<StringArray>()
558                .expect("document_id column should be Utf8");
559            let chunk_indices = batch
560                .column_by_name("chunk_index")
561                .expect("chunks query should include chunk_index")
562                .as_any()
563                .downcast_ref::<UInt64Array>()
564                .expect("chunk_index column should be UInt64");
565            let texts = batch
566                .column_by_name("text")
567                .expect("chunks query should include text")
568                .as_any()
569                .downcast_ref::<StringArray>()
570                .expect("text column should be Utf8");
571            let vectors = batch
572                .column_by_name("vector")
573                .expect("chunks query should include vector")
574                .as_any()
575                .downcast_ref::<FixedSizeListArray>()
576                .expect("vector column should be FixedSizeList");
577
578            (0..batch.num_rows())
579                .map(|index| Chunk {
580                    document_id: document_ids.value(index).to_string(),
581                    chunk_index: chunk_indices.value(index),
582                    text: texts.value(index).to_string(),
583                    vector: vectors
584                        .value(index)
585                        .as_any()
586                        .downcast_ref::<Float32Array>()
587                        .expect("vector item column should be Float32")
588                        .values()
589                        .to_vec(),
590                })
591                .collect::<Vec<_>>()
592        })
593        .collect()
594}
595
596fn document_id_predicate(document_id: &str) -> String {
597    format!("document_id = '{}'", document_id.replace('\'', "''"))
598}
599
600#[cfg(test)]
601mod tests {
602    use super::{CHUNKS_TABLE_NAME, Chunk, DOCUMENTS_TABLE_NAME, DocumentRecord, LanceDbBackend};
603    use arrow_array::StringArray;
604    use futures::TryStreamExt;
605    use lancedb::Error;
606    use lancedb::query::ExecutableQuery;
607
608    fn demo_document() -> DocumentRecord {
609        DocumentRecord {
610            document_id: "demo-doc".to_string(),
611            content: "knight ranger priest rogue".to_string(),
612        }
613    }
614
615    fn demo_chunks() -> Vec<Chunk> {
616        vec![
617            Chunk {
618                document_id: "demo-doc".to_string(),
619                chunk_index: 0,
620                text: "knight".to_string(),
621                vector: vec![0.9, 0.4, 0.8],
622            },
623            Chunk {
624                document_id: "demo-doc".to_string(),
625                chunk_index: 1,
626                text: "ranger".to_string(),
627                vector: vec![0.8, 0.4, 0.7],
628            },
629            Chunk {
630                document_id: "demo-doc".to_string(),
631                chunk_index: 2,
632                text: "priest".to_string(),
633                vector: vec![0.6, 0.2, 0.6],
634            },
635            Chunk {
636                document_id: "demo-doc".to_string(),
637                chunk_index: 3,
638                text: "rogue".to_string(),
639                vector: vec![0.7, 0.4, 0.7],
640            },
641        ]
642    }
643
644    fn five_dimensional_chunks() -> Vec<Chunk> {
645        vec![
646            Chunk {
647                document_id: "five-dim-doc".to_string(),
648                chunk_index: 0,
649                text: "mage".to_string(),
650                vector: vec![0.1, 0.2, 0.3, 0.4, 0.5],
651            },
652            Chunk {
653                document_id: "five-dim-doc".to_string(),
654                chunk_index: 1,
655                text: "paladin".to_string(),
656                vector: vec![0.5, 0.4, 0.3, 0.2, 0.1],
657            },
658        ]
659    }
660
661    #[tokio::test]
662    async fn initializes_lancedb_from_local_path() {
663        let temp_dir = tempfile::tempdir().unwrap();
664
665        let backend = LanceDbBackend::new(temp_dir.path(), 3).await;
666
667        assert!(backend.is_ok());
668    }
669
670    #[tokio::test]
671    async fn creates_empty_tables() {
672        let temp_dir = tempfile::tempdir().unwrap();
673        let backend = LanceDbBackend::new(temp_dir.path(), 5).await.unwrap();
674
675        backend.create_tables().await.unwrap();
676
677        assert_eq!(backend.vector_dimensions(), 5);
678        let documents_table = backend
679            .connection()
680            .open_table(DOCUMENTS_TABLE_NAME)
681            .execute()
682            .await
683            .unwrap();
684        let chunks_table = backend
685            .connection()
686            .open_table(CHUNKS_TABLE_NAME)
687            .execute()
688            .await
689            .unwrap();
690        assert_eq!(documents_table.count_rows(None).await.unwrap(), 0);
691        assert_eq!(chunks_table.count_rows(None).await.unwrap(), 0);
692    }
693
694    #[tokio::test]
695    async fn inserts_demo_rows_into_tables() {
696        let temp_dir = tempfile::tempdir().unwrap();
697        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
698
699        backend.create_tables().await.unwrap();
700        backend
701            .insert_data(&[demo_document()], &demo_chunks())
702            .await
703            .unwrap();
704
705        let documents_table = backend
706            .connection()
707            .open_table(DOCUMENTS_TABLE_NAME)
708            .execute()
709            .await
710            .unwrap();
711        let chunks_table = backend
712            .connection()
713            .open_table(CHUNKS_TABLE_NAME)
714            .execute()
715            .await
716            .unwrap();
717        assert_eq!(documents_table.count_rows(None).await.unwrap(), 1);
718        assert_eq!(chunks_table.count_rows(None).await.unwrap(), 4);
719    }
720
721    #[tokio::test]
722    async fn create_tables_is_idempotent() {
723        let temp_dir = tempfile::tempdir().unwrap();
724        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
725
726        backend.create_tables().await.unwrap();
727        backend
728            .insert_data(&[demo_document()], &demo_chunks())
729            .await
730            .unwrap();
731        backend.create_tables().await.unwrap();
732
733        let chunks_table = backend
734            .connection()
735            .open_table(CHUNKS_TABLE_NAME)
736            .execute()
737            .await
738            .unwrap();
739        assert_eq!(chunks_table.count_rows(None).await.unwrap(), 4);
740    }
741
742    #[tokio::test]
743    async fn inserts_matching_non_default_dimensions() {
744        let temp_dir = tempfile::tempdir().unwrap();
745        let backend = LanceDbBackend::new(temp_dir.path(), 5).await.unwrap();
746
747        backend.create_tables().await.unwrap();
748        backend
749            .insert_data(
750                &[DocumentRecord {
751                    document_id: "five-dim-doc".to_string(),
752                    content: "mage paladin".to_string(),
753                }],
754                &five_dimensional_chunks(),
755            )
756            .await
757            .unwrap();
758
759        let chunks_table = backend
760            .connection()
761            .open_table(CHUNKS_TABLE_NAME)
762            .execute()
763            .await
764            .unwrap();
765        assert_eq!(chunks_table.count_rows(None).await.unwrap(), 2);
766    }
767
768    #[tokio::test]
769    async fn rejects_mismatched_vector_dimensions() {
770        let temp_dir = tempfile::tempdir().unwrap();
771        let backend = LanceDbBackend::new(temp_dir.path(), 5).await.unwrap();
772
773        backend.create_tables().await.unwrap();
774        let error = backend
775            .insert_data(&[demo_document()], &demo_chunks())
776            .await
777            .unwrap_err();
778
779        assert!(matches!(error, Error::InvalidInput { .. }));
780    }
781
782    #[tokio::test]
783    async fn vector_search_returns_ranked_chunk_records() {
784        let temp_dir = tempfile::tempdir().unwrap();
785        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
786
787        backend.create_tables().await.unwrap();
788        backend
789            .insert_data(&[demo_document()], &demo_chunks())
790            .await
791            .unwrap();
792
793        let results = backend.vector_search(vec![0.9, 0.4, 0.8], 2).await.unwrap();
794
795        assert_eq!(results.len(), 2);
796        assert_eq!(results[0].document_id, "demo-doc");
797        assert_eq!(results[0].text, "knight");
798        assert_eq!(results[0].distance, 0.0);
799    }
800
801    #[tokio::test]
802    async fn keyword_search_returns_ranked_chunk_records() {
803        let temp_dir = tempfile::tempdir().unwrap();
804        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
805
806        backend.create_tables().await.unwrap();
807        backend
808            .insert_data(
809                &[DocumentRecord {
810                    document_id: "search-doc".to_string(),
811                    content: "rust database rust search ranger".to_string(),
812                }],
813                &[
814                    Chunk {
815                        document_id: "search-doc".to_string(),
816                        chunk_index: 0,
817                        text: "rust database rust search".to_string(),
818                        vector: vec![0.1, 0.2, 0.3],
819                    },
820                    Chunk {
821                        document_id: "search-doc".to_string(),
822                        chunk_index: 1,
823                        text: "ranger path".to_string(),
824                        vector: vec![0.4, 0.5, 0.6],
825                    },
826                ],
827            )
828            .await
829            .unwrap();
830
831        let results = backend.keyword_search("rust".to_string(), 1).await.unwrap();
832
833        assert_eq!(results.len(), 1);
834        assert_eq!(results[0].document_id, "search-doc");
835        assert_eq!(results[0].text, "rust database rust search");
836        assert!(results[0].score > 0.0);
837    }
838
839    #[tokio::test]
840    async fn keyword_search_returns_empty_for_missing_terms() {
841        let temp_dir = tempfile::tempdir().unwrap();
842        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
843
844        backend.create_tables().await.unwrap();
845        backend
846            .insert_data(&[demo_document()], &demo_chunks())
847            .await
848            .unwrap();
849
850        let results = backend
851            .keyword_search("warlock".to_string(), 10)
852            .await
853            .unwrap();
854
855        assert!(results.is_empty());
856    }
857
858    #[tokio::test]
859    async fn upsert_replaces_rows_for_document_id() {
860        let temp_dir = tempfile::tempdir().unwrap();
861        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
862
863        backend.create_tables().await.unwrap();
864        backend
865            .insert_data(&[demo_document()], &demo_chunks())
866            .await
867            .unwrap();
868        backend
869            .upsert_data(
870                &DocumentRecord {
871                    document_id: "demo-doc".to_string(),
872                    content: "replacement next".to_string(),
873                },
874                &[
875                    Chunk {
876                        document_id: "demo-doc".to_string(),
877                        chunk_index: 0,
878                        text: "replacement".to_string(),
879                        vector: vec![0.1, 0.2, 0.3],
880                    },
881                    Chunk {
882                        document_id: "demo-doc".to_string(),
883                        chunk_index: 1,
884                        text: "next".to_string(),
885                        vector: vec![0.4, 0.5, 0.6],
886                    },
887                ],
888            )
889            .await
890            .unwrap();
891
892        let table = backend
893            .connection()
894            .open_table(CHUNKS_TABLE_NAME)
895            .execute()
896            .await
897            .unwrap();
898        assert_eq!(table.count_rows(None).await.unwrap(), 2);
899
900        let rows = table.query().execute().await.unwrap();
901        let batches = rows.try_collect::<Vec<_>>().await.unwrap();
902        let texts = batches
903            .iter()
904            .flat_map(|batch| {
905                batch
906                    .column_by_name("text")
907                    .unwrap()
908                    .as_any()
909                    .downcast_ref::<StringArray>()
910                    .unwrap()
911                    .iter()
912                    .flatten()
913                    .map(str::to_owned)
914                    .collect::<Vec<_>>()
915            })
916            .collect::<Vec<_>>();
917
918        assert_eq!(texts, vec!["replacement".to_string(), "next".to_string()]);
919        assert_eq!(
920            backend.get_document("demo-doc").await.unwrap(),
921            Some(DocumentRecord {
922                document_id: "demo-doc".to_string(),
923                content: "replacement next".to_string(),
924            })
925        );
926    }
927
928    #[tokio::test]
929    async fn upsert_preserves_repeated_chunk_text_by_index() {
930        let temp_dir = tempfile::tempdir().unwrap();
931        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
932
933        backend.create_tables().await.unwrap();
934        backend
935            .insert_data(&[demo_document()], &demo_chunks())
936            .await
937            .unwrap();
938        backend
939            .upsert_data(
940                &DocumentRecord {
941                    document_id: "demo-doc".to_string(),
942                    content: "repeat repeat".to_string(),
943                },
944                &[
945                    Chunk {
946                        document_id: "demo-doc".to_string(),
947                        chunk_index: 0,
948                        text: "repeat".to_string(),
949                        vector: vec![0.1, 0.2, 0.3],
950                    },
951                    Chunk {
952                        document_id: "demo-doc".to_string(),
953                        chunk_index: 1,
954                        text: "repeat".to_string(),
955                        vector: vec![0.4, 0.5, 0.6],
956                    },
957                ],
958            )
959            .await
960            .unwrap();
961
962        let table = backend
963            .connection()
964            .open_table(CHUNKS_TABLE_NAME)
965            .execute()
966            .await
967            .unwrap();
968        let rows = table.query().execute().await.unwrap();
969        let batches = rows.try_collect::<Vec<_>>().await.unwrap();
970        let texts = batches
971            .iter()
972            .flat_map(|batch| {
973                batch
974                    .column_by_name("text")
975                    .unwrap()
976                    .as_any()
977                    .downcast_ref::<StringArray>()
978                    .unwrap()
979                    .iter()
980                    .flatten()
981                    .map(str::to_owned)
982                    .collect::<Vec<_>>()
983            })
984            .collect::<Vec<_>>();
985
986        assert_eq!(texts, vec!["repeat".to_string(), "repeat".to_string()]);
987    }
988
989    #[tokio::test]
990    async fn upsert_escapes_document_id_predicate() {
991        let temp_dir = tempfile::tempdir().unwrap();
992        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
993
994        backend.create_tables().await.unwrap();
995        backend
996            .insert_data(
997                &[DocumentRecord {
998                    document_id: "doc-'quoted'".to_string(),
999                    content: "old".to_string(),
1000                }],
1001                &[Chunk {
1002                    document_id: "doc-'quoted'".to_string(),
1003                    chunk_index: 0,
1004                    text: "old".to_string(),
1005                    vector: vec![0.1, 0.2, 0.3],
1006                }],
1007            )
1008            .await
1009            .unwrap();
1010        backend
1011            .upsert_data(
1012                &DocumentRecord {
1013                    document_id: "doc-'quoted'".to_string(),
1014                    content: "new".to_string(),
1015                },
1016                &[Chunk {
1017                    document_id: "doc-'quoted'".to_string(),
1018                    chunk_index: 0,
1019                    text: "new".to_string(),
1020                    vector: vec![0.4, 0.5, 0.6],
1021                }],
1022            )
1023            .await
1024            .unwrap();
1025
1026        let table = backend
1027            .connection()
1028            .open_table(CHUNKS_TABLE_NAME)
1029            .execute()
1030            .await
1031            .unwrap();
1032        assert_eq!(table.count_rows(None).await.unwrap(), 1);
1033        assert_eq!(
1034            backend.get_document("doc-'quoted'").await.unwrap(),
1035            Some(DocumentRecord {
1036                document_id: "doc-'quoted'".to_string(),
1037                content: "new".to_string(),
1038            })
1039        );
1040    }
1041
1042    #[tokio::test]
1043    async fn lists_documents_sorted_by_document_id() {
1044        let temp_dir = tempfile::tempdir().unwrap();
1045        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
1046
1047        backend.create_tables().await.unwrap();
1048        backend
1049            .insert_data(
1050                &[
1051                    DocumentRecord {
1052                        document_id: "b-doc".to_string(),
1053                        content: "second".to_string(),
1054                    },
1055                    DocumentRecord {
1056                        document_id: "a-doc".to_string(),
1057                        content: "first".to_string(),
1058                    },
1059                ],
1060                &[
1061                    Chunk {
1062                        document_id: "b-doc".to_string(),
1063                        chunk_index: 0,
1064                        text: "second".to_string(),
1065                        vector: vec![0.1, 0.2, 0.3],
1066                    },
1067                    Chunk {
1068                        document_id: "a-doc".to_string(),
1069                        chunk_index: 0,
1070                        text: "first".to_string(),
1071                        vector: vec![0.4, 0.5, 0.6],
1072                    },
1073                ],
1074            )
1075            .await
1076            .unwrap();
1077
1078        let documents = backend.list_documents().await.unwrap();
1079
1080        assert_eq!(
1081            documents,
1082            vec![
1083                DocumentRecord {
1084                    document_id: "a-doc".to_string(),
1085                    content: "first".to_string(),
1086                },
1087                DocumentRecord {
1088                    document_id: "b-doc".to_string(),
1089                    content: "second".to_string(),
1090                },
1091            ]
1092        );
1093    }
1094
1095    #[tokio::test]
1096    async fn delete_document_removes_document_and_chunks() {
1097        let temp_dir = tempfile::tempdir().unwrap();
1098        let backend = LanceDbBackend::new(temp_dir.path(), 3).await.unwrap();
1099
1100        backend.create_tables().await.unwrap();
1101        backend
1102            .insert_data(&[demo_document()], &demo_chunks())
1103            .await
1104            .unwrap();
1105
1106        backend.delete_document("demo-doc").await.unwrap();
1107
1108        let documents_table = backend
1109            .connection()
1110            .open_table(DOCUMENTS_TABLE_NAME)
1111            .execute()
1112            .await
1113            .unwrap();
1114        let chunks_table = backend
1115            .connection()
1116            .open_table(CHUNKS_TABLE_NAME)
1117            .execute()
1118            .await
1119            .unwrap();
1120        assert_eq!(documents_table.count_rows(None).await.unwrap(), 0);
1121        assert_eq!(chunks_table.count_rows(None).await.unwrap(), 0);
1122    }
1123}