use anyhow::{Context, Result, anyhow, bail};
use arrow_array::types::Float32Type;
use arrow_array::{
FixedSizeListArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch,
RecordBatchIterator, StringArray, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::index::Index;
use lancedb::query::{ExecutableQuery, QueryBase};
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct DocMeta {
pub path: String,
pub size_bytes: u64,
pub mtime: i64,
}
impl DocMeta {
pub fn id(&self) -> i32 {
let mut hasher = DefaultHasher::new();
self.path.hash(&mut hasher);
(hasher.finish() as i32).abs().max(1)
}
}
#[derive(Debug, Clone)]
pub struct RankedDoc {
pub path: String,
pub distance: f32,
}
#[derive(Debug, Clone)]
pub struct WorkspaceStats {
pub total_documents: usize,
pub has_index: bool,
pub index_type: Option<String>,
}
pub struct Store {
db: lancedb::Connection,
}
impl Store {
pub async fn open(workspace_dir: &str) -> Result<Self> {
let db_path = Path::new(workspace_dir)
.join("documents.lance")
.to_string_lossy()
.to_string();
let db = lancedb::connect(&db_path)
.execute()
.await
.with_context(|| format!("failed to open LanceDB connection at {db_path}"))?;
Ok(Self { db })
}
pub async fn get_existing_docs(&self, paths: &[String]) -> Result<HashMap<String, DocMeta>> {
let mut existing = HashMap::new();
let tables = self
.db
.table_names()
.execute()
.await
.context("failed to list LanceDB tables")?;
if !tables.contains(&"documents".to_string()) {
return Ok(existing);
}
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
for chunk in paths.chunks(1000) {
let filter_expr = build_in_filter(chunk);
let stream = tbl
.query()
.only_if(filter_expr)
.execute()
.await
.context("failed to execute documents query")?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.context("failed to collect query result batches")?;
for batch in batches {
let schema = batch.schema();
let path_idx = schema
.index_of("path")
.context("missing 'path' column in documents schema")?;
let size_idx = schema
.index_of("size_bytes")
.context("missing 'size_bytes' column in documents schema")?;
let mtime_idx = schema
.index_of("mtime")
.context("missing 'mtime' column in documents schema")?;
let path_array = batch
.column(path_idx)
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow!("unexpected type for 'path' column"))?;
let size_array = batch
.column(size_idx)
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| anyhow!("unexpected type for 'size_bytes' column"))?;
let mtime_array = batch
.column(mtime_idx)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| anyhow!("unexpected type for 'mtime' column"))?;
for i in 0..batch.num_rows() {
let path = path_array.value(i).to_string();
let size_bytes = size_array.value(i);
let mtime = mtime_array.value(i);
existing.insert(
path.clone(),
DocMeta {
path,
size_bytes,
mtime,
},
);
}
}
}
Ok(existing)
}
pub async fn delete_documents(&self, paths: &[String]) -> Result<()> {
if paths.is_empty() {
return Ok(());
}
let tables = self
.db
.table_names()
.execute()
.await
.context("failed to list LanceDB tables")?;
if !tables.contains(&"documents".to_string()) {
return Ok(()); }
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
for chunk in paths.chunks(1000) {
let filter_expr = build_in_filter(chunk);
tbl.delete(&filter_expr).await.with_context(|| {
format!("failed to delete documents with filter: {filter_expr}")
})?;
}
Ok(())
}
pub async fn upsert_documents(&self, metas: &[DocMeta], embeddings: &[Vec<f32>]) -> Result<()> {
if metas.len() != embeddings.len() {
bail!(
"metas and embeddings length mismatch: {} vs {}",
metas.len(),
embeddings.len()
);
}
if embeddings.is_empty() {
return Ok(()); }
let dim = embeddings[0].len();
if dim == 0 {
bail!("embeddings must be non-empty vectors");
}
if embeddings.iter().any(|e| e.len() != dim) {
bail!("all embeddings must have equal length");
}
let paths: Vec<String> = metas.iter().map(|m| m.path.clone()).collect();
self.delete_documents(&paths).await?;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("path", DataType::Utf8, false),
Field::new("size_bytes", DataType::UInt64, false),
Field::new("mtime", DataType::Int64, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
),
true,
),
]));
let id_array = Int32Array::from_iter_values(metas.iter().map(|meta| meta.id()));
let path_array =
StringArray::from(metas.iter().map(|m| m.path.as_str()).collect::<Vec<_>>());
let size_bytes_array = UInt64Array::from_iter_values(metas.iter().map(|m| m.size_bytes));
let mtime_array = Int64Array::from_iter_values(metas.iter().map(|m| m.mtime));
let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
embeddings
.iter()
.map(|embedding| Some(embedding.iter().cloned().map(Some))),
dim as i32,
);
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(id_array),
Arc::new(path_array),
Arc::new(size_bytes_array),
Arc::new(mtime_array),
Arc::new(vector_array),
],
)?;
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone());
let tables = self
.db
.table_names()
.execute()
.await
.context("failed to list LanceDB tables")?;
let table_existed = tables.contains(&"documents".to_string());
if !table_existed {
self.db
.create_table("documents", Box::new(batches))
.execute()
.await
.context("failed to create 'documents' table")?;
} else {
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
tbl.add(Box::new(batches))
.execute()
.await
.context("failed to append batches to 'documents' table")?;
}
self.ensure_vector_index().await?;
Ok(())
}
async fn ensure_vector_index(&self) -> Result<()> {
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
let indices = tbl
.list_indices()
.await
.context("failed to list indices for 'documents' table")?;
let has_vector_index = indices
.iter()
.any(|idx| idx.columns.contains(&"vector".to_string()));
if !has_vector_index {
match tbl.create_index(&["vector"], Index::Auto).execute().await {
Ok(_) => {
}
Err(e) => {
let error_msg = e.to_string();
if error_msg.contains("Not enough rows to train PQ")
|| error_msg.contains("Requires 256 rows")
{
eprintln!(
"Warning: Skipping vector index creation due to insufficient data (need at least 256 rows for PQ index). Database will use brute-force search."
);
} else if error_msg.contains("No space left on device") {
return Err(anyhow!(
"Insufficient disk space to create vector index. Consider freeing up space or using a different workspace location."
));
} else if error_msg.contains("Permission denied") {
return Err(anyhow!(
"Permission denied while creating vector index. Check workspace directory permissions."
));
} else {
return Err(e.into());
}
}
}
} else {
if tbl.optimize(Default::default()).await.is_err() {
eprintln!("Warning: Failed to optimize vector index");
}
}
Ok(())
}
pub async fn get_stats(&self) -> Result<WorkspaceStats> {
let tables = self
.db
.table_names()
.execute()
.await
.context("failed to list LanceDB tables")?;
if !tables.contains(&"documents".to_string()) {
return Ok(WorkspaceStats {
total_documents: 0,
has_index: false,
index_type: None,
});
}
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
let stream = tbl
.query()
.execute()
.await
.context("failed to execute count query on 'documents'")?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.context("failed to collect result batches for stats")?;
let total_documents = batches.iter().map(|batch| batch.num_rows()).sum();
let indices = tbl
.list_indices()
.await
.context("failed to list indices for 'documents' table")?;
let has_vector_index = indices
.iter()
.any(|idx| idx.columns.contains(&"vector".to_string()));
let index_type = if has_vector_index {
Some("IVF_PQ".to_string())
} else {
None
};
Ok(WorkspaceStats {
total_documents,
has_index: has_vector_index,
index_type,
})
}
pub async fn get_all_document_paths(&self) -> Result<Vec<String>> {
let tables = self
.db
.table_names()
.execute()
.await
.context("failed to list LanceDB tables")?;
if !tables.contains(&"documents".to_string()) {
return Ok(Vec::new());
}
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
let stream = tbl
.query()
.execute()
.await
.context("failed to execute query for all document paths")?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.context("failed to collect batches for all document paths")?;
let mut paths = Vec::new();
for batch in batches {
let schema = batch.schema();
let path_idx = schema
.index_of("path")
.context("missing 'path' column in documents schema")?;
let path_array = batch
.column(path_idx)
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow!("unexpected type for 'path' column"))?;
for i in 0..batch.num_rows() {
paths.push(path_array.value(i).to_string());
}
}
Ok(paths)
}
pub async fn ann_filter_top_k(
&self,
query_vec: &[f32],
subset_paths: &[String],
doc_top_k: usize,
in_batch_size: usize,
) -> Result<Vec<RankedDoc>> {
self.ann_filter_top_k_with_params(
query_vec,
subset_paths,
doc_top_k,
in_batch_size,
Some(5),
Some(10),
)
.await
}
pub async fn ann_filter_top_k_with_params(
&self,
query_vec: &[f32],
subset_paths: &[String],
doc_top_k: usize,
in_batch_size: usize,
refine_factor: Option<u32>,
nprobes: Option<u32>,
) -> Result<Vec<RankedDoc>> {
if subset_paths.is_empty() || doc_top_k == 0 {
return Ok(Vec::new());
}
let tbl = self
.db
.open_table("documents")
.execute()
.await
.context("failed to open 'documents' table")?;
let mut best_by_path: HashMap<String, f32> = HashMap::new();
for chunk in subset_paths.chunks(in_batch_size.max(1)) {
let filter_expr = build_in_filter(chunk);
let mut query = tbl
.query()
.only_if(filter_expr)
.nearest_to(query_vec)
.context("failed to set nearest_to on query")?
.distance_type(lancedb::DistanceType::Cosine)
.limit(doc_top_k);
if let Some(rf) = refine_factor {
query = query.refine_factor(rf);
}
if let Some(np) = nprobes {
query = query.nprobes(np as usize);
}
let stream = query
.execute()
.await
.context("failed to execute ANN query batch")?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.context("failed to collect ANN query batches")?;
for batch in batches {
let schema = batch.schema();
let path_idx = schema
.index_of("path")
.context("missing 'path' column in ANN result schema")?;
let distance_idx = schema
.index_of("_distance")
.or_else(|_| schema.index_of("distance"))
.context("missing 'distance' column in ANN result schema")?;
let path_col = batch.column(path_idx);
let dist_col = batch.column(distance_idx);
let path_array = path_col
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow!("unexpected type for 'path' column in ANN result"))?;
if let Some(dist_array) = dist_col.as_any().downcast_ref::<Float32Array>() {
for i in 0..batch.num_rows() {
let path = path_array.value(i).to_string();
let distance = dist_array.value(i);
match best_by_path.get_mut(&path) {
Some(existing) => {
if distance < *existing {
*existing = distance;
}
}
None => {
best_by_path.insert(path, distance);
}
}
}
} else if let Some(dist_array) = dist_col.as_any().downcast_ref::<Float64Array>() {
for i in 0..batch.num_rows() {
let path = path_array.value(i).to_string();
let distance_f32 = dist_array.value(i) as f32;
match best_by_path.get_mut(&path) {
Some(existing) => {
if distance_f32 < *existing {
*existing = distance_f32;
}
}
None => {
best_by_path.insert(path, distance_f32);
}
}
}
} else {
bail!("unsupported distance column type");
}
}
}
let mut ranked: Vec<RankedDoc> = best_by_path
.into_iter()
.map(|(path, distance)| RankedDoc { path, distance })
.collect();
ranked.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
ranked.truncate(doc_top_k);
Ok(ranked)
}
}
pub fn build_in_filter(paths: &[String]) -> String {
let escaped: Vec<String> = paths
.iter()
.map(|p| p.replace('\'', "''"))
.map(|p| format!("'{p}'"))
.collect();
format!("path IN ({})", escaped.join(","))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_store() -> (Store, TempDir) {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let store = Store::open(temp_dir.path().to_str().unwrap())
.await
.expect("Failed to create store");
(store, temp_dir)
}
fn create_test_docs() -> (Vec<DocMeta>, Vec<Vec<f32>>) {
let docs = vec![
DocMeta {
path: "/test/doc1.txt".to_string(),
size_bytes: 100,
mtime: 1234567890,
},
DocMeta {
path: "/test/doc2.txt".to_string(),
size_bytes: 200,
mtime: 1234567891,
},
DocMeta {
path: "/test/doc3.txt".to_string(),
size_bytes: 150,
mtime: 1234567892,
},
];
let embeddings = vec![
vec![0.1, 0.2, 0.3, 0.4],
vec![0.5, 0.6, 0.7, 0.8],
vec![0.9, 1.0, 1.1, 1.2],
];
(docs, embeddings)
}
#[tokio::test]
async fn test_store_creation_and_stats_empty() {
let (store, _temp_dir) = create_test_store().await;
let stats = store.get_stats().await.expect("Failed to get stats");
assert_eq!(stats.total_documents, 0);
assert!(!stats.has_index);
assert_eq!(stats.index_type, None);
}
#[tokio::test]
async fn test_upsert_documents_and_stats() {
let (store, _temp_dir) = create_test_store().await;
let (docs, embeddings) = create_test_docs();
store
.upsert_documents(&docs, &embeddings)
.await
.expect("Failed to upsert documents");
let stats = store.get_stats().await.expect("Failed to get stats");
assert_eq!(stats.total_documents, 3);
if stats.has_index {
assert_eq!(stats.index_type, Some("IVF_PQ".to_string()));
}
}
#[tokio::test]
async fn test_get_all_document_paths() {
let (store, _temp_dir) = create_test_store().await;
let (docs, embeddings) = create_test_docs();
let paths = store
.get_all_document_paths()
.await
.expect("Failed to get document paths");
assert!(paths.is_empty());
store
.upsert_documents(&docs, &embeddings)
.await
.expect("Failed to upsert documents");
let paths = store
.get_all_document_paths()
.await
.expect("Failed to get document paths");
assert_eq!(paths.len(), 3);
assert!(paths.contains(&"/test/doc1.txt".to_string()));
assert!(paths.contains(&"/test/doc2.txt".to_string()));
assert!(paths.contains(&"/test/doc3.txt".to_string()));
}
#[tokio::test]
async fn test_get_existing_docs() {
let (store, _temp_dir) = create_test_store().await;
let (docs, embeddings) = create_test_docs();
store
.upsert_documents(&docs, &embeddings)
.await
.expect("Failed to upsert documents");
let query_paths = vec![
"/test/doc1.txt".to_string(),
"/test/doc2.txt".to_string(),
"/test/nonexistent.txt".to_string(),
];
let existing = store
.get_existing_docs(&query_paths)
.await
.expect("Failed to get existing docs");
assert_eq!(existing.len(), 2);
assert!(existing.contains_key("/test/doc1.txt"));
assert!(existing.contains_key("/test/doc2.txt"));
assert!(!existing.contains_key("/test/nonexistent.txt"));
let doc1_meta = existing.get("/test/doc1.txt").unwrap();
assert_eq!(doc1_meta.size_bytes, 100);
assert_eq!(doc1_meta.mtime, 1234567890);
}
#[tokio::test]
async fn test_delete_documents() {
let (store, _temp_dir) = create_test_store().await;
let (docs, embeddings) = create_test_docs();
store
.upsert_documents(&docs, &embeddings)
.await
.expect("Failed to upsert documents");
let all_paths = store
.get_all_document_paths()
.await
.expect("Failed to get document paths");
assert_eq!(all_paths.len(), 3);
let to_delete = vec!["/test/doc1.txt".to_string(), "/test/doc3.txt".to_string()];
store
.delete_documents(&to_delete)
.await
.expect("Failed to delete documents");
let remaining_paths = store
.get_all_document_paths()
.await
.expect("Failed to get document paths");
assert_eq!(remaining_paths.len(), 1);
assert!(remaining_paths.contains(&"/test/doc2.txt".to_string()));
}
#[tokio::test]
async fn test_ann_filter_top_k() {
let (store, _temp_dir) = create_test_store().await;
let (docs, embeddings) = create_test_docs();
store
.upsert_documents(&docs, &embeddings)
.await
.expect("Failed to upsert documents");
let query_vec = vec![0.2, 0.3, 0.4, 0.5];
let subset_paths = vec![
"/test/doc1.txt".to_string(),
"/test/doc2.txt".to_string(),
"/test/doc3.txt".to_string(),
];
let results = store
.ann_filter_top_k(&query_vec, &subset_paths, 2, 1000)
.await
.expect("Failed to perform ANN search");
assert!(!results.is_empty());
assert!(results.len() <= 2);
for i in 1..results.len() {
assert!(results[i - 1].distance <= results[i].distance);
}
}
#[tokio::test]
async fn test_ann_filter_top_k_with_custom_params() {
let (store, _temp_dir) = create_test_store().await;
let (docs, embeddings) = create_test_docs();
store
.upsert_documents(&docs, &embeddings)
.await
.expect("Failed to upsert documents");
let query_vec = vec![0.2, 0.3, 0.4, 0.5];
let subset_paths = vec![
"/test/doc1.txt".to_string(),
"/test/doc2.txt".to_string(),
"/test/doc3.txt".to_string(),
];
let results = store
.ann_filter_top_k_with_params(&query_vec, &subset_paths, 2, 1000, Some(3), Some(5))
.await
.expect("Failed to perform ANN search with custom params");
assert!(!results.is_empty());
assert!(results.len() <= 2);
}
#[tokio::test]
async fn test_upsert_replaces_existing() {
let (store, _temp_dir) = create_test_store().await;
let initial_doc = DocMeta {
path: "/test/doc.txt".to_string(),
size_bytes: 100,
mtime: 1000,
};
let initial_embedding = vec![vec![1.0, 2.0, 3.0, 4.0]];
store
.upsert_documents(&[initial_doc], &initial_embedding)
.await
.expect("Failed to insert initial document");
let paths = store
.get_all_document_paths()
.await
.expect("Failed to get paths");
assert_eq!(paths.len(), 1);
let updated_doc = DocMeta {
path: "/test/doc.txt".to_string(),
size_bytes: 200,
mtime: 2000,
};
let updated_embedding = vec![vec![5.0, 6.0, 7.0, 8.0]];
store
.upsert_documents(&[updated_doc], &updated_embedding)
.await
.expect("Failed to update document");
let paths = store
.get_all_document_paths()
.await
.expect("Failed to get paths");
assert_eq!(paths.len(), 1);
let existing = store
.get_existing_docs(&["/test/doc.txt".to_string()])
.await
.expect("Failed to get existing docs");
let doc_meta = existing.get("/test/doc.txt").unwrap();
assert_eq!(doc_meta.size_bytes, 200);
assert_eq!(doc_meta.mtime, 2000);
}
#[test]
fn test_build_in_filter() {
let paths = vec![
"file1.txt".to_string(),
"file2.txt".to_string(),
"file with spaces.txt".to_string(),
"file'with'quotes.txt".to_string(),
];
let filter = build_in_filter(&paths);
assert!(filter.starts_with("path IN ("));
assert!(filter.ends_with(")"));
assert!(filter.contains("'file1.txt'"));
assert!(filter.contains("'file2.txt'"));
assert!(filter.contains("'file with spaces.txt'"));
assert!(filter.contains("'file''with''quotes.txt'"));
}
#[test]
fn test_doc_meta_id_generation() {
let doc1 = DocMeta {
path: "test1.txt".to_string(),
size_bytes: 100,
mtime: 1000,
};
let doc2 = DocMeta {
path: "test2.txt".to_string(),
size_bytes: 100,
mtime: 1000,
};
let id1 = doc1.id();
let id2 = doc2.id();
assert_ne!(id1, id2);
assert!(id1 >= 0);
assert!(id2 >= 0);
}
}