use anyhow::{bail, Context, Result};
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures_util::StreamExt;
use lancedb::{index::Index, query::ExecutableQuery, query::QueryBase, Table};
use std::sync::Arc;
use std::path::PathBuf;
use crate::paths::Paths;
const MIN_ROWS_FOR_INDEX: usize = 300;
#[derive(Debug, Clone)]
pub struct Entry {
pub id: String,
pub content: String,
}
fn get_db_path() -> PathBuf {
Paths::get_insights_db()
}
fn validate_id(id: &str) -> Result<()> {
if id.is_empty() || id.len() > 100 {
bail!("Invalid id: must be non-empty and <= 100 characters");
}
if id.contains('\'') || id.contains(';') || id.contains('"') {
bail!("Invalid id: contains forbidden characters");
}
Ok(())
}
pub struct LanceDb {
table: Table,
indexed: bool,
vector_dim: usize,
}
impl LanceDb {
pub async fn new(table_name: &str, vector_dim: usize) -> Result<Self> {
let db_path = get_db_path();
let db = lancedb::connect(db_path.to_str().context("Invalid DB path")?)
.execute()
.await?;
let schema = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, false)),
vector_dim as i32,
),
false,
),
]);
let schema_ref = Arc::new(schema);
let names = db.table_names().execute().await?;
let (table, indexed) = if names.contains(&table_name.to_string()) {
let tbl = db.open_table(table_name).execute().await?;
let existing_schema = tbl.schema().await?;
if let Ok(vector_field) = existing_schema.field_with_name("vector")
&& let DataType::FixedSizeList(_, existing_dim) = vector_field.data_type()
&& *existing_dim as usize != vector_dim {
bail!(
"Table '{}' exists with vector dimension {} but current embedding model produces dimension {}. \
Please delete the table and recreate it to switch models.",
table_name, existing_dim, vector_dim
);
}
let indices = tbl.list_indices().await?;
let indexed = !indices.is_empty();
(tbl, indexed)
} else {
let tbl = db
.create_table(table_name, vec![RecordBatch::new_empty(schema_ref.clone())])
.execute()
.await?;
(tbl, false)
};
Ok(Self { table, indexed, vector_dim })
}
pub async fn post(&self, id: &str, content: &str, vector: Vec<f32>) -> Result<String> {
if vector.len() != self.vector_dim {
bail!("vector dimension must be {}, got {}", self.vector_dim, vector.len());
}
validate_id(id)?;
if self.exists_by_content(content).await? {
return Ok(id.to_string());
}
let schema = self.table.schema().await?;
let vector_array = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
Arc::new(Float32Array::from(vector)),
None,
)?;
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec![id.to_string()])),
Arc::new(StringArray::from(vec![content])),
Arc::new(vector_array),
],
)?;
self.table.add(vec![batch]).execute().await?;
if !self.indexed {
let count = self.table.count_rows(None).await?;
if count >= MIN_ROWS_FOR_INDEX {
self.table
.create_index(&["vector"], Index::Auto)
.execute()
.await?;
}
}
Ok(id.to_string())
}
pub async fn get(&self, query_vector: &[f32], limit: usize) -> Result<Vec<Entry>> {
if query_vector.len() != self.vector_dim {
bail!("query vector dimension must be {}, got {}", self.vector_dim, query_vector.len());
}
let stream = self
.table
.query()
.nearest_to(query_vector)?
.limit(limit)
.execute()
.await?;
let mut entries = Vec::new();
let mut stream = stream;
while let Some(batch_result) = stream.next().await {
let batch: RecordBatch = batch_result?;
let id_array = batch.column(0);
let content_array = batch.column(1);
for i in 0..batch.num_rows() {
let id = id_array
.as_any()
.downcast_ref::<StringArray>()
.map(|arr| arr.value(i).to_string())
.unwrap_or_default();
let content = content_array
.as_any()
.downcast_ref::<StringArray>()
.map(|arr| arr.value(i).to_string())
.unwrap_or_default();
entries.push(Entry { id, content });
}
}
Ok(entries)
}
pub async fn exists_by_content(&self, content: &str) -> Result<bool> {
let escaped = content.replace('\'', "''");
let count = self.table.count_rows(Some(format!("content = '{}'", escaped))).await?;
Ok(count > 0)
}
pub async fn patch(&self, id: &str, new_content: &str, new_vector: Vec<f32>) -> Result<()> {
if new_vector.len() != self.vector_dim {
bail!("vector dimension must be {}, got {}", self.vector_dim, new_vector.len());
}
if self.exists_by_content(new_content).await? {
return Ok(());
}
self.delete(id).await?;
self.post(id, new_content, new_vector).await?;
Ok(())
}
pub async fn delete(&self, id: &str) -> Result<()> {
validate_id(id)?;
self.table.delete(&format!("id = '{}'", id)).await?;
Ok(())
}
pub async fn rebuild_index(&self) -> Result<()> {
self.table
.create_index(&["vector"], Index::Auto)
.execute()
.await?;
Ok(())
}
}