Skip to main content

brainwires_storage/databases/lance/
mod.rs

1//! LanceDB unified database backend.
2//!
3//! [`LanceDatabase`] implements both [`StorageBackend`] and [`VectorDatabase`]
4//! using a single shared `lancedb::Connection`. This replaces the former
5//! `LanceBackend` + `LanceVectorDB` split.
6//!
7//! # Feature flag
8//!
9//! Requires `lance-backend` (included in `native` by default).
10
11pub mod arrow_convert;
12
13use anyhow::{Context, Result};
14use arrow_array::{
15    Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
16    UInt32Array, types::Float32Type,
17};
18use arrow_schema::{DataType, Field, Schema};
19use futures::stream::TryStreamExt;
20use lancedb::Table;
21use lancedb::connection::Connection;
22use lancedb::query::{ExecutableQuery, QueryBase};
23use sha2::{Digest, Sha256};
24use std::collections::HashMap;
25use std::sync::{Arc, RwLock};
26
27use crate::bm25_search::{BM25Search, RrfScorer, SearchScorer};
28use crate::databases::traits::{
29    ChunkMetadata, DatabaseStats, SearchResult, StorageBackend, VectorDatabase,
30};
31use crate::databases::types::{FieldDef, Filter, Record, ScoredRecord};
32use crate::glob_utils;
33
34use arrow_convert::{
35    batch_to_records, extract_field_value, field_defs_to_schema, filter_to_sql, records_to_batch,
36};
37
38/// Default table name for RAG embeddings.
39const RAG_TABLE_NAME: &str = "code_embeddings";
40
41/// Unified LanceDB database backend.
42///
43/// Holds a single `lancedb::Connection` and implements both
44/// [`StorageBackend`] (for domain stores) and [`VectorDatabase`] (for RAG).
45///
46/// # Example
47///
48/// ```ignore
49/// let db = Arc::new(LanceDatabase::new("/path/to/db").await?);
50///
51/// // Use as StorageBackend
52/// let messages = MessageStore::new(db.clone(), embeddings);
53///
54/// // Use as VectorDatabase
55/// db.initialize(384).await?;
56/// db.store_embeddings(embeddings, metadata, contents, root_path).await?;
57/// ```
58pub struct LanceDatabase {
59    connection: Connection,
60    db_path: String,
61    /// RAG table name (default: "code_embeddings").
62    rag_table_name: String,
63    /// Per-project BM25 search indexes for keyword matching.
64    bm25_indexes: Arc<RwLock<HashMap<String, BM25Search>>>,
65    /// Pluggable search scorer for hybrid result fusion (default: RRF).
66    scorer: Arc<dyn SearchScorer>,
67}
68
69impl LanceDatabase {
70    /// Create a new LanceDB database at the given path.
71    ///
72    /// The path can be a local directory. Parent directories are created
73    /// automatically.
74    pub async fn new(db_path: impl Into<String>) -> Result<Self> {
75        let db_path = db_path.into();
76
77        if let Some(parent) = std::path::Path::new(&db_path).parent() {
78            std::fs::create_dir_all(parent).context("Failed to create database directory")?;
79        }
80
81        let connection = lancedb::connect(&db_path)
82            .execute()
83            .await
84            .context("Failed to connect to LanceDB")?;
85
86        Ok(Self {
87            connection,
88            db_path,
89            rag_table_name: RAG_TABLE_NAME.to_string(),
90            bm25_indexes: Arc::new(RwLock::new(HashMap::new())),
91            scorer: Arc::new(RrfScorer),
92        })
93    }
94
95    /// Create with the platform default LanceDB path.
96    pub async fn with_default_path() -> Result<Self> {
97        let db_path = Self::default_lancedb_path();
98        Self::new(db_path).await
99    }
100
101    /// Set a custom search scorer for hybrid result fusion.
102    pub fn with_scorer(mut self, scorer: Arc<dyn SearchScorer>) -> Self {
103        self.scorer = scorer;
104        self
105    }
106
107    /// Get the underlying LanceDB connection (for legacy code).
108    pub fn connection(&self) -> &Connection {
109        &self.connection
110    }
111
112    /// Get the database path.
113    pub fn db_path(&self) -> &str {
114        &self.db_path
115    }
116
117    /// Report backend capabilities.
118    pub fn capabilities(&self) -> crate::databases::BackendCapabilities {
119        crate::databases::BackendCapabilities {
120            vector_search: true,
121        }
122    }
123
124    /// Get default database path.
125    pub fn default_lancedb_path() -> String {
126        crate::paths::PlatformPaths::default_lancedb_path()
127            .to_string_lossy()
128            .to_string()
129    }
130
131    // ── VectorDatabase helpers ──────────────────────────────────────────
132
133    fn hash_root_path(root_path: &str) -> String {
134        let mut hasher = Sha256::new();
135        hasher.update(root_path.as_bytes());
136        let result = hasher.finalize();
137        format!("{:x}", result)[..16].to_string()
138    }
139
140    fn bm25_path_for_root(&self, root_path: &str) -> String {
141        let hash = Self::hash_root_path(root_path);
142        format!("{}/bm25_{}", self.db_path, hash)
143    }
144
145    fn get_or_create_bm25(&self, root_path: &str) -> Result<()> {
146        let hash = Self::hash_root_path(root_path);
147
148        {
149            let indexes = self.bm25_indexes.read().map_err(|e| {
150                anyhow::anyhow!("Failed to acquire read lock on BM25 indexes: {}", e)
151            })?;
152            if indexes.contains_key(&hash) {
153                return Ok(());
154            }
155        }
156
157        let mut indexes = self
158            .bm25_indexes
159            .write()
160            .map_err(|e| anyhow::anyhow!("Failed to acquire write lock on BM25 indexes: {}", e))?;
161
162        if indexes.contains_key(&hash) {
163            return Ok(());
164        }
165
166        let bm25_path = self.bm25_path_for_root(root_path);
167        tracing::info!(
168            "Creating BM25 index for root path '{}' at: {}",
169            root_path,
170            bm25_path
171        );
172
173        let bm25_index = BM25Search::new(&bm25_path)
174            .with_context(|| format!("Failed to initialize BM25 index for root: {}", root_path))?;
175
176        indexes.insert(hash, bm25_index);
177        Ok(())
178    }
179
180    fn create_rag_schema(dimension: usize) -> Arc<Schema> {
181        Arc::new(Schema::new(vec![
182            Field::new(
183                "vector",
184                DataType::FixedSizeList(
185                    Arc::new(Field::new("item", DataType::Float32, true)),
186                    dimension as i32,
187                ),
188                false,
189            ),
190            Field::new("id", DataType::Utf8, false),
191            Field::new("file_path", DataType::Utf8, false),
192            Field::new("root_path", DataType::Utf8, true),
193            Field::new("start_line", DataType::UInt32, false),
194            Field::new("end_line", DataType::UInt32, false),
195            Field::new("language", DataType::Utf8, false),
196            Field::new("extension", DataType::Utf8, false),
197            Field::new("file_hash", DataType::Utf8, false),
198            Field::new("indexed_at", DataType::Utf8, false),
199            Field::new("content", DataType::Utf8, false),
200            Field::new("project", DataType::Utf8, true),
201        ]))
202    }
203
204    async fn get_rag_table(&self) -> Result<Table> {
205        self.connection
206            .open_table(&self.rag_table_name)
207            .execute()
208            .await
209            .context("Failed to open RAG table")
210    }
211
212    fn create_rag_record_batch(
213        embeddings: Vec<Vec<f32>>,
214        metadata: Vec<ChunkMetadata>,
215        contents: Vec<String>,
216        schema: Arc<Schema>,
217    ) -> Result<RecordBatch> {
218        let num_rows = embeddings.len();
219        let dimension = embeddings[0].len();
220
221        let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
222            embeddings
223                .into_iter()
224                .map(|v| Some(v.into_iter().map(Some))),
225            dimension as i32,
226        );
227
228        let id_array = StringArray::from(
229            (0..num_rows)
230                .map(|i| format!("{}:{}", metadata[i].file_path, metadata[i].start_line))
231                .collect::<Vec<_>>(),
232        );
233        let file_path_array = StringArray::from(
234            metadata
235                .iter()
236                .map(|m| m.file_path.as_str())
237                .collect::<Vec<_>>(),
238        );
239        let root_path_array = StringArray::from(
240            metadata
241                .iter()
242                .map(|m| m.root_path.as_deref())
243                .collect::<Vec<_>>(),
244        );
245        let start_line_array = UInt32Array::from(
246            metadata
247                .iter()
248                .map(|m| m.start_line as u32)
249                .collect::<Vec<_>>(),
250        );
251        let end_line_array = UInt32Array::from(
252            metadata
253                .iter()
254                .map(|m| m.end_line as u32)
255                .collect::<Vec<_>>(),
256        );
257        let language_array = StringArray::from(
258            metadata
259                .iter()
260                .map(|m| m.language.as_deref().unwrap_or("Unknown"))
261                .collect::<Vec<_>>(),
262        );
263        let extension_array = StringArray::from(
264            metadata
265                .iter()
266                .map(|m| m.extension.as_deref().unwrap_or(""))
267                .collect::<Vec<_>>(),
268        );
269        let file_hash_array = StringArray::from(
270            metadata
271                .iter()
272                .map(|m| m.file_hash.as_str())
273                .collect::<Vec<_>>(),
274        );
275        let indexed_at_array = StringArray::from(
276            metadata
277                .iter()
278                .map(|m| m.indexed_at.to_string())
279                .collect::<Vec<_>>(),
280        );
281        let content_array =
282            StringArray::from(contents.iter().map(|s| s.as_str()).collect::<Vec<_>>());
283        let project_array = StringArray::from(
284            metadata
285                .iter()
286                .map(|m| m.project.as_deref())
287                .collect::<Vec<_>>(),
288        );
289
290        RecordBatch::try_new(
291            schema,
292            vec![
293                Arc::new(vector_array),
294                Arc::new(id_array),
295                Arc::new(file_path_array),
296                Arc::new(root_path_array),
297                Arc::new(start_line_array),
298                Arc::new(end_line_array),
299                Arc::new(language_array),
300                Arc::new(extension_array),
301                Arc::new(file_hash_array),
302                Arc::new(indexed_at_array),
303                Arc::new(content_array),
304                Arc::new(project_array),
305            ],
306        )
307        .context("Failed to create RecordBatch")
308    }
309}
310
311// ── StorageBackend impl ─────────────────────────────────────────────────
312
313#[async_trait::async_trait]
314impl StorageBackend for LanceDatabase {
315    async fn ensure_table(&self, table_name: &str, schema: &[FieldDef]) -> Result<()> {
316        let table_names = self.connection.table_names().execute().await?;
317        if table_names.contains(&table_name.to_string()) {
318            return Ok(());
319        }
320
321        let arrow_schema = Arc::new(field_defs_to_schema(schema));
322        let batches = RecordBatchIterator::new(vec![], arrow_schema);
323        self.connection
324            .create_table(table_name, Box::new(batches))
325            .execute()
326            .await
327            .with_context(|| format!("Failed to create table '{table_name}'"))?;
328        Ok(())
329    }
330
331    async fn insert(&self, table_name: &str, records: Vec<Record>) -> Result<()> {
332        if records.is_empty() {
333            return Ok(());
334        }
335
336        let table = self
337            .connection
338            .open_table(table_name)
339            .execute()
340            .await
341            .with_context(|| format!("Failed to open table '{table_name}'"))?;
342
343        let batch = records_to_batch(&records)?;
344        let schema = batch.schema();
345        let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
346        table
347            .add(Box::new(batches))
348            .execute()
349            .await
350            .with_context(|| format!("Failed to insert into '{table_name}'"))?;
351        Ok(())
352    }
353
354    async fn query(
355        &self,
356        table_name: &str,
357        filter: Option<&Filter>,
358        limit: Option<usize>,
359    ) -> Result<Vec<Record>> {
360        let table = self
361            .connection
362            .open_table(table_name)
363            .execute()
364            .await
365            .with_context(|| format!("Failed to open table '{table_name}'"))?;
366
367        let mut q = table.query();
368        if let Some(f) = filter {
369            q = q.only_if(filter_to_sql(f));
370        }
371        if let Some(n) = limit {
372            q = q.limit(n);
373        }
374
375        let batches: Vec<RecordBatch> = q
376            .execute()
377            .await
378            .with_context(|| format!("Failed to query '{table_name}'"))?
379            .try_collect()
380            .await?;
381
382        let mut results = Vec::new();
383        for batch in &batches {
384            batch_to_records(batch, &mut results)?;
385        }
386        Ok(results)
387    }
388
389    async fn delete(&self, table_name: &str, filter: &Filter) -> Result<()> {
390        let table = self
391            .connection
392            .open_table(table_name)
393            .execute()
394            .await
395            .with_context(|| format!("Failed to open table '{table_name}'"))?;
396
397        table
398            .delete(&filter_to_sql(filter))
399            .await
400            .with_context(|| format!("Failed to delete from '{table_name}'"))?;
401        Ok(())
402    }
403
404    async fn count(&self, table_name: &str, filter: Option<&Filter>) -> Result<usize> {
405        let table = self
406            .connection
407            .open_table(table_name)
408            .execute()
409            .await
410            .with_context(|| format!("Failed to open table '{table_name}'"))?;
411
412        let mut q = table.query();
413        if let Some(f) = filter {
414            q = q.only_if(filter_to_sql(f));
415        }
416        let batches: Vec<RecordBatch> = q.execute().await?.try_collect().await?;
417        Ok(batches.iter().map(|b| b.num_rows()).sum())
418    }
419
420    async fn vector_search(
421        &self,
422        table_name: &str,
423        _vector_column: &str,
424        vector: Vec<f32>,
425        limit: usize,
426        filter: Option<&Filter>,
427    ) -> Result<Vec<ScoredRecord>> {
428        let table = self
429            .connection
430            .open_table(table_name)
431            .execute()
432            .await
433            .with_context(|| format!("Failed to open table '{table_name}'"))?;
434
435        let mut q = table.vector_search(vector)?;
436        q = q.limit(limit);
437        if let Some(f) = filter {
438            q = q.only_if(filter_to_sql(f));
439        }
440
441        let batches: Vec<RecordBatch> = q.execute().await?.try_collect().await?;
442
443        let mut results = Vec::new();
444        for batch in &batches {
445            let distance_col = batch
446                .column_by_name("_distance")
447                .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
448
449            for row in 0..batch.num_rows() {
450                let mut record = Vec::new();
451                for (col_idx, field) in batch.schema().fields().iter().enumerate() {
452                    if field.name() == "_distance" {
453                        continue;
454                    }
455                    let val = extract_field_value(batch, col_idx, row, field)?;
456                    record.push((field.name().clone(), val));
457                }
458
459                let distance = distance_col.map_or(0.0, |c| c.value(row));
460                let score = 1.0 / (1.0 + distance);
461
462                results.push(ScoredRecord { record, score });
463            }
464        }
465        Ok(results)
466    }
467}
468
469// ── VectorDatabase impl ────────────────────────────────────────────────
470
471#[async_trait::async_trait]
472impl VectorDatabase for LanceDatabase {
473    async fn initialize(&self, dimension: usize) -> Result<()> {
474        tracing::info!(
475            "Initializing LanceDB with dimension {} at {}",
476            dimension,
477            self.db_path
478        );
479
480        let table_names = self
481            .connection
482            .table_names()
483            .execute()
484            .await
485            .context("Failed to list tables")?;
486
487        if table_names.contains(&self.rag_table_name) {
488            tracing::info!("Table '{}' already exists", self.rag_table_name);
489            return Ok(());
490        }
491
492        let schema = Self::create_rag_schema(dimension);
493        let empty_batch = RecordBatch::new_empty(schema.clone());
494        let batches =
495            RecordBatchIterator::new(vec![empty_batch].into_iter().map(Ok), schema.clone());
496
497        self.connection
498            .create_table(&self.rag_table_name, Box::new(batches))
499            .execute()
500            .await
501            .context("Failed to create table")?;
502
503        tracing::info!("Created table '{}'", self.rag_table_name);
504        Ok(())
505    }
506
507    async fn store_embeddings(
508        &self,
509        embeddings: Vec<Vec<f32>>,
510        metadata: Vec<ChunkMetadata>,
511        contents: Vec<String>,
512        root_path: &str,
513    ) -> Result<usize> {
514        if embeddings.is_empty() {
515            return Ok(0);
516        }
517
518        let dimension = embeddings[0].len();
519        let schema = Self::create_rag_schema(dimension);
520
521        let table = self.get_rag_table().await?;
522        let current_count = table.count_rows(None).await.unwrap_or(0) as u64;
523
524        let batch = Self::create_rag_record_batch(
525            embeddings,
526            metadata.clone(),
527            contents.clone(),
528            schema.clone(),
529        )?;
530        let count = batch.num_rows();
531
532        let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
533
534        table
535            .add(Box::new(batches))
536            .execute()
537            .await
538            .context("Failed to add records to table")?;
539
540        self.get_or_create_bm25(root_path)?;
541
542        let bm25_docs: Vec<_> = (0..count)
543            .map(|i| {
544                let id = current_count + i as u64;
545                (id, contents[i].clone(), metadata[i].file_path.clone())
546            })
547            .collect();
548
549        let hash = Self::hash_root_path(root_path);
550        let bm25_indexes = self
551            .bm25_indexes
552            .read()
553            .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
554
555        if let Some(bm25) = bm25_indexes.get(&hash) {
556            bm25.add_documents(bm25_docs)
557                .context("Failed to add documents to BM25 index")?;
558        }
559        drop(bm25_indexes);
560
561        tracing::info!(
562            "Stored {} embeddings with BM25 indexing for root: {}",
563            count,
564            root_path
565        );
566        Ok(count)
567    }
568
569    async fn search(
570        &self,
571        query_vector: Vec<f32>,
572        query_text: &str,
573        limit: usize,
574        min_score: f32,
575        project: Option<String>,
576        root_path: Option<String>,
577        hybrid: bool,
578    ) -> Result<Vec<SearchResult>> {
579        let table = self.get_rag_table().await?;
580
581        if hybrid {
582            let search_limit = limit * 3;
583
584            let query = table
585                .vector_search(query_vector)
586                .context("Failed to create vector search")?
587                .limit(search_limit);
588
589            let stream = if let Some(ref project_name) = project {
590                query
591                    .only_if(format!("project = '{}'", project_name))
592                    .execute()
593                    .await
594                    .context("Failed to execute search")?
595            } else {
596                query.execute().await.context("Failed to execute search")?
597            };
598
599            let results: Vec<RecordBatch> = stream
600                .try_collect()
601                .await
602                .context("Failed to collect search results")?;
603
604            let mut vector_results = Vec::new();
605            let mut row_offset = 0u64;
606            let mut original_scores: HashMap<u64, (f32, Option<f32>)> = HashMap::new();
607
608            for batch in &results {
609                let distance_array = batch
610                    .column_by_name("_distance")
611                    .context("Missing _distance column")?
612                    .as_any()
613                    .downcast_ref::<Float32Array>()
614                    .context("Invalid _distance type")?;
615
616                for i in 0..batch.num_rows() {
617                    let distance = distance_array.value(i);
618                    let score = 1.0 / (1.0 + distance);
619                    let id = row_offset + i as u64;
620                    vector_results.push((id, score));
621                    original_scores.insert(id, (score, None));
622                }
623                row_offset += batch.num_rows() as u64;
624            }
625
626            let bm25_indexes = self
627                .bm25_indexes
628                .read()
629                .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
630
631            let mut all_bm25_results = Vec::new();
632            for (root_hash, bm25) in bm25_indexes.iter() {
633                tracing::debug!("Searching BM25 index for root hash: {}", root_hash);
634                let bm25_results = bm25
635                    .search(query_text, search_limit)
636                    .context("Failed to search BM25 index")?;
637
638                for result in &bm25_results {
639                    original_scores
640                        .entry(result.id)
641                        .and_modify(|e| e.1 = Some(result.score))
642                        .or_insert((0.0, Some(result.score)));
643                }
644
645                all_bm25_results.extend(bm25_results);
646            }
647            drop(bm25_indexes);
648
649            let combined = self.scorer.fuse(vector_results, all_bm25_results, limit);
650
651            let mut search_results = Vec::new();
652
653            for (id, combined_score) in combined {
654                let mut found = false;
655                let mut batch_offset = 0u64;
656
657                for batch in &results {
658                    if id >= batch_offset && id < batch_offset + batch.num_rows() as u64 {
659                        let idx = (id - batch_offset) as usize;
660
661                        let file_path_array = batch
662                            .column_by_name("file_path")
663                            .and_then(|c| c.as_any().downcast_ref::<StringArray>());
664                        let root_path_array = batch
665                            .column_by_name("root_path")
666                            .and_then(|c| c.as_any().downcast_ref::<StringArray>());
667                        let start_line_array = batch
668                            .column_by_name("start_line")
669                            .and_then(|c| c.as_any().downcast_ref::<UInt32Array>());
670                        let end_line_array = batch
671                            .column_by_name("end_line")
672                            .and_then(|c| c.as_any().downcast_ref::<UInt32Array>());
673                        let language_array = batch
674                            .column_by_name("language")
675                            .and_then(|c| c.as_any().downcast_ref::<StringArray>());
676                        let content_array = batch
677                            .column_by_name("content")
678                            .and_then(|c| c.as_any().downcast_ref::<StringArray>());
679                        let project_array = batch
680                            .column_by_name("project")
681                            .and_then(|c| c.as_any().downcast_ref::<StringArray>());
682                        let indexed_at_array = batch
683                            .column_by_name("indexed_at")
684                            .and_then(|c| c.as_any().downcast_ref::<StringArray>());
685
686                        if let (
687                            Some(fp),
688                            Some(rp),
689                            Some(sl),
690                            Some(el),
691                            Some(lang),
692                            Some(cont),
693                            Some(proj),
694                        ) = (
695                            file_path_array,
696                            root_path_array,
697                            start_line_array,
698                            end_line_array,
699                            language_array,
700                            content_array,
701                            project_array,
702                        ) {
703                            let (vector_score, keyword_score) =
704                                original_scores.get(&id).copied().unwrap_or((0.0, None));
705
706                            let passes_filter = vector_score >= min_score
707                                || keyword_score.is_some_and(|k| k >= min_score);
708
709                            if passes_filter {
710                                let result_root_path = if rp.is_null(idx) {
711                                    None
712                                } else {
713                                    Some(rp.value(idx).to_string())
714                                };
715
716                                if let Some(ref filter_path) = root_path
717                                    && result_root_path.as_ref() != Some(filter_path)
718                                {
719                                    found = true;
720                                    break;
721                                }
722
723                                search_results.push(SearchResult {
724                                    score: combined_score,
725                                    vector_score,
726                                    keyword_score,
727                                    file_path: fp.value(idx).to_string(),
728                                    root_path: result_root_path,
729                                    start_line: sl.value(idx) as usize,
730                                    end_line: el.value(idx) as usize,
731                                    language: lang.value(idx).to_string(),
732                                    content: cont.value(idx).to_string(),
733                                    project: if proj.is_null(idx) {
734                                        None
735                                    } else {
736                                        Some(proj.value(idx).to_string())
737                                    },
738                                    indexed_at: indexed_at_array
739                                        .and_then(|ia| ia.value(idx).parse::<i64>().ok())
740                                        .unwrap_or(0),
741                                });
742                            }
743                            found = true;
744                            break;
745                        }
746                    }
747                    batch_offset += batch.num_rows() as u64;
748                }
749
750                if !found {
751                    tracing::warn!("Could not find result for RRF ID {}", id);
752                }
753            }
754
755            Ok(search_results)
756        } else {
757            // Pure vector search
758            let query = table
759                .vector_search(query_vector)
760                .context("Failed to create vector search")?
761                .limit(limit);
762
763            let stream = if let Some(ref project_name) = project {
764                query
765                    .only_if(format!("project = '{}'", project_name))
766                    .execute()
767                    .await
768                    .context("Failed to execute search")?
769            } else {
770                query.execute().await.context("Failed to execute search")?
771            };
772
773            let results: Vec<RecordBatch> = stream
774                .try_collect()
775                .await
776                .context("Failed to collect search results")?;
777
778            let mut search_results = Vec::new();
779
780            for batch in results {
781                let file_path_array = batch
782                    .column_by_name("file_path")
783                    .context("Missing file_path column")?
784                    .as_any()
785                    .downcast_ref::<StringArray>()
786                    .context("Invalid file_path type")?;
787
788                let root_path_array = batch
789                    .column_by_name("root_path")
790                    .context("Missing root_path column")?
791                    .as_any()
792                    .downcast_ref::<StringArray>()
793                    .context("Invalid root_path type")?;
794
795                let start_line_array = batch
796                    .column_by_name("start_line")
797                    .context("Missing start_line column")?
798                    .as_any()
799                    .downcast_ref::<UInt32Array>()
800                    .context("Invalid start_line type")?;
801
802                let end_line_array = batch
803                    .column_by_name("end_line")
804                    .context("Missing end_line column")?
805                    .as_any()
806                    .downcast_ref::<UInt32Array>()
807                    .context("Invalid end_line type")?;
808
809                let language_array = batch
810                    .column_by_name("language")
811                    .context("Missing language column")?
812                    .as_any()
813                    .downcast_ref::<StringArray>()
814                    .context("Invalid language type")?;
815
816                let content_array = batch
817                    .column_by_name("content")
818                    .context("Missing content column")?
819                    .as_any()
820                    .downcast_ref::<StringArray>()
821                    .context("Invalid content type")?;
822
823                let project_array = batch
824                    .column_by_name("project")
825                    .context("Missing project column")?
826                    .as_any()
827                    .downcast_ref::<StringArray>()
828                    .context("Invalid project type")?;
829
830                let distance_array = batch
831                    .column_by_name("_distance")
832                    .context("Missing _distance column")?
833                    .as_any()
834                    .downcast_ref::<Float32Array>()
835                    .context("Invalid _distance type")?;
836
837                let indexed_at_array = batch
838                    .column_by_name("indexed_at")
839                    .and_then(|c| c.as_any().downcast_ref::<StringArray>());
840
841                for i in 0..batch.num_rows() {
842                    let distance = distance_array.value(i);
843                    let score = 1.0 / (1.0 + distance);
844
845                    if score >= min_score {
846                        let result_root_path = if root_path_array.is_null(i) {
847                            None
848                        } else {
849                            Some(root_path_array.value(i).to_string())
850                        };
851
852                        if let Some(ref filter_path) = root_path
853                            && result_root_path.as_ref() != Some(filter_path)
854                        {
855                            continue;
856                        }
857
858                        search_results.push(SearchResult {
859                            score,
860                            vector_score: score,
861                            keyword_score: None,
862                            file_path: file_path_array.value(i).to_string(),
863                            root_path: result_root_path,
864                            start_line: start_line_array.value(i) as usize,
865                            end_line: end_line_array.value(i) as usize,
866                            language: language_array.value(i).to_string(),
867                            content: content_array.value(i).to_string(),
868                            project: if project_array.is_null(i) {
869                                None
870                            } else {
871                                Some(project_array.value(i).to_string())
872                            },
873                            indexed_at: indexed_at_array
874                                .and_then(|ia| ia.value(i).parse::<i64>().ok())
875                                .unwrap_or(0),
876                        });
877                    }
878                }
879            }
880
881            Ok(search_results)
882        }
883    }
884
885    async fn search_filtered(
886        &self,
887        query_vector: Vec<f32>,
888        query_text: &str,
889        limit: usize,
890        min_score: f32,
891        project: Option<String>,
892        root_path: Option<String>,
893        hybrid: bool,
894        file_extensions: Vec<String>,
895        languages: Vec<String>,
896        path_patterns: Vec<String>,
897    ) -> Result<Vec<SearchResult>> {
898        let search_limit = limit * 3;
899
900        let mut results = self
901            .search(
902                query_vector,
903                query_text,
904                search_limit,
905                min_score,
906                project,
907                root_path,
908                hybrid,
909            )
910            .await?;
911
912        results.retain(|result| {
913            if !file_extensions.is_empty() {
914                let has_extension = file_extensions
915                    .iter()
916                    .any(|ext| result.file_path.ends_with(&format!(".{}", ext)));
917                if !has_extension {
918                    return false;
919                }
920            }
921
922            if !languages.is_empty() && !languages.contains(&result.language) {
923                return false;
924            }
925
926            if !path_patterns.is_empty()
927                && !glob_utils::matches_any_pattern(&result.file_path, &path_patterns)
928            {
929                return false;
930            }
931
932            true
933        });
934
935        results.truncate(limit);
936        Ok(results)
937    }
938
939    async fn delete_by_file(&self, file_path: &str) -> Result<usize> {
940        {
941            let bm25_indexes = self
942                .bm25_indexes
943                .read()
944                .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
945
946            for (root_hash, bm25) in bm25_indexes.iter() {
947                bm25.delete_by_file_path(file_path)
948                    .context("Failed to delete from BM25 index")?;
949                tracing::debug!(
950                    "Deleted BM25 entries for file: {} in index: {}",
951                    file_path,
952                    root_hash
953                );
954            }
955        }
956
957        let table = self.get_rag_table().await?;
958        let filter = format!("file_path = '{}'", file_path);
959        table
960            .delete(&filter)
961            .await
962            .context("Failed to delete records")?;
963
964        tracing::info!("Deleted embeddings for file: {}", file_path);
965        Ok(0)
966    }
967
968    async fn clear(&self) -> Result<()> {
969        self.connection
970            .drop_table(&self.rag_table_name, &[])
971            .await
972            .context("Failed to drop table")?;
973
974        let bm25_indexes = self
975            .bm25_indexes
976            .read()
977            .map_err(|e| anyhow::anyhow!("Failed to acquire BM25 read lock: {}", e))?;
978
979        for (root_hash, bm25) in bm25_indexes.iter() {
980            bm25.clear().context("Failed to clear BM25 index")?;
981            tracing::info!("Cleared BM25 index for root hash: {}", root_hash);
982        }
983        drop(bm25_indexes);
984
985        tracing::info!("Cleared all embeddings and all per-project BM25 indexes");
986        Ok(())
987    }
988
989    async fn get_statistics(&self) -> Result<DatabaseStats> {
990        let table = self.get_rag_table().await?;
991
992        let count_result = table
993            .count_rows(None)
994            .await
995            .context("Failed to count rows")?;
996
997        let stream = table
998            .query()
999            .select(lancedb::query::Select::Columns(vec![
1000                "language".to_string(),
1001            ]))
1002            .execute()
1003            .await
1004            .context("Failed to query languages")?;
1005
1006        let query_result: Vec<RecordBatch> = stream
1007            .try_collect()
1008            .await
1009            .context("Failed to collect language data")?;
1010
1011        let mut language_counts: HashMap<String, usize> = HashMap::new();
1012
1013        for batch in query_result {
1014            let language_array = batch
1015                .column_by_name("language")
1016                .context("Missing language column")?
1017                .as_any()
1018                .downcast_ref::<StringArray>()
1019                .context("Invalid language type")?;
1020
1021            for i in 0..batch.num_rows() {
1022                let language = language_array.value(i);
1023                *language_counts.entry(language.to_string()).or_insert(0) += 1;
1024            }
1025        }
1026
1027        let mut language_breakdown: Vec<(String, usize)> = language_counts.into_iter().collect();
1028        language_breakdown.sort_by(|a, b| b.1.cmp(&a.1));
1029
1030        Ok(DatabaseStats {
1031            total_points: count_result,
1032            total_vectors: count_result,
1033            language_breakdown,
1034        })
1035    }
1036
1037    async fn flush(&self) -> Result<()> {
1038        Ok(())
1039    }
1040
1041    async fn count_by_root_path(&self, root_path: &str) -> Result<usize> {
1042        let table = self.get_rag_table().await?;
1043        let filter = format!("root_path = '{}'", root_path);
1044        let count = table
1045            .count_rows(Some(filter))
1046            .await
1047            .context("Failed to count rows by root path")?;
1048        Ok(count)
1049    }
1050
1051    async fn get_indexed_files(&self, root_path: &str) -> Result<Vec<String>> {
1052        let table = self.get_rag_table().await?;
1053        let filter = format!("root_path = '{}'", root_path);
1054        let stream = table
1055            .query()
1056            .only_if(filter)
1057            .select(lancedb::query::Select::Columns(vec![
1058                "file_path".to_string(),
1059            ]))
1060            .execute()
1061            .await
1062            .context("Failed to query indexed files")?;
1063
1064        let results: Vec<RecordBatch> = stream
1065            .try_collect()
1066            .await
1067            .context("Failed to collect file paths")?;
1068
1069        let mut file_paths = std::collections::HashSet::new();
1070        for batch in results {
1071            let file_path_array = batch
1072                .column_by_name("file_path")
1073                .context("Missing file_path column")?
1074                .as_any()
1075                .downcast_ref::<StringArray>()
1076                .context("Invalid file_path type")?;
1077
1078            for i in 0..batch.num_rows() {
1079                file_paths.insert(file_path_array.value(i).to_string());
1080            }
1081        }
1082
1083        Ok(file_paths.into_iter().collect())
1084    }
1085
1086    async fn search_with_embeddings(
1087        &self,
1088        query_vector: Vec<f32>,
1089        query_text: &str,
1090        limit: usize,
1091        min_score: f32,
1092        project: Option<String>,
1093        root_path: Option<String>,
1094        hybrid: bool,
1095    ) -> Result<(Vec<SearchResult>, Vec<Vec<f32>>)> {
1096        let results = self
1097            .search(
1098                query_vector,
1099                query_text,
1100                limit,
1101                min_score,
1102                project,
1103                root_path,
1104                hybrid,
1105            )
1106            .await?;
1107
1108        if results.is_empty() {
1109            return Ok((results, Vec::new()));
1110        }
1111
1112        let table = self.get_rag_table().await?;
1113        let mut embeddings = Vec::with_capacity(results.len());
1114
1115        for result in &results {
1116            let filter = format!(
1117                "file_path = '{}' AND start_line = {}",
1118                result.file_path, result.start_line
1119            );
1120            let stream = table
1121                .query()
1122                .only_if(filter)
1123                .select(lancedb::query::Select::Columns(vec!["vector".to_string()]))
1124                .limit(1)
1125                .execute()
1126                .await
1127                .context("Failed to query embedding vector")?;
1128
1129            let batches: Vec<RecordBatch> = stream
1130                .try_collect()
1131                .await
1132                .context("Failed to collect embedding vector")?;
1133
1134            let mut found = false;
1135            for batch in &batches {
1136                if batch.num_rows() > 0
1137                    && let Some(vector_col) = batch.column_by_name("vector")
1138                    && let Some(fsl) = vector_col.as_any().downcast_ref::<FixedSizeListArray>()
1139                {
1140                    let values = fsl
1141                        .value(0)
1142                        .as_any()
1143                        .downcast_ref::<Float32Array>()
1144                        .map(|a| a.values().to_vec())
1145                        .unwrap_or_default();
1146                    embeddings.push(values);
1147                    found = true;
1148                    break;
1149                }
1150            }
1151            if !found {
1152                embeddings.push(Vec::new());
1153            }
1154        }
1155
1156        Ok((results, embeddings))
1157    }
1158}
1159
1160#[cfg(test)]
1161mod tests {
1162    use super::*;
1163    use crate::databases::types::{FieldValue, Filter};
1164    use tempfile::TempDir;
1165
1166    #[tokio::test]
1167    async fn test_lance_database_new() {
1168        let temp = TempDir::new().unwrap();
1169        let db_path = temp.path().join("test.lance");
1170        let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1171        assert_eq!(db.db_path(), db_path.to_str().unwrap());
1172    }
1173
1174    #[tokio::test]
1175    async fn test_lance_storage_backend_crud() {
1176        let temp = TempDir::new().unwrap();
1177        let db_path = temp.path().join("test.lance");
1178        let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1179
1180        let schema = vec![
1181            FieldDef::required("id", crate::databases::types::FieldType::Utf8),
1182            FieldDef::required("value", crate::databases::types::FieldType::Int64),
1183        ];
1184        db.ensure_table("test_table", &schema).await.unwrap();
1185
1186        let records = vec![vec![
1187            ("id".to_string(), FieldValue::Utf8(Some("row1".to_string()))),
1188            ("value".to_string(), FieldValue::Int64(Some(42))),
1189        ]];
1190        db.insert("test_table", records).await.unwrap();
1191
1192        let results = db.query("test_table", None, None).await.unwrap();
1193        assert_eq!(results.len(), 1);
1194
1195        let count = db.count("test_table", None).await.unwrap();
1196        assert_eq!(count, 1);
1197
1198        db.delete(
1199            "test_table",
1200            &Filter::Eq("id".into(), FieldValue::Utf8(Some("row1".into()))),
1201        )
1202        .await
1203        .unwrap();
1204
1205        let count = db.count("test_table", None).await.unwrap();
1206        assert_eq!(count, 0);
1207    }
1208
1209    #[tokio::test]
1210    async fn test_lance_vector_search() {
1211        use crate::databases::types::FieldType;
1212
1213        let temp = TempDir::new().unwrap();
1214        let db_path = temp.path().join("vec_search.lance");
1215        let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1216
1217        let dim = 4;
1218        let schema = vec![
1219            FieldDef::required("id", FieldType::Utf8),
1220            FieldDef::required("embedding", FieldType::Vector(dim)),
1221        ];
1222        db.ensure_table("vectors", &schema).await.unwrap();
1223
1224        // Insert three records with different vectors.
1225        let records = vec![
1226            vec![
1227                ("id".to_string(), FieldValue::Utf8(Some("a".to_string()))),
1228                (
1229                    "embedding".to_string(),
1230                    FieldValue::Vector(vec![1.0, 0.0, 0.0, 0.0]),
1231                ),
1232            ],
1233            vec![
1234                ("id".to_string(), FieldValue::Utf8(Some("b".to_string()))),
1235                (
1236                    "embedding".to_string(),
1237                    FieldValue::Vector(vec![0.0, 1.0, 0.0, 0.0]),
1238                ),
1239            ],
1240            vec![
1241                ("id".to_string(), FieldValue::Utf8(Some("c".to_string()))),
1242                (
1243                    "embedding".to_string(),
1244                    FieldValue::Vector(vec![0.9, 0.1, 0.0, 0.0]),
1245                ),
1246            ],
1247        ];
1248        db.insert("vectors", records).await.unwrap();
1249
1250        // Search for a vector closest to [1, 0, 0, 0] — should rank "a" first.
1251        let results = db
1252            .vector_search("vectors", "embedding", vec![1.0, 0.0, 0.0, 0.0], 3, None)
1253            .await
1254            .unwrap();
1255
1256        assert!(!results.is_empty(), "vector_search should return results");
1257        // The first result should be "a" (exact match → distance 0 → highest score).
1258        let first_id = results[0]
1259            .record
1260            .iter()
1261            .find(|(n, _)| n == "id")
1262            .and_then(|(_, v)| v.as_str())
1263            .unwrap();
1264        assert_eq!(first_id, "a");
1265
1266        // Scores should be in descending order.
1267        for w in results.windows(2) {
1268            assert!(
1269                w[0].score >= w[1].score,
1270                "scores should be descending: {} >= {}",
1271                w[0].score,
1272                w[1].score
1273            );
1274        }
1275    }
1276
1277    #[tokio::test]
1278    async fn test_lance_capabilities() {
1279        let temp = TempDir::new().unwrap();
1280        let db_path = temp.path().join("caps.lance");
1281        let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1282
1283        let caps = db.capabilities();
1284        assert!(
1285            caps.vector_search,
1286            "LanceDatabase should support vector search"
1287        );
1288    }
1289
1290    #[tokio::test]
1291    async fn test_lance_shared_connection() {
1292        use crate::databases::types::FieldType;
1293
1294        let temp = TempDir::new().unwrap();
1295        let db_path = temp.path().join("shared.lance");
1296        let db = LanceDatabase::new(db_path.to_str().unwrap()).await.unwrap();
1297
1298        // Use StorageBackend trait
1299        let schema = vec![FieldDef::required("name", FieldType::Utf8)];
1300        db.ensure_table("store_table", &schema).await.unwrap();
1301        let records = vec![vec![(
1302            "name".to_string(),
1303            FieldValue::Utf8(Some("test".to_string())),
1304        )]];
1305        db.insert("store_table", records).await.unwrap();
1306
1307        // Use VectorDatabase trait on same instance
1308        db.initialize(4).await.unwrap();
1309
1310        // Both should work on the same connection
1311        let store_count = db.count("store_table", None).await.unwrap();
1312        assert_eq!(store_count, 1);
1313
1314        let stats = db.get_statistics().await.unwrap();
1315        assert_eq!(stats.total_vectors, 0);
1316    }
1317}