1pub mod arrow_convert;
12
13use anyhow::{Context, Result};
14use arrow_array::{
15 Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
16 UInt32Array, types::Float32Type,
17};
18use arrow_schema::{DataType, Field, Schema};
19use futures::stream::TryStreamExt;
20use lancedb::Table;
21use lancedb::connection::Connection;
22use lancedb::query::{ExecutableQuery, QueryBase};
23use sha2::{Digest, Sha256};
24use std::collections::HashMap;
25use std::sync::{Arc, RwLock};
26
27use crate::bm25_search::{BM25Search, RrfScorer, SearchScorer};
28use crate::databases::traits::{
29 ChunkMetadata, DatabaseStats, SearchResult, StorageBackend, VectorDatabase,
30};
31use crate::databases::types::{FieldDef, Filter, Record, ScoredRecord};
32use crate::glob_utils;
33
34use arrow_convert::{
35 batch_to_records, extract_field_value, field_defs_to_schema, filter_to_sql, records_to_batch,
36};
37
38const RAG_TABLE_NAME: &str = "code_embeddings";
40
41pub struct LanceDatabase {
59 connection: Connection,
60 db_path: String,
61 rag_table_name: String,
63 bm25_indexes: Arc<RwLock<HashMap<String, BM25Search>>>,
65 scorer: Arc<dyn SearchScorer>,
67}
68
69impl LanceDatabase {
70 pub async fn new(db_path: impl Into<String>) -> Result<Self> {
75 let db_path = db_path.into();
76
77 if let Some(parent) = std::path::Path::new(&db_path).parent() {
78 std::fs::create_dir_all(parent).context("Failed to create database directory")?;
79 }
80
81 let connection = lancedb::connect(&db_path)
82 .execute()
83 .await
84 .context("Failed to connect to LanceDB")?;
85
86 Ok(Self {
87 connection,
88 db_path,
89 rag_table_name: RAG_TABLE_NAME.to_string(),
90 bm25_indexes: Arc::new(RwLock::new(HashMap::new())),
91 scorer: Arc::new(RrfScorer),
92 })
93 }
94
95 pub async fn with_default_path() -> Result<Self> {
97 let db_path = Self::default_lancedb_path();
98 Self::new(db_path).await
99 }
100
101 pub fn with_scorer(mut self, scorer: Arc<dyn SearchScorer>) -> Self {
103 self.scorer = scorer;
104 self
105 }
106
107 pub fn connection(&self) -> &Connection {
109 &self.connection
110 }
111
112 pub fn db_path(&self) -> &str {
114 &self.db_path
115 }
116
117 pub fn capabilities(&self) -> crate::databases::BackendCapabilities {
119 crate::databases::BackendCapabilities {
120 vector_search: true,
121 }
122 }
123
124 pub fn default_lancedb_path() -> String {
126 crate::paths::PlatformPaths::default_lancedb_path()
127 .to_string_lossy()
128 .to_string()
129 }
130
131 fn hash_root_path(root_path: &str) -> String {
134 let mut hasher = Sha256::new();
135 hasher.update(root_path.as_bytes());
136 let result = hasher.finalize();
137 format!("{:x}", result)[..16].to_string()
138 }
139
140 fn bm25_path_for_root(&self, root_path: &str) -> String {
141 let hash = Self::hash_root_path(root_path);
142 format!("{}/bm25_{}", self.db_path, hash)
143 }
144
145 fn get_or_create_bm25(&self, root_path: &str) -> Result<()> {
146 let hash = Self::hash_root_path(root_path);
147
148 {
149 let indexes = self.bm25_indexes.read().map_err(|e| {
150 anyhow::anyhow!("Failed to acquire read lock on BM25 indexes: {}", e)
151 })?;
152 if indexes.contains_key(&hash) {
153 return Ok(());
154 }
155 }
156
157 let mut indexes = self
158 .bm25_indexes
159 .write()
160 .map_err(|e| anyhow::anyhow!("Failed to acquire write lock on BM25 indexes: {}", e))?;
161
162 if indexes.contains_key(&hash) {
163 return Ok(());
164 }
165
166 let bm25_path = self.bm25_path_for_root(root_path);
167 tracing::info!(
168 "Creating BM25 index for root path '{}' at: {}",
169 root_path,
170 bm25_path
171 );
172
173 let bm25_index = BM25Search::new(&bm25_path)
174 .with_context(|| format!("Failed to initialize BM25 index for root: {}", root_path))?;
175
176 indexes.insert(hash, bm25_index);
177 Ok(())
178 }
179
180 fn create_rag_schema(dimension: usize) -> Arc<Schema> {
181 Arc::new(Schema::new(vec![
182 Field::new(
183 "vector",
184 DataType::FixedSizeList(
185 Arc::new(Field::new("item", DataType::Float32, true)),
186 dimension as i32,
187 ),
188 false,
189 ),
190 Field::new("id", DataType::Utf8, false),
191 Field::new("file_path", DataType::Utf8, false),
192 Field::new("root_path", DataType::Utf8, true),
193 Field::new("start_line", DataType::UInt32, false),
194 Field::new("end_line", DataType::UInt32, false),
195 Field::new("language", DataType::Utf8, false),
196 Field::new("extension", DataType::Utf8, false),
197 Field::new("file_hash", DataType::Utf8, false),
198 Field::new("indexed_at", DataType::Utf8, false),
199 Field::new("content", DataType::Utf8, false),
200 Field::new("project", DataType::Utf8, true),
201 ]))
202 }
203
204 async fn get_rag_table(&self) -> Result<Table> {
205 self.connection
206 .open_table(&self.rag_table_name)
207 .execute()
208 .await
209 .context("Failed to open RAG table")
210 }
211
212 fn create_rag_record_batch(
213 embeddings: Vec<Vec<f32>>,
214 metadata: Vec<ChunkMetadata>,
215 contents: Vec<String>,
216 schema: Arc<Schema>,
217 ) -> Result<RecordBatch> {
218 let num_rows = embeddings.len();
219 let dimension = embeddings[0].len();
220
221 let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
222 embeddings
223 .into_iter()
224 .map(|v| Some(v.into_iter().map(Some))),
225 dimension as i32,
226 );
227
228 let id_array = StringArray::from(
229 (0..num_rows)
230 .map(|i| format!("{}:{}", metadata[i].file_path, metadata[i].start_line))
231 .collect::<Vec<_>>(),
232 );
233 let file_path_array = StringArray::from(
234 metadata
235 .iter()
236 .map(|m| m.file_path.as_str())
237 .collect::<Vec<_>>(),
238 );
239 let root_path_array = StringArray::from(
240 metadata
241 .iter()
242 .map(|m| m.root_path.as_deref())
243 .collect::<Vec<_>>(),
244 );
245 let start_line_array = UInt32Array::from(
246 metadata
247 .iter()
248 .map(|m| m.start_line as u32)
249 .collect::<Vec<_>>(),
250 );
251 let end_line_array = UInt32Array::from(
252 metadata
253 .iter()
254 .map(|m| m.end_line as u32)
255 .collect::<Vec<_>>(),
256 );
257 let language_array = StringArray::from(
258 metadata
259 .iter()
260 .map(|m| m.language.as_deref().unwrap_or("Unknown"))
261 .collect::<Vec<_>>(),
262 );
263 let extension_array = StringArray::from(
264 metadata
265 .iter()
266 .map(|m| m.extension.as_deref().unwrap_or(""))
267 .collect::<Vec<_>>(),
268 );
269 let file_hash_array = StringArray::from(
270 metadata
271 .iter()
272 .map(|m| m.file_hash.as_str())
273 .collect::<Vec<_>>(),
274 );
275 let indexed_at_array = StringArray::from(
276 metadata
277 .iter()
278 .map(|m| m.indexed_at.to_string())
279 .collect::<Vec<_>>(),
280 );
281 let content_array =
282 StringArray::from(contents.iter().map(|s| s.as_str()).collect::<Vec<_>>());
283 let project_array = StringArray::from(
284 metadata
285 .iter()
286 .map(|m| m.project.as_deref())
287 .collect::<Vec<_>>(),
288 );
289
290 RecordBatch::try_new(
291 schema,
292 vec![
293 Arc::new(vector_array),
294 Arc::new(id_array),
295 Arc::new(file_path_array),
296 Arc::new(root_path_array),
297 Arc::new(start_line_array),
298 Arc::new(end_line_array),
299 Arc::new(language_array),
300 Arc::new(extension_array),
301 Arc::new(file_hash_array),
302 Arc::new(indexed_at_array),
303 Arc::new(content_array),
304 Arc::new(project_array),
305 ],
306 )
307 .context("Failed to create RecordBatch")
308 }
309}
310
311#[async_trait::async_trait]
314impl StorageBackend for LanceDatabase {
315 async fn ensure_table(&self, table_name: &str, schema: &[FieldDef]) -> Result<()> {
316 let table_names = self.connection.table_names().execute().await?;
317 if table_names.contains(&table_name.to_string()) {
318 return Ok(());
319 }
320
321 let arrow_schema = Arc::new(field_defs_to_schema(schema));
322 let batches = RecordBatchIterator::new(vec![], arrow_schema);
323 self.connection
324 .create_table(table_name, Box::new(batches))
325 .execute()
326 .await
327 .with_context(|| format!("Failed to create table '{table_name}'"))?;
328 Ok(())
329 }
330
331 async fn insert(&self, table_name: &str, records: Vec<Record>) -> Result<()> {
332 if records.is_empty() {
333 return Ok(());
334 }
335
336 let table = self
337 .connection
338 .open_table(table_name)
339 .execute()
340 .await
341 .with_context(|| format!("Failed to open table '{table_name}'"))?;
342
343 let batch = records_to_batch(&records)?;
344 let schema = batch.schema();
345 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
346 table
347 .add(Box::new(batches))
348 .execute()
349 .await
350 .with_context(|| format!("Failed to insert into '{table_name}'"))?;
351 Ok(())
352 }
353
354 async fn query(
355 &self,
356 table_name: &str,
357 filter: Option<&Filter>,
358 limit: Option<usize>,
359 ) -> Result<Vec<Record>> {
360 let table = self
361 .connection
362 .open_table(table_name)
363 .execute()
364 .await
365 .with_context(|| format!("Failed to open table '{table_name}'"))?;
366
367 let mut q = table.query();
368 if let Some(f) = filter {
369 q = q.only_if(filter_to_sql(f));
370 }
371 if let Some(n) = limit {
372 q = q.limit(n);
373 }
374
375 let batches: Vec<RecordBatch> = q
376 .execute()
377 .await
378 .with_context(|| format!("Failed to query '{table_name}'"))?
379 .try_collect()
380 .await?;
381
382 let mut results = Vec::new();
383 for batch in &batches {
384 batch_to_records(batch, &mut results)?;
385 }
386 Ok(results)
387 }
388
389 async fn delete(&self, table_name: &str, filter: &Filter) -> Result<()> {
390 let table = self
391 .connection
392 .open_table(table_name)
393 .execute()
394 .await
395 .with_context(|| format!("Failed to open table '{table_name}'"))?;
396
397 table
398 .delete(&filter_to_sql(filter))
399 .await
400 .with_context(|| format!("Failed to delete from '{table_name}'"))?;
401 Ok(())
402 }
403
404 async fn count(&self, table_name: &str, filter: Option<&Filter>) -> Result<usize> {
405 let table = self
406 .connection
407 .open_table(table_name)
408 .execute()
409 .await
410 .with_context(|| format!("Failed to open table '{table_name}'"))?;
411
412 let mut q = table.query();
413 if let Some(f) = filter {
414 q = q.only_if(filter_to_sql(f));
415 }
416 let batches: Vec<RecordBatch> = q.execute().await?.try_collect().await?;
417 Ok(batches.iter().map(|b| b.num_rows()).sum())
418 }
419
420 async fn vector_search(
421 &self,
422 table_name: &str,
423 _vector_column: &str,
424 vector: Vec<f32>,
425 limit: usize,
426 filter: Option<&Filter>,
427 ) -> Result<Vec<ScoredRecord>> {
428 let table = self
429 .connection
430 .open_table(table_name)
431 .execute()
432 .await
433 .with_context(|| format!("Failed to open table '{table_name}'"))?;
434
435 let mut q = table.vector_search(vector)?;
436 q = q.limit(limit);
437 if let Some(f) = filter {
438 q = q.only_if(filter_to_sql(f));
439 }
440
441 let batches: Vec<RecordBatch> = q.execute().await?.try_collect().await?;
442
443 let mut results = Vec::new();
444 for batch in &batches {
445 let distance_col = batch
446 .column_by_name("_distance")
447 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
448
449 for row in 0..batch.num_rows() {
450 let mut record = Vec::new();
451 for (col_idx, field) in batch.schema().fields().iter().enumerate() {
452 if field.name() == "_distance" {
453 continue;
454 }
455 let val = extract_field_value(batch, col_idx, row, field)?;
456 record.push((field.name().clone(), val));
457 }
458
459 let distance = distance_col.map_or(0.0, |c| c.value(row));
460 let score = 1.0 / (1.0 + distance);
461
462 results.push(ScoredRecord { record, score });
463 }
464 }
465 Ok(results)
466 }
467}
468
469#[async_trait::async_trait]
472impl VectorDatabase for LanceDatabase {
473 async fn initialize(&self, dimension: usize) -> Result<()> {
474 tracing::info!(
475 "Initializing LanceDB with dimension {} at {}",
476 dimension,
477 self.db_path
478 );
479
480 let table_names = self
481 .connection
482 .table_names()
483 .execute()
484 .await
485 .context("Failed to list tables")?;
486
487 if table_names.contains(&self.rag_table_name) {
488 tracing::info!("Table '{}' already exists", self.rag_table_name);
489 return Ok(());
490 }
491
492 let schema = Self::create_rag_schema(dimension);
493 let empty_batch = RecordBatch::new_empty(schema.clone());
494 let batches =
495 RecordBatchIterator::new(vec![empty_batch].into_iter().map(Ok), schema.clone());
496
497 self.connection
498 .create_table(&self.rag_table_name, Box::new(batches))
499 .execute()
500 .await
501 .context("Failed to create table")?;
502
503 tracing::info!("Created table '{}'", self.rag_table_name);
504 Ok(())
505 }
506
507 async fn store_embeddings(
508 &self,
509 embeddings: Vec<Vec<f32>>,
510 metadata: Vec<ChunkMetadata>,
511 contents: Vec<String>,
512 root_path: &str,
513 ) -> Result<usize> {
514 if embeddings.is_empty() {
515 return Ok(0);
516 }
517
518 let dimension = embeddings[0].len();
519 let schema = Self::create_rag_schema(dimension);
520
521 let table = self.get_rag_table().await?;
522 let current_count = table.count_rows(None).await.unwrap_or(0) as u64;
523
524 let batch = Self::create_rag_record_batch(
525 embeddings,
526 metadata.clone(),
527 contents.clone(),
528 schema.clone(),
529 )?;
530 let count = batch.num_rows();
531
532 let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
533
534 table
535 .add(Box::new(batches))
536 .execute()
537 .await
538 .context("Failed to add records to table")?;
539
540 self.get_or_create_bm25(root_path)?;
541
542 let bm25_docs: Vec<_> = (0..count)
543 .map(|i| {
544 let id = current_count + i as u64;
545 (id, contents[i].clone(), metadata[i].file_path.clone())
546 })
547 .collect();
548
549 let hash = Self::hash_root_path(root_path);
550 let bm25_indexes = self
551 .bm25_indexes
552 .read()
553 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
554
555 if let Some(bm25) = bm25_indexes.get(&hash) {
556 bm25.add_documents(bm25_docs)
557 .context("Failed to add documents to BM25 index")?;
558 }
559 drop(bm25_indexes);
560
561 tracing::info!(
562 "Stored {} embeddings with BM25 indexing for root: {}",
563 count,
564 root_path
565 );
566 Ok(count)
567 }
568
569 async fn search(
570 &self,
571 query_vector: Vec<f32>,
572 query_text: &str,
573 limit: usize,
574 min_score: f32,
575 project: Option<String>,
576 root_path: Option<String>,
577 hybrid: bool,
578 ) -> Result<Vec<SearchResult>> {
579 let table = self.get_rag_table().await?;
580
581 if hybrid {
582 let search_limit = limit * 3;
583
584 let query = table
585 .vector_search(query_vector)
586 .context("Failed to create vector search")?
587 .limit(search_limit);
588
589 let stream = if let Some(ref project_name) = project {
590 query
591 .only_if(format!("project = '{}'", project_name))
592 .execute()
593 .await
594 .context("Failed to execute search")?
595 } else {
596 query.execute().await.context("Failed to execute search")?
597 };
598
599 let results: Vec<RecordBatch> = stream
600 .try_collect()
601 .await
602 .context("Failed to collect search results")?;
603
604 let mut vector_results = Vec::new();
605 let mut row_offset = 0u64;
606 let mut original_scores: HashMap<u64, (f32, Option<f32>)> = HashMap::new();
607
608 for batch in &results {
609 let distance_array = batch
610 .column_by_name("_distance")
611 .context("Missing _distance column")?
612 .as_any()
613 .downcast_ref::<Float32Array>()
614 .context("Invalid _distance type")?;
615
616 for i in 0..batch.num_rows() {
617 let distance = distance_array.value(i);
618 let score = 1.0 / (1.0 + distance);
619 let id = row_offset + i as u64;
620 vector_results.push((id, score));
621 original_scores.insert(id, (score, None));
622 }
623 row_offset += batch.num_rows() as u64;
624 }
625
626 let bm25_indexes = self
627 .bm25_indexes
628 .read()
629 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
630
631 let mut all_bm25_results = Vec::new();
632 for (root_hash, bm25) in bm25_indexes.iter() {
633 tracing::debug!("Searching BM25 index for root hash: {}", root_hash);
634 let bm25_results = bm25
635 .search(query_text, search_limit)
636 .context("Failed to search BM25 index")?;
637
638 for result in &bm25_results {
639 original_scores
640 .entry(result.id)
641 .and_modify(|e| e.1 = Some(result.score))
642 .or_insert((0.0, Some(result.score)));
643 }
644
645 all_bm25_results.extend(bm25_results);
646 }
647 drop(bm25_indexes);
648
649 let combined = self.scorer.fuse(vector_results, all_bm25_results, limit);
650
651 let mut search_results = Vec::new();
652
653 for (id, combined_score) in combined {
654 let mut found = false;
655 let mut batch_offset = 0u64;
656
657 for batch in &results {
658 if id >= batch_offset && id < batch_offset + batch.num_rows() as u64 {
659 let idx = (id - batch_offset) as usize;
660
661 let file_path_array = batch
662 .column_by_name("file_path")
663 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
664 let root_path_array = batch
665 .column_by_name("root_path")
666 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
667 let start_line_array = batch
668 .column_by_name("start_line")
669 .and_then(|c| c.as_any().downcast_ref::<UInt32Array>());
670 let end_line_array = batch
671 .column_by_name("end_line")
672 .and_then(|c| c.as_any().downcast_ref::<UInt32Array>());
673 let language_array = batch
674 .column_by_name("language")
675 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
676 let content_array = batch
677 .column_by_name("content")
678 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
679 let project_array = batch
680 .column_by_name("project")
681 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
682 let indexed_at_array = batch
683 .column_by_name("indexed_at")
684 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
685
686 if let (
687 Some(fp),
688 Some(rp),
689 Some(sl),
690 Some(el),
691 Some(lang),
692 Some(cont),
693 Some(proj),
694 ) = (
695 file_path_array,
696 root_path_array,
697 start_line_array,
698 end_line_array,
699 language_array,
700 content_array,
701 project_array,
702 ) {
703 let (vector_score, keyword_score) =
704 original_scores.get(&id).copied().unwrap_or((0.0, None));
705
706 let passes_filter = vector_score >= min_score
707 || keyword_score.is_some_and(|k| k >= min_score);
708
709 if passes_filter {
710 let result_root_path = if rp.is_null(idx) {
711 None
712 } else {
713 Some(rp.value(idx).to_string())
714 };
715
716 if let Some(ref filter_path) = root_path
717 && result_root_path.as_ref() != Some(filter_path)
718 {
719 found = true;
720 break;
721 }
722
723 search_results.push(SearchResult {
724 score: combined_score,
725 vector_score,
726 keyword_score,
727 file_path: fp.value(idx).to_string(),
728 root_path: result_root_path,
729 start_line: sl.value(idx) as usize,
730 end_line: el.value(idx) as usize,
731 language: lang.value(idx).to_string(),
732 content: cont.value(idx).to_string(),
733 project: if proj.is_null(idx) {
734 None
735 } else {
736 Some(proj.value(idx).to_string())
737 },
738 indexed_at: indexed_at_array
739 .and_then(|ia| ia.value(idx).parse::<i64>().ok())
740 .unwrap_or(0),
741 });
742 }
743 found = true;
744 break;
745 }
746 }
747 batch_offset += batch.num_rows() as u64;
748 }
749
750 if !found {
751 tracing::warn!("Could not find result for RRF ID {}", id);
752 }
753 }
754
755 Ok(search_results)
756 } else {
757 let query = table
759 .vector_search(query_vector)
760 .context("Failed to create vector search")?
761 .limit(limit);
762
763 let stream = if let Some(ref project_name) = project {
764 query
765 .only_if(format!("project = '{}'", project_name))
766 .execute()
767 .await
768 .context("Failed to execute search")?
769 } else {
770 query.execute().await.context("Failed to execute search")?
771 };
772
773 let results: Vec<RecordBatch> = stream
774 .try_collect()
775 .await
776 .context("Failed to collect search results")?;
777
778 let mut search_results = Vec::new();
779
780 for batch in results {
781 let file_path_array = batch
782 .column_by_name("file_path")
783 .context("Missing file_path column")?
784 .as_any()
785 .downcast_ref::<StringArray>()
786 .context("Invalid file_path type")?;
787
788 let root_path_array = batch
789 .column_by_name("root_path")
790 .context("Missing root_path column")?
791 .as_any()
792 .downcast_ref::<StringArray>()
793 .context("Invalid root_path type")?;
794
795 let start_line_array = batch
796 .column_by_name("start_line")
797 .context("Missing start_line column")?
798 .as_any()
799 .downcast_ref::<UInt32Array>()
800 .context("Invalid start_line type")?;
801
802 let end_line_array = batch
803 .column_by_name("end_line")
804 .context("Missing end_line column")?
805 .as_any()
806 .downcast_ref::<UInt32Array>()
807 .context("Invalid end_line type")?;
808
809 let language_array = batch
810 .column_by_name("language")
811 .context("Missing language column")?
812 .as_any()
813 .downcast_ref::<StringArray>()
814 .context("Invalid language type")?;
815
816 let content_array = batch
817 .column_by_name("content")
818 .context("Missing content column")?
819 .as_any()
820 .downcast_ref::<StringArray>()
821 .context("Invalid content type")?;
822
823 let project_array = batch
824 .column_by_name("project")
825 .context("Missing project column")?
826 .as_any()
827 .downcast_ref::<StringArray>()
828 .context("Invalid project type")?;
829
830 let distance_array = batch
831 .column_by_name("_distance")
832 .context("Missing _distance column")?
833 .as_any()
834 .downcast_ref::<Float32Array>()
835 .context("Invalid _distance type")?;
836
837 let indexed_at_array = batch
838 .column_by_name("indexed_at")
839 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
840
841 for i in 0..batch.num_rows() {
842 let distance = distance_array.value(i);
843 let score = 1.0 / (1.0 + distance);
844
845 if score >= min_score {
846 let result_root_path = if root_path_array.is_null(i) {
847 None
848 } else {
849 Some(root_path_array.value(i).to_string())
850 };
851
852 if let Some(ref filter_path) = root_path
853 && result_root_path.as_ref() != Some(filter_path)
854 {
855 continue;
856 }
857
858 search_results.push(SearchResult {
859 score,
860 vector_score: score,
861 keyword_score: None,
862 file_path: file_path_array.value(i).to_string(),
863 root_path: result_root_path,
864 start_line: start_line_array.value(i) as usize,
865 end_line: end_line_array.value(i) as usize,
866 language: language_array.value(i).to_string(),
867 content: content_array.value(i).to_string(),
868 project: if project_array.is_null(i) {
869 None
870 } else {
871 Some(project_array.value(i).to_string())
872 },
873 indexed_at: indexed_at_array
874 .and_then(|ia| ia.value(i).parse::<i64>().ok())
875 .unwrap_or(0),
876 });
877 }
878 }
879 }
880
881 Ok(search_results)
882 }
883 }
884
885 async fn search_filtered(
886 &self,
887 query_vector: Vec<f32>,
888 query_text: &str,
889 limit: usize,
890 min_score: f32,
891 project: Option<String>,
892 root_path: Option<String>,
893 hybrid: bool,
894 file_extensions: Vec<String>,
895 languages: Vec<String>,
896 path_patterns: Vec<String>,
897 ) -> Result<Vec<SearchResult>> {
898 let search_limit = limit * 3;
899
900 let mut results = self
901 .search(
902 query_vector,
903 query_text,
904 search_limit,
905 min_score,
906 project,
907 root_path,
908 hybrid,
909 )
910 .await?;
911
912 results.retain(|result| {
913 if !file_extensions.is_empty() {
914 let has_extension = file_extensions
915 .iter()
916 .any(|ext| result.file_path.ends_with(&format!(".{}", ext)));
917 if !has_extension {
918 return false;
919 }
920 }
921
922 if !languages.is_empty() && !languages.contains(&result.language) {
923 return false;
924 }
925
926 if !path_patterns.is_empty()
927 && !glob_utils::matches_any_pattern(&result.file_path, &path_patterns)
928 {
929 return false;
930 }
931
932 true
933 });
934
935 results.truncate(limit);
936 Ok(results)
937 }
938
939 async fn delete_by_file(&self, file_path: &str) -> Result<usize> {
940 {
941 let bm25_indexes = self
942 .bm25_indexes
943 .read()
944 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
945
946 for (root_hash, bm25) in bm25_indexes.iter() {
947 bm25.delete_by_file_path(file_path)
948 .context("Failed to delete from BM25 index")?;
949 tracing::debug!(
950 "Deleted BM25 entries for file: {} in index: {}",
951 file_path,
952 root_hash
953 );
954 }
955 }
956
957 let table = self.get_rag_table().await?;
958 let filter = format!("file_path = '{}'", file_path);
959 table
960 .delete(&filter)
961 .await
962 .context("Failed to delete records")?;
963
964 tracing::info!("Deleted embeddings for file: {}", file_path);
965 Ok(0)
966 }
967
968 async fn clear(&self) -> Result<()> {
969 self.connection
970 .drop_table(&self.rag_table_name, &[])
971 .await
972 .context("Failed to drop table")?;
973
974 let bm25_indexes = self
975 .bm25_indexes
976 .read()
977 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
978
979 for (root_hash, bm25) in bm25_indexes.iter() {
980 bm25.clear().context("Failed to clear BM25 index")?;
981 tracing::info!("Cleared BM25 index for root hash: {}", root_hash);
982 }
983 drop(bm25_indexes);
984
985 tracing::info!("Cleared all embeddings and all per-project BM25 indexes");
986 Ok(())
987 }
988
989 async fn get_statistics(&self) -> Result<DatabaseStats> {
990 let table = self.get_rag_table().await?;
991
992 let count_result = table
993 .count_rows(None)
994 .await
995 .context("Failed to count rows")?;
996
997 let stream = table
998 .query()
999 .select(lancedb::query::Select::Columns(vec![
1000 "language".to_string(),
1001 ]))
1002 .execute()
1003 .await
1004 .context("Failed to query languages")?;
1005
1006 let query_result: Vec<RecordBatch> = stream
1007 .try_collect()
1008 .await
1009 .context("Failed to collect language data")?;
1010
1011 let mut language_counts: HashMap<String, usize> = HashMap::new();
1012
1013 for batch in query_result {
1014 let language_array = batch
1015 .column_by_name("language")
1016 .context("Missing language column")?
1017 .as_any()
1018 .downcast_ref::<StringArray>()
1019 .context("Invalid language type")?;
1020
1021 for i in 0..batch.num_rows() {
1022 let language = language_array.value(i);
1023 *language_counts.entry(language.to_string()).or_insert(0) += 1;
1024 }
1025 }
1026
1027 let mut language_breakdown: Vec<(String, usize)> = language_counts.into_iter().collect();
1028 language_breakdown.sort_by(|a, b| b.1.cmp(&a.1));
1029
1030 Ok(DatabaseStats {
1031 total_points: count_result,
1032 total_vectors: count_result,
1033 language_breakdown,
1034 })
1035 }
1036
1037 async fn flush(&self) -> Result<()> {
1038 Ok(())
1039 }
1040
1041 async fn count_by_root_path(&self, root_path: &str) -> Result<usize> {
1042 let table = self.get_rag_table().await?;
1043 let filter = format!("root_path = '{}'", root_path);
1044 let count = table
1045 .count_rows(Some(filter))
1046 .await
1047 .context("Failed to count rows by root path")?;
1048 Ok(count)
1049 }
1050
1051 async fn get_indexed_files(&self, root_path: &str) -> Result<Vec<String>> {
1052 let table = self.get_rag_table().await?;
1053 let filter = format!("root_path = '{}'", root_path);
1054 let stream = table
1055 .query()
1056 .only_if(filter)
1057 .select(lancedb::query::Select::Columns(vec![
1058 "file_path".to_string(),
1059 ]))
1060 .execute()
1061 .await
1062 .context("Failed to query indexed files")?;
1063
1064 let results: Vec<RecordBatch> = stream
1065 .try_collect()
1066 .await
1067 .context("Failed to collect file paths")?;
1068
1069 let mut file_paths = std::collections::HashSet::new();
1070 for batch in results {
1071 let file_path_array = batch
1072 .column_by_name("file_path")
1073 .context("Missing file_path column")?
1074 .as_any()
1075 .downcast_ref::<StringArray>()
1076 .context("Invalid file_path type")?;
1077
1078 for i in 0..batch.num_rows() {
1079 file_paths.insert(file_path_array.value(i).to_string());
1080 }
1081 }
1082
1083 Ok(file_paths.into_iter().collect())
1084 }
1085
1086 async fn search_with_embeddings(
1087 &self,
1088 query_vector: Vec<f32>,
1089 query_text: &str,
1090 limit: usize,
1091 min_score: f32,
1092 project: Option<String>,
1093 root_path: Option<String>,
1094 hybrid: bool,
1095 ) -> Result<(Vec<SearchResult>, Vec<Vec<f32>>)> {
1096 let results = self
1097 .search(
1098 query_vector,
1099 query_text,
1100 limit,
1101 min_score,
1102 project,
1103 root_path,
1104 hybrid,
1105 )
1106 .await?;
1107
1108 if results.is_empty() {
1109 return Ok((results, Vec::new()));
1110 }
1111
1112 let table = self.get_rag_table().await?;
1113 let mut embeddings = Vec::with_capacity(results.len());
1114
1115 for result in &results {
1116 let filter = format!(
1117 "file_path = '{}' AND start_line = {}",
1118 result.file_path, result.start_line
1119 );
1120 let stream = table
1121 .query()
1122 .only_if(filter)
1123 .select(lancedb::query::Select::Columns(vec!["vector".to_string()]))
1124 .limit(1)
1125 .execute()
1126 .await
1127 .context("Failed to query embedding vector")?;
1128
1129 let batches: Vec<RecordBatch> = stream
1130 .try_collect()
1131 .await
1132 .context("Failed to collect embedding vector")?;
1133
1134 let mut found = false;
1135 for batch in &batches {
1136 if batch.num_rows() > 0
1137 && let Some(vector_col) = batch.column_by_name("vector")
1138 && let Some(fsl) = vector_col.as_any().downcast_ref::<FixedSizeListArray>()
1139 {
1140 let values = fsl
1141 .value(0)
1142 .as_any()
1143 .downcast_ref::<Float32Array>()
1144 .map(|a| a.values().to_vec())
1145 .unwrap_or_default();
1146 embeddings.push(values);
1147 found = true;
1148 break;
1149 }
1150 }
1151 if !found {
1152 embeddings.push(Vec::new());
1153 }
1154 }
1155
1156 Ok((results, embeddings))
1157 }
1158}
1159
1160#[cfg(test)]
1161mod tests {
1162 use super::*;
1163 use crate::databases::types::{FieldValue, Filter};
1164 use tempfile::TempDir;
1165
1166 #[tokio::test]
1167 async fn test_lance_database_new() {
1168 let temp = TempDir::new().unwrap();
1169 let db_path = temp.path().join("test.lance");
1170 let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1171 assert_eq!(db.db_path(), db_path.to_str().unwrap());
1172 }
1173
1174 #[tokio::test]
1175 async fn test_lance_storage_backend_crud() {
1176 let temp = TempDir::new().unwrap();
1177 let db_path = temp.path().join("test.lance");
1178 let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1179
1180 let schema = vec![
1181 FieldDef::required("id", crate::databases::types::FieldType::Utf8),
1182 FieldDef::required("value", crate::databases::types::FieldType::Int64),
1183 ];
1184 db.ensure_table("test_table", &schema).await.unwrap();
1185
1186 let records = vec![vec![
1187 ("id".to_string(), FieldValue::Utf8(Some("row1".to_string()))),
1188 ("value".to_string(), FieldValue::Int64(Some(42))),
1189 ]];
1190 db.insert("test_table", records).await.unwrap();
1191
1192 let results = db.query("test_table", None, None).await.unwrap();
1193 assert_eq!(results.len(), 1);
1194
1195 let count = db.count("test_table", None).await.unwrap();
1196 assert_eq!(count, 1);
1197
1198 db.delete(
1199 "test_table",
1200 &Filter::Eq("id".into(), FieldValue::Utf8(Some("row1".into()))),
1201 )
1202 .await
1203 .unwrap();
1204
1205 let count = db.count("test_table", None).await.unwrap();
1206 assert_eq!(count, 0);
1207 }
1208
1209 #[tokio::test]
1210 async fn test_lance_vector_search() {
1211 use crate::databases::types::FieldType;
1212
1213 let temp = TempDir::new().unwrap();
1214 let db_path = temp.path().join("vec_search.lance");
1215 let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1216
1217 let dim = 4;
1218 let schema = vec![
1219 FieldDef::required("id", FieldType::Utf8),
1220 FieldDef::required("embedding", FieldType::Vector(dim)),
1221 ];
1222 db.ensure_table("vectors", &schema).await.unwrap();
1223
1224 let records = vec![
1226 vec![
1227 ("id".to_string(), FieldValue::Utf8(Some("a".to_string()))),
1228 (
1229 "embedding".to_string(),
1230 FieldValue::Vector(vec![1.0, 0.0, 0.0, 0.0]),
1231 ),
1232 ],
1233 vec![
1234 ("id".to_string(), FieldValue::Utf8(Some("b".to_string()))),
1235 (
1236 "embedding".to_string(),
1237 FieldValue::Vector(vec![0.0, 1.0, 0.0, 0.0]),
1238 ),
1239 ],
1240 vec![
1241 ("id".to_string(), FieldValue::Utf8(Some("c".to_string()))),
1242 (
1243 "embedding".to_string(),
1244 FieldValue::Vector(vec![0.9, 0.1, 0.0, 0.0]),
1245 ),
1246 ],
1247 ];
1248 db.insert("vectors", records).await.unwrap();
1249
1250 let results = db
1252 .vector_search("vectors", "embedding", vec![1.0, 0.0, 0.0, 0.0], 3, None)
1253 .await
1254 .unwrap();
1255
1256 assert!(!results.is_empty(), "vector_search should return results");
1257 let first_id = results[0]
1259 .record
1260 .iter()
1261 .find(|(n, _)| n == "id")
1262 .and_then(|(_, v)| v.as_str())
1263 .unwrap();
1264 assert_eq!(first_id, "a");
1265
1266 for w in results.windows(2) {
1268 assert!(
1269 w[0].score >= w[1].score,
1270 "scores should be descending: {} >= {}",
1271 w[0].score,
1272 w[1].score
1273 );
1274 }
1275 }
1276
1277 #[tokio::test]
1278 async fn test_lance_capabilities() {
1279 let temp = TempDir::new().unwrap();
1280 let db_path = temp.path().join("caps.lance");
1281 let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1282
1283 let caps = db.capabilities();
1284 assert!(
1285 caps.vector_search,
1286 "LanceDatabase should support vector search"
1287 );
1288 }
1289
1290 #[tokio::test]
1291 async fn test_lance_shared_connection() {
1292 use crate::databases::types::FieldType;
1293
1294 let temp = TempDir::new().unwrap();
1295 let db_path = temp.path().join("shared.lance");
1296 let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1297
1298 let schema = vec![FieldDef::required("name", FieldType::Utf8)];
1300 db.ensure_table("store_table", &schema).await.unwrap();
1301 let records = vec![vec![(
1302 "name".to_string(),
1303 FieldValue::Utf8(Some("test".to_string())),
1304 )]];
1305 db.insert("store_table", records).await.unwrap();
1306
1307 db.initialize(4).await.unwrap();
1309
1310 let store_count = db.count("store_table", None).await.unwrap();
1312 assert_eq!(store_count, 1);
1313
1314 let stats = db.get_statistics().await.unwrap();
1315 assert_eq!(stats.total_vectors, 0);
1316 }
1317}