1use crate::bm25_search::BM25Search;
12use crate::glob_utils;
13use crate::types::{ChunkMetadata, SearchResult};
14use crate::vector_db::{DatabaseStats, VectorDatabase};
15use anyhow::{Context, Result};
16use arrow_array::{
17 Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
18 UInt32Array, types::Float32Type,
19};
20use arrow_schema::{DataType, Field, Schema};
21use futures::stream::TryStreamExt;
22use lancedb::Table;
23use lancedb::connection::Connection;
24use lancedb::query::{ExecutableQuery, QueryBase};
25use sha2::{Digest, Sha256};
26use std::collections::HashMap;
27use std::sync::{Arc, RwLock};
28
29pub struct LanceVectorDB {
32 connection: Connection,
33 table_name: String,
34 db_path: String,
35 bm25_indexes: Arc<RwLock<HashMap<String, BM25Search>>>,
38}
39
40impl LanceVectorDB {
41 pub async fn new() -> Result<Self> {
43 let db_path = Self::default_lancedb_path();
44 Self::with_path(&db_path).await
45 }
46
47 pub async fn with_path(db_path: &str) -> Result<Self> {
49 tracing::info!("Connecting to LanceDB at: {}", db_path);
50
51 let connection = lancedb::connect(db_path)
52 .execute()
53 .await
54 .context("Failed to connect to LanceDB")?;
55
56 let bm25_indexes = Arc::new(RwLock::new(HashMap::new()));
59
60 Ok(Self {
61 connection,
62 table_name: "code_embeddings".to_string(),
63 db_path: db_path.to_string(),
64 bm25_indexes,
65 })
66 }
67
68 pub fn default_lancedb_path() -> String {
70 crate::paths::PlatformPaths::default_lancedb_path()
71 .to_string_lossy()
72 .to_string()
73 }
74
75 fn hash_root_path(root_path: &str) -> String {
77 let mut hasher = Sha256::new();
78 hasher.update(root_path.as_bytes());
79 let result = hasher.finalize();
80 format!("{:x}", result)[..16].to_string()
82 }
83
84 fn bm25_path_for_root(&self, root_path: &str) -> String {
86 let hash = Self::hash_root_path(root_path);
87 format!("{}/bm25_{}", self.db_path, hash)
88 }
89
90 fn get_or_create_bm25(&self, root_path: &str) -> Result<()> {
92 let hash = Self::hash_root_path(root_path);
93
94 {
96 let indexes = self.bm25_indexes.read().map_err(|e| {
97 anyhow::anyhow!("Failed to acquire read lock on BM25 indexes: {}", e)
98 })?;
99 if indexes.contains_key(&hash) {
100 return Ok(()); }
102 }
103
104 let mut indexes = self
106 .bm25_indexes
107 .write()
108 .map_err(|e| anyhow::anyhow!("Failed to acquire write lock on BM25 indexes: {}", e))?;
109
110 if indexes.contains_key(&hash) {
112 return Ok(());
113 }
114
115 let bm25_path = self.bm25_path_for_root(root_path);
116 tracing::info!(
117 "Creating BM25 index for root path '{}' at: {}",
118 root_path,
119 bm25_path
120 );
121
122 let bm25_index = BM25Search::new(&bm25_path)
123 .with_context(|| format!("Failed to initialize BM25 index for root: {}", root_path))?;
124
125 indexes.insert(hash, bm25_index);
126
127 Ok(())
128 }
129
130 fn create_schema(dimension: usize) -> Arc<Schema> {
132 Arc::new(Schema::new(vec![
133 Field::new(
134 "vector",
135 DataType::FixedSizeList(
136 Arc::new(Field::new("item", DataType::Float32, true)),
137 dimension as i32,
138 ),
139 false,
140 ),
141 Field::new("id", DataType::Utf8, false),
142 Field::new("file_path", DataType::Utf8, false),
143 Field::new("root_path", DataType::Utf8, true),
144 Field::new("start_line", DataType::UInt32, false),
145 Field::new("end_line", DataType::UInt32, false),
146 Field::new("language", DataType::Utf8, false),
147 Field::new("extension", DataType::Utf8, false),
148 Field::new("file_hash", DataType::Utf8, false),
149 Field::new("indexed_at", DataType::Utf8, false),
150 Field::new("content", DataType::Utf8, false),
151 Field::new("project", DataType::Utf8, true),
152 ]))
153 }
154
155 async fn get_table(&self) -> Result<Table> {
157 self.connection
158 .open_table(&self.table_name)
159 .execute()
160 .await
161 .context("Failed to open table")
162 }
163
164 fn create_record_batch(
166 embeddings: Vec<Vec<f32>>,
167 metadata: Vec<ChunkMetadata>,
168 contents: Vec<String>,
169 schema: Arc<Schema>,
170 ) -> Result<RecordBatch> {
171 let num_rows = embeddings.len();
172 let dimension = embeddings[0].len();
173
174 let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
176 embeddings
177 .into_iter()
178 .map(|v| Some(v.into_iter().map(Some))),
179 dimension as i32,
180 );
181
182 let id_array = StringArray::from(
184 (0..num_rows)
185 .map(|i| format!("{}:{}", metadata[i].file_path, metadata[i].start_line))
186 .collect::<Vec<_>>(),
187 );
188 let file_path_array = StringArray::from(
189 metadata
190 .iter()
191 .map(|m| m.file_path.as_str())
192 .collect::<Vec<_>>(),
193 );
194 let root_path_array = StringArray::from(
195 metadata
196 .iter()
197 .map(|m| m.root_path.as_deref())
198 .collect::<Vec<_>>(),
199 );
200 let start_line_array = UInt32Array::from(
201 metadata
202 .iter()
203 .map(|m| m.start_line as u32)
204 .collect::<Vec<_>>(),
205 );
206 let end_line_array = UInt32Array::from(
207 metadata
208 .iter()
209 .map(|m| m.end_line as u32)
210 .collect::<Vec<_>>(),
211 );
212 let language_array = StringArray::from(
213 metadata
214 .iter()
215 .map(|m| m.language.as_deref().unwrap_or("Unknown"))
216 .collect::<Vec<_>>(),
217 );
218 let extension_array = StringArray::from(
219 metadata
220 .iter()
221 .map(|m| m.extension.as_deref().unwrap_or(""))
222 .collect::<Vec<_>>(),
223 );
224 let file_hash_array = StringArray::from(
225 metadata
226 .iter()
227 .map(|m| m.file_hash.as_str())
228 .collect::<Vec<_>>(),
229 );
230 let indexed_at_array = StringArray::from(
231 metadata
232 .iter()
233 .map(|m| m.indexed_at.to_string())
234 .collect::<Vec<_>>(),
235 );
236 let content_array =
237 StringArray::from(contents.iter().map(|s| s.as_str()).collect::<Vec<_>>());
238 let project_array = StringArray::from(
239 metadata
240 .iter()
241 .map(|m| m.project.as_deref())
242 .collect::<Vec<_>>(),
243 );
244
245 RecordBatch::try_new(
246 schema,
247 vec![
248 Arc::new(vector_array),
249 Arc::new(id_array),
250 Arc::new(file_path_array),
251 Arc::new(root_path_array),
252 Arc::new(start_line_array),
253 Arc::new(end_line_array),
254 Arc::new(language_array),
255 Arc::new(extension_array),
256 Arc::new(file_hash_array),
257 Arc::new(indexed_at_array),
258 Arc::new(content_array),
259 Arc::new(project_array),
260 ],
261 )
262 .context("Failed to create RecordBatch")
263 }
264}
265
266#[async_trait::async_trait]
267impl VectorDatabase for LanceVectorDB {
268 async fn initialize(&self, dimension: usize) -> Result<()> {
269 tracing::info!(
270 "Initializing LanceDB with dimension {} at {}",
271 dimension,
272 self.db_path
273 );
274
275 let table_names = self
277 .connection
278 .table_names()
279 .execute()
280 .await
281 .context("Failed to list tables")?;
282
283 if table_names.contains(&self.table_name) {
284 tracing::info!("Table '{}' already exists", self.table_name);
285 return Ok(());
286 }
287
288 let schema = Self::create_schema(dimension);
290
291 let empty_batch = RecordBatch::new_empty(schema.clone());
293
294 let batches =
296 RecordBatchIterator::new(vec![empty_batch].into_iter().map(Ok), schema.clone());
297
298 self.connection
299 .create_table(&self.table_name, Box::new(batches))
300 .execute()
301 .await
302 .context("Failed to create table")?;
303
304 tracing::info!("Created table '{}'", self.table_name);
305 Ok(())
306 }
307
308 async fn store_embeddings(
309 &self,
310 embeddings: Vec<Vec<f32>>,
311 metadata: Vec<ChunkMetadata>,
312 contents: Vec<String>,
313 root_path: &str,
314 ) -> Result<usize> {
315 if embeddings.is_empty() {
316 return Ok(0);
317 }
318
319 let dimension = embeddings[0].len();
320 let schema = Self::create_schema(dimension);
321
322 let table = self.get_table().await?;
324 let current_count = table.count_rows(None).await.unwrap_or(0) as u64;
325
326 let batch = Self::create_record_batch(
327 embeddings,
328 metadata.clone(),
329 contents.clone(),
330 schema.clone(),
331 )?;
332 let count = batch.num_rows();
333
334 let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
335
336 table
337 .add(Box::new(batches))
338 .execute()
339 .await
340 .context("Failed to add records to table")?;
341
342 self.get_or_create_bm25(root_path)?;
344
345 let bm25_docs: Vec<_> = (0..count)
347 .map(|i| {
348 let id = current_count + i as u64;
349 (id, contents[i].clone(), metadata[i].file_path.clone())
350 })
351 .collect();
352
353 let hash = Self::hash_root_path(root_path);
354 let bm25_indexes = self
355 .bm25_indexes
356 .read()
357 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
358
359 if let Some(bm25) = bm25_indexes.get(&hash) {
360 bm25.add_documents(bm25_docs)
361 .context("Failed to add documents to BM25 index")?;
362 }
363 drop(bm25_indexes);
364
365 tracing::info!(
366 "Stored {} embeddings with BM25 indexing for root: {}",
367 count,
368 root_path
369 );
370 Ok(count)
371 }
372
373 async fn search(
374 &self,
375 query_vector: Vec<f32>,
376 query_text: &str,
377 limit: usize,
378 min_score: f32,
379 project: Option<String>,
380 root_path: Option<String>,
381 hybrid: bool,
382 ) -> Result<Vec<SearchResult>> {
383 let table = self.get_table().await?;
384
385 if hybrid {
386 let search_limit = limit * 3;
389
390 let query = table
392 .vector_search(query_vector)
393 .context("Failed to create vector search")?
394 .limit(search_limit);
395
396 let stream = if let Some(ref project_name) = project {
397 query
398 .only_if(format!("project = '{}'", project_name))
399 .execute()
400 .await
401 .context("Failed to execute search")?
402 } else {
403 query.execute().await.context("Failed to execute search")?
404 };
405
406 let results: Vec<RecordBatch> = stream
407 .try_collect()
408 .await
409 .context("Failed to collect search results")?;
410
411 let mut vector_results = Vec::new();
413 let mut row_offset = 0u64;
414
415 let mut original_scores: HashMap<u64, (f32, Option<f32>)> = HashMap::new();
417
418 for batch in &results {
419 let distance_array = batch
420 .column_by_name("_distance")
421 .context("Missing _distance column")?
422 .as_any()
423 .downcast_ref::<Float32Array>()
424 .context("Invalid _distance type")?;
425
426 for i in 0..batch.num_rows() {
427 let distance = distance_array.value(i);
428 let score = 1.0 / (1.0 + distance);
429 let id = row_offset + i as u64;
430
431 vector_results.push((id, score));
435 original_scores.insert(id, (score, None));
436 }
437 row_offset += batch.num_rows() as u64;
438 }
439
440 let bm25_indexes = self
442 .bm25_indexes
443 .read()
444 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
445
446 let mut all_bm25_results = Vec::new();
447 for (root_hash, bm25) in bm25_indexes.iter() {
448 tracing::debug!("Searching BM25 index for root hash: {}", root_hash);
449 let results = bm25
450 .search(query_text, search_limit)
451 .context("Failed to search BM25 index")?;
452
453 for result in &results {
456 original_scores
457 .entry(result.id)
458 .and_modify(|e| e.1 = Some(result.score))
459 .or_insert((0.0, Some(result.score))); }
461
462 all_bm25_results.extend(results);
463 }
464 drop(bm25_indexes);
465
466 let bm25_results = all_bm25_results;
467
468 let combined =
471 crate::bm25_search::reciprocal_rank_fusion(vector_results, bm25_results, limit);
472
473 let mut search_results = Vec::new();
475
476 for (id, combined_score) in combined {
477 let mut found = false;
479 let mut batch_offset = 0u64;
480
481 for batch in &results {
482 if id >= batch_offset && id < batch_offset + batch.num_rows() as u64 {
483 let idx = (id - batch_offset) as usize;
484
485 let file_path_array = batch
486 .column_by_name("file_path")
487 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
488 let root_path_array = batch
489 .column_by_name("root_path")
490 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
491 let start_line_array = batch
492 .column_by_name("start_line")
493 .and_then(|c| c.as_any().downcast_ref::<UInt32Array>());
494 let end_line_array = batch
495 .column_by_name("end_line")
496 .and_then(|c| c.as_any().downcast_ref::<UInt32Array>());
497 let language_array = batch
498 .column_by_name("language")
499 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
500 let content_array = batch
501 .column_by_name("content")
502 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
503 let project_array = batch
504 .column_by_name("project")
505 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
506
507 if let (
508 Some(fp),
509 Some(rp),
510 Some(sl),
511 Some(el),
512 Some(lang),
513 Some(cont),
514 Some(proj),
515 ) = (
516 file_path_array,
517 root_path_array,
518 start_line_array,
519 end_line_array,
520 language_array,
521 content_array,
522 project_array,
523 ) {
524 let (vector_score, keyword_score) =
526 original_scores.get(&id).copied().unwrap_or((0.0, None));
527
528 let passes_filter = vector_score >= min_score
532 || keyword_score.is_some_and(|k| k >= min_score);
533
534 if passes_filter {
535 let result_root_path = if rp.is_null(idx) {
536 None
537 } else {
538 Some(rp.value(idx).to_string())
539 };
540
541 if let Some(ref filter_path) = root_path {
543 if result_root_path.as_ref() != Some(filter_path) {
544 found = true;
545 break;
546 }
547 }
548
549 search_results.push(SearchResult {
552 score: combined_score, vector_score, keyword_score, file_path: fp.value(idx).to_string(),
556 root_path: result_root_path,
557 start_line: sl.value(idx) as usize,
558 end_line: el.value(idx) as usize,
559 language: lang.value(idx).to_string(),
560 content: cont.value(idx).to_string(),
561 project: if proj.is_null(idx) {
562 None
563 } else {
564 Some(proj.value(idx).to_string())
565 },
566 });
567 }
568 found = true;
569 break;
570 }
571 }
572 batch_offset += batch.num_rows() as u64;
573 }
574
575 if !found {
576 tracing::warn!("Could not find result for RRF ID {}", id);
577 }
578 }
579
580 Ok(search_results)
581 } else {
582 let query = table
584 .vector_search(query_vector)
585 .context("Failed to create vector search")?
586 .limit(limit);
587
588 let stream = if let Some(ref project_name) = project {
589 query
590 .only_if(format!("project = '{}'", project_name))
591 .execute()
592 .await
593 .context("Failed to execute search")?
594 } else {
595 query.execute().await.context("Failed to execute search")?
596 };
597
598 let results: Vec<RecordBatch> = stream
599 .try_collect()
600 .await
601 .context("Failed to collect search results")?;
602
603 let mut search_results = Vec::new();
604
605 for batch in results {
606 let file_path_array = batch
607 .column_by_name("file_path")
608 .context("Missing file_path column")?
609 .as_any()
610 .downcast_ref::<StringArray>()
611 .context("Invalid file_path type")?;
612
613 let root_path_array = batch
614 .column_by_name("root_path")
615 .context("Missing root_path column")?
616 .as_any()
617 .downcast_ref::<StringArray>()
618 .context("Invalid root_path type")?;
619
620 let start_line_array = batch
621 .column_by_name("start_line")
622 .context("Missing start_line column")?
623 .as_any()
624 .downcast_ref::<UInt32Array>()
625 .context("Invalid start_line type")?;
626
627 let end_line_array = batch
628 .column_by_name("end_line")
629 .context("Missing end_line column")?
630 .as_any()
631 .downcast_ref::<UInt32Array>()
632 .context("Invalid end_line type")?;
633
634 let language_array = batch
635 .column_by_name("language")
636 .context("Missing language column")?
637 .as_any()
638 .downcast_ref::<StringArray>()
639 .context("Invalid language type")?;
640
641 let content_array = batch
642 .column_by_name("content")
643 .context("Missing content column")?
644 .as_any()
645 .downcast_ref::<StringArray>()
646 .context("Invalid content type")?;
647
648 let project_array = batch
649 .column_by_name("project")
650 .context("Missing project column")?
651 .as_any()
652 .downcast_ref::<StringArray>()
653 .context("Invalid project type")?;
654
655 let distance_array = batch
656 .column_by_name("_distance")
657 .context("Missing _distance column")?
658 .as_any()
659 .downcast_ref::<Float32Array>()
660 .context("Invalid _distance type")?;
661
662 for i in 0..batch.num_rows() {
663 let distance = distance_array.value(i);
664 let score = 1.0 / (1.0 + distance);
665
666 if score >= min_score {
667 let result_root_path = if root_path_array.is_null(i) {
668 None
669 } else {
670 Some(root_path_array.value(i).to_string())
671 };
672
673 if let Some(ref filter_path) = root_path {
675 if result_root_path.as_ref() != Some(filter_path) {
676 continue;
677 }
678 }
679
680 search_results.push(SearchResult {
681 score,
682 vector_score: score,
683 keyword_score: None,
684 file_path: file_path_array.value(i).to_string(),
685 root_path: result_root_path,
686 start_line: start_line_array.value(i) as usize,
687 end_line: end_line_array.value(i) as usize,
688 language: language_array.value(i).to_string(),
689 content: content_array.value(i).to_string(),
690 project: if project_array.is_null(i) {
691 None
692 } else {
693 Some(project_array.value(i).to_string())
694 },
695 });
696 }
697 }
698 }
699
700 Ok(search_results)
701 }
702 }
703
704 async fn search_filtered(
705 &self,
706 query_vector: Vec<f32>,
707 query_text: &str,
708 limit: usize,
709 min_score: f32,
710 project: Option<String>,
711 root_path: Option<String>,
712 hybrid: bool,
713 file_extensions: Vec<String>,
714 languages: Vec<String>,
715 path_patterns: Vec<String>,
716 ) -> Result<Vec<SearchResult>> {
717 let search_limit = limit * 3;
719
720 let mut results = self
722 .search(
723 query_vector,
724 query_text,
725 search_limit,
726 min_score,
727 project.clone(),
728 root_path.clone(),
729 hybrid,
730 )
731 .await?;
732
733 results.retain(|result| {
735 if !file_extensions.is_empty() {
737 let has_extension = file_extensions
738 .iter()
739 .any(|ext| result.file_path.ends_with(&format!(".{}", ext)));
740 if !has_extension {
741 return false;
742 }
743 }
744
745 if !languages.is_empty() && !languages.contains(&result.language) {
747 return false;
748 }
749
750 if !path_patterns.is_empty() {
752 if !glob_utils::matches_any_pattern(&result.file_path, &path_patterns) {
753 return false;
754 }
755 }
756
757 true
758 });
759
760 results.truncate(limit);
762
763 Ok(results)
764 }
765
766 async fn delete_by_file(&self, file_path: &str) -> Result<usize> {
767 {
771 let bm25_indexes = self
772 .bm25_indexes
773 .read()
774 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
775
776 for (root_hash, bm25) in bm25_indexes.iter() {
777 bm25.delete_by_file_path(file_path)
778 .context("Failed to delete from BM25 index")?;
779 tracing::debug!(
780 "Deleted BM25 entries for file: {} in index: {}",
781 file_path,
782 root_hash
783 );
784 }
785 } let table = self.get_table().await?;
788
789 let filter = format!("file_path = '{}'", file_path);
791
792 table
793 .delete(&filter)
794 .await
795 .context("Failed to delete records")?;
796
797 tracing::info!("Deleted embeddings for file: {}", file_path);
798
799 Ok(0)
801 }
802
803 async fn clear(&self) -> Result<()> {
804 self.connection
806 .drop_table(&self.table_name, &[])
807 .await
808 .context("Failed to drop table")?;
809
810 let bm25_indexes = self
812 .bm25_indexes
813 .read()
814 .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
815
816 for (root_hash, bm25) in bm25_indexes.iter() {
817 bm25.clear().context("Failed to clear BM25 index")?;
818 tracing::info!("Cleared BM25 index for root hash: {}", root_hash);
819 }
820 drop(bm25_indexes);
821
822 tracing::info!("Cleared all embeddings and all per-project BM25 indexes");
823 Ok(())
824 }
825
826 async fn get_statistics(&self) -> Result<DatabaseStats> {
827 let table = self.get_table().await?;
828
829 let count_result = table
831 .count_rows(None)
832 .await
833 .context("Failed to count rows")?;
834
835 let stream = table
837 .query()
838 .select(lancedb::query::Select::Columns(vec![
839 "language".to_string(),
840 ]))
841 .execute()
842 .await
843 .context("Failed to query languages")?;
844
845 let query_result: Vec<RecordBatch> = stream
846 .try_collect()
847 .await
848 .context("Failed to collect language data")?;
849
850 let mut language_counts: HashMap<String, usize> = HashMap::new();
851
852 for batch in query_result {
853 let language_array = batch
854 .column_by_name("language")
855 .context("Missing language column")?
856 .as_any()
857 .downcast_ref::<StringArray>()
858 .context("Invalid language type")?;
859
860 for i in 0..batch.num_rows() {
861 let language = language_array.value(i);
862 *language_counts.entry(language.to_string()).or_insert(0) += 1;
863 }
864 }
865
866 let mut language_breakdown: Vec<(String, usize)> = language_counts.into_iter().collect();
867 language_breakdown.sort_by(|a, b| b.1.cmp(&a.1));
868
869 Ok(DatabaseStats {
870 total_points: count_result,
871 total_vectors: count_result,
872 language_breakdown,
873 })
874 }
875
876 async fn flush(&self) -> Result<()> {
877 Ok(())
879 }
880
881 async fn count_by_root_path(&self, root_path: &str) -> Result<usize> {
882 let table = self.get_table().await?;
883
884 let filter = format!("root_path = '{}'", root_path);
886 let count = table
887 .count_rows(Some(filter))
888 .await
889 .context("Failed to count rows by root path")?;
890
891 Ok(count)
892 }
893
894 async fn get_indexed_files(&self, root_path: &str) -> Result<Vec<String>> {
895 let table = self.get_table().await?;
896
897 let filter = format!("root_path = '{}'", root_path);
899 let stream = table
900 .query()
901 .only_if(filter)
902 .select(lancedb::query::Select::Columns(vec![
903 "file_path".to_string(),
904 ]))
905 .execute()
906 .await
907 .context("Failed to query indexed files")?;
908
909 let results: Vec<RecordBatch> = stream
910 .try_collect()
911 .await
912 .context("Failed to collect file paths")?;
913
914 let mut file_paths = std::collections::HashSet::new();
916
917 for batch in results {
918 let file_path_array = batch
919 .column_by_name("file_path")
920 .context("Missing file_path column")?
921 .as_any()
922 .downcast_ref::<StringArray>()
923 .context("Invalid file_path type")?;
924
925 for i in 0..batch.num_rows() {
926 file_paths.insert(file_path_array.value(i).to_string());
927 }
928 }
929
930 Ok(file_paths.into_iter().collect())
931 }
932}
933
934#[cfg(test)]
935mod tests;