use std::str::FromStr;
use std::sync::Arc;
use async_trait::async_trait;
use sqlx::{Row, SqlitePool};
use crate::attribute::AttributeValue;
use crate::error::{Error, Result};
use crate::index::vector_trait::{
DistanceMetric, VectorFilter, VectorIndex, VectorIndexCapabilities, VectorScope,
};
use crate::memory::{MemoryId, MemoryKind};
use crate::partition::PartitionPath;
use crate::summarizer::SummaryStyle;
use crate::summary::SummaryId;
#[derive(Debug, Clone)]
pub struct SqliteVecIndex {
pool: Arc<SqlitePool>,
}
impl SqliteVecIndex {
#[must_use]
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
}
}
#[async_trait]
impl VectorIndex for SqliteVecIndex {
async fn upsert_memory(
&self,
id: &MemoryId,
partition_path: &PartitionPath,
kind: Option<&MemoryKind>,
embedding: &[f32],
) -> Result<()> {
let blob: &[u8] = bytemuck::cast_slice(embedding);
let kind_tag: Option<&'static str> = kind.map(|k| k.as_persisted_str());
sqlx::query("DELETE FROM memory_vec WHERE memory_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::upsert_memory delete", e))?;
sqlx::query(
"INSERT INTO memory_vec(memory_id, partition_path, kind, embedding) \
VALUES (?, ?, ?, ?)",
)
.bind(id.to_string())
.bind(partition_path.as_str())
.bind(kind_tag)
.bind(blob)
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::upsert_memory", e))?;
Ok(())
}
async fn upsert_summary(
&self,
id: &SummaryId,
parent_path: &str,
style: &SummaryStyle,
embedding: &[f32],
) -> Result<()> {
let blob: &[u8] = bytemuck::cast_slice(embedding);
sqlx::query("DELETE FROM summary_vec WHERE summary_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::upsert_summary delete", e))?;
sqlx::query(
"INSERT INTO summary_vec(summary_id, parent_path, style, embedding) \
VALUES (?, ?, ?, ?)",
)
.bind(id.to_string())
.bind(parent_path)
.bind(style.as_str().as_ref())
.bind(blob)
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::upsert_summary", e))?;
Ok(())
}
async fn delete_memory(&self, id: &MemoryId) -> Result<()> {
sqlx::query("DELETE FROM memory_vec WHERE memory_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::delete_memory", e))?;
Ok(())
}
async fn delete_summary(&self, id: &SummaryId) -> Result<()> {
sqlx::query("DELETE FROM summary_vec WHERE summary_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::delete_summary", e))?;
Ok(())
}
async fn knn_memory(
&self,
query: &[f32],
k: u32,
scope: VectorScope,
filter: Option<&VectorFilter>,
) -> Result<Vec<(MemoryId, f32)>> {
let qblob: &[u8] = bytemuck::cast_slice(query);
let scope_clause: (&'static str, Vec<String>) = match &scope {
VectorScope::Tenant => ("", Vec::new()),
VectorScope::Partition(p) => {
(" AND mv.partition_path = ?", vec![p.as_str().to_string()])
}
VectorScope::PartitionPrefix(prefix) => (
" AND (mv.partition_path = ? OR mv.partition_path LIKE ?)",
vec![prefix.clone(), format!("{prefix}/%")],
),
};
let (filter_join, filter_where, filter_binds) = match filter {
None => ("", "", Vec::<String>::new()),
Some(f) => {
let (col, val_str) = attribute_to_filter_pair(&f.value);
(
" JOIN memory_attribute ma ON ma.memory_id = mv.memory_id",
match col {
"v_string" => " AND ma.key = ? AND ma.v_string = ?",
"v_int" => " AND ma.key = ? AND ma.v_int = ?",
"v_decimal" => " AND ma.key = ? AND ma.v_decimal = ?",
"v_timestamp" => " AND ma.key = ? AND ma.v_timestamp = ?",
"v_bool" => " AND ma.key = ? AND ma.v_bool = ?",
_ => " AND ma.key = ? AND ma.v_string = ?",
},
vec![f.key.clone(), val_str],
)
}
};
let sql = format!(
"SELECT mv.memory_id AS id, distance \
FROM memory_vec mv{filter_join} \
WHERE mv.embedding MATCH ? AND k = ?{scope_where}{filter_where} \
ORDER BY distance",
filter_join = filter_join,
scope_where = scope_clause.0,
filter_where = filter_where,
);
let mut q = sqlx::query(&sql).bind(qblob).bind(i64::from(k));
for s in &scope_clause.1 {
q = q.bind(s);
}
for s in &filter_binds {
q = q.bind(s);
}
let rows = q
.fetch_all(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::knn_memory", e))?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let id_s: String = row
.try_get("id")
.map_err(|e| Error::metadata("read memory_id", e))?;
let dist: f32 =
row.try_get::<f64, _>("distance")
.map_err(|e| Error::metadata("read distance", e))? as f32;
let id = MemoryId::from_str(&id_s)
.map_err(|_| Error::metadata("parse memory_id", std::io::Error::other("bad id")))?;
out.push((id, dist));
}
Ok(out)
}
async fn knn_summary(
&self,
query: &[f32],
k: u32,
parent_path_prefix: &str,
) -> Result<Vec<(SummaryId, f32)>> {
let qblob: &[u8] = bytemuck::cast_slice(query);
let prefix_eq = parent_path_prefix.to_string();
let prefix_like = format!("{parent_path_prefix}/%");
let sql = "SELECT summary_id AS id, distance \
FROM summary_vec \
WHERE embedding MATCH ? AND k = ? \
AND (parent_path = ? OR parent_path LIKE ?) \
ORDER BY distance";
let rows = sqlx::query(sql)
.bind(qblob)
.bind(i64::from(k))
.bind(&prefix_eq)
.bind(&prefix_like)
.fetch_all(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("SqliteVecIndex::knn_summary", e))?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let id_s: String = row
.try_get("id")
.map_err(|e| Error::metadata("read summary_id", e))?;
let dist: f32 =
row.try_get::<f64, _>("distance")
.map_err(|e| Error::metadata("read distance", e))? as f32;
let id = SummaryId::from_str(&id_s).map_err(|_| {
Error::metadata("parse summary_id", std::io::Error::other("bad id"))
})?;
out.push((id, dist));
}
Ok(out)
}
fn id(&self) -> &str {
"sqlite-vec:vec0"
}
fn capabilities(&self) -> VectorIndexCapabilities {
VectorIndexCapabilities {
knn_filtered: true,
max_dimensions: 4096,
distance_metric: DistanceMetric::CosineDistance,
}
}
}
fn attribute_to_filter_pair(v: &AttributeValue) -> (&'static str, String) {
match v {
AttributeValue::String(s) => ("v_string", s.clone()),
AttributeValue::Int(i) => ("v_int", i.to_string()),
AttributeValue::Decimal(d) => ("v_decimal", d.to_string()),
AttributeValue::Bool(b) => ("v_bool", if *b { "1".into() } else { "0".into() }),
AttributeValue::Timestamp(t) => ("v_timestamp", t.to_string()),
AttributeValue::Array(_) => ("v_string", serde_json::to_string(v).unwrap_or_default()),
}
}