use crate::core::{GraphRAGError, Result};
use std::path::PathBuf;
#[cfg(feature = "lancedb")]
use std::sync::Arc;
#[cfg(feature = "lancedb")]
use arrow_array::{
FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
#[cfg(feature = "lancedb")]
use arrow_array::types::Float32Type;
#[cfg(feature = "lancedb")]
use arrow_schema::{DataType, Field, Schema, SchemaRef};
#[cfg(feature = "lancedb")]
use lancedb::query::{ExecutableQuery, QueryBase};
#[derive(Debug, Clone)]
pub struct LanceConfig {
pub dimension: usize,
pub index_type: IndexType,
pub distance_metric: DistanceMetric,
}
#[derive(Debug, Clone, Copy)]
pub enum IndexType {
Flat,
Hnsw,
Ivf,
}
#[derive(Debug, Clone, Copy)]
pub enum DistanceMetric {
L2,
Cosine,
Dot,
}
impl Default for LanceConfig {
fn default() -> Self {
Self {
dimension: 768, index_type: IndexType::Hnsw,
distance_metric: DistanceMetric::Cosine,
}
}
}
pub struct LanceVectorStore {
path: PathBuf,
config: LanceConfig,
#[cfg(feature = "lancedb")]
connection: lancedb::Connection,
#[cfg(feature = "lancedb")]
table: lancedb::Table,
}
#[cfg(feature = "lancedb")]
impl std::fmt::Debug for LanceVectorStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanceVectorStore")
.field("path", &self.path)
.field("config", &self.config)
.finish()
}
}
#[cfg(not(feature = "lancedb"))]
impl std::fmt::Debug for LanceVectorStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanceVectorStore")
.field("path", &self.path)
.field("config", &self.config)
.finish()
}
}
impl LanceVectorStore {
#[cfg(feature = "lancedb")]
pub async fn new(path: PathBuf, config: LanceConfig) -> Result<Self> {
let db = lancedb::connect(path.to_str().ok_or_else(|| GraphRAGError::Config {
message: "Invalid path encoding".to_string(),
})?)
.execute()
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to connect to LanceDB: {}", e),
})?;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
config.dimension as i32,
),
false,
),
]));
let table = match db.open_table("embeddings").execute().await {
Ok(table) => {
#[cfg(feature = "tracing")]
tracing::info!("Opened existing LanceDB table at: {:?}", path);
table
},
Err(_) => {
let empty_batches = create_empty_batch(schema.clone())?;
db.create_table("embeddings", empty_batches)
.execute()
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to create LanceDB table: {}", e),
})?
},
};
#[cfg(feature = "tracing")]
tracing::info!("LanceDB vector store initialized at: {:?}", path);
Ok(Self {
path,
config,
connection: db,
table,
})
}
#[cfg(feature = "lancedb")]
pub async fn store_embedding(&self, id: &str, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.config.dimension {
return Err(GraphRAGError::Config {
message: format!(
"Embedding dimension mismatch: expected {}, got {}",
self.config.dimension,
embedding.len()
),
});
}
let id_array = Arc::new(StringArray::from(vec![id]));
let vector_array = Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
vec![Some(embedding.into_iter().map(Some).collect::<Vec<_>>())],
self.config.dimension as i32,
),
);
let schema = self
.table
.schema()
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to get table schema: {}", e),
})?;
let batch =
RecordBatch::try_new(schema.clone(), vec![id_array, vector_array]).map_err(|e| {
GraphRAGError::Config {
message: format!("Failed to create record batch: {}", e),
}
})?;
let batches = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema);
self.table
.add(Box::new(batches))
.execute()
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to store embedding: {}", e),
})?;
#[cfg(feature = "tracing")]
tracing::debug!("Stored embedding for id: {}", id);
Ok(())
}
#[cfg(feature = "lancedb")]
pub async fn search_similar(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
use arrow_array::cast::AsArray;
use futures::stream::TryStreamExt;
if query.len() != self.config.dimension {
return Err(GraphRAGError::Config {
message: format!(
"Query dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
),
});
}
let results = self
.table
.query()
.limit(k)
.nearest_to(query)
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Failed to create query: {}", e),
})?
.execute()
.await
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Failed to execute search: {}", e),
})?
.try_collect::<Vec<_>>()
.await
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Failed to collect results: {}", e),
})?;
let mut search_results = Vec::new();
for batch in results {
let id_array = batch
.column(0)
.as_string::<i32>()
.iter()
.map(|s| s.unwrap_or("").to_string())
.collect::<Vec<_>>();
let vector_array = batch.column(1).as_fixed_size_list();
for (idx, id) in id_array.iter().enumerate() {
let embedding = if let Some(values) =
vector_array.value(idx).as_primitive_opt::<Float32Type>()
{
values.values().to_vec()
} else {
vec![0.0; self.config.dimension]
};
let score = 1.0 / (search_results.len() as f32 + 1.0);
search_results.push(SearchResult {
id: id.clone(),
score,
embedding,
});
}
}
Ok(search_results)
}
#[cfg(feature = "lancedb")]
pub async fn store_embeddings_batch(&self, embeddings: Vec<(String, Vec<f32>)>) -> Result<()> {
if embeddings.is_empty() {
return Ok(());
}
for (id, embedding) in &embeddings {
if embedding.len() != self.config.dimension {
return Err(GraphRAGError::Config {
message: format!(
"Embedding dimension mismatch for '{}': expected {}, got {}",
id,
self.config.dimension,
embedding.len()
),
});
}
}
let ids: Vec<&str> = embeddings.iter().map(|(id, _)| id.as_str()).collect();
let id_array = Arc::new(StringArray::from(ids));
let vectors: Vec<Option<Vec<Option<f32>>>> = embeddings
.iter()
.map(|(_, vec)| Some(vec.iter().map(|&v| Some(v)).collect()))
.collect();
let vector_array = Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
vectors,
self.config.dimension as i32,
),
);
let schema = self
.table
.schema()
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to get table schema: {}", e),
})?;
let batch =
RecordBatch::try_new(schema.clone(), vec![id_array, vector_array]).map_err(|e| {
GraphRAGError::Config {
message: format!("Failed to create record batch: {}", e),
}
})?;
let batches = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema);
self.table
.add(Box::new(batches))
.execute()
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to store embeddings batch: {}", e),
})?;
#[cfg(feature = "tracing")]
tracing::debug!("Stored {} embeddings in batch", embeddings.len());
Ok(())
}
#[cfg(feature = "lancedb")]
pub async fn get_embedding(&self, id: &str) -> Result<Option<Vec<f32>>> {
use arrow_array::cast::AsArray;
use futures::stream::TryStreamExt;
let results = self
.table
.query()
.only_if(format!("id = '{}'", id))
.execute()
.await
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Failed to query by ID: {}", e),
})?
.try_collect::<Vec<_>>()
.await
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Failed to collect results: {}", e),
})?;
for batch in results {
if batch.num_rows() == 0 {
continue;
}
let vector_array = batch.column(1).as_fixed_size_list();
if let Some(values) = vector_array.value(0).as_primitive_opt::<Float32Type>() {
return Ok(Some(values.values().to_vec()));
}
}
Ok(None)
}
#[cfg(feature = "lancedb")]
pub async fn count(&self) -> Result<usize> {
self.table
.count_rows(None)
.await
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to count rows: {}", e),
})
}
#[cfg(not(feature = "lancedb"))]
pub async fn new(_path: PathBuf, _config: LanceConfig) -> Result<Self> {
Err(GraphRAGError::Config {
message: "lancedb feature not enabled".to_string(),
})
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub score: f32,
pub embedding: Vec<f32>,
}
#[cfg(feature = "lancedb")]
fn create_empty_batch(schema: SchemaRef) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
let list_size = match schema.field(1).data_type() {
DataType::FixedSizeList(_, size) => *size,
_ => {
return Err(GraphRAGError::Config {
message: "Expected FixedSizeList data type for vector field".to_string(),
})
},
};
let empty_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(Vec::<String>::new())),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
std::iter::empty::<Option<Vec<Option<f32>>>>(),
list_size,
),
),
],
)
.map_err(|e| GraphRAGError::Config {
message: format!("Failed to create empty batch: {}", e),
})?;
let reader = RecordBatchIterator::new(vec![Ok(empty_batch)].into_iter(), schema);
Ok(Box::new(reader))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lance_config_default() {
let config = LanceConfig::default();
assert_eq!(config.dimension, 768);
}
}