use std::str::FromStr;
use std::sync::Arc;
use async_trait::async_trait;
use sqlx::{Row, SqlitePool};
use crate::error::{Error, Result};
use crate::index::lexical_trait::{LexicalIndex, LexicalIndexCapabilities};
use crate::index::vector_trait::VectorScope;
use crate::memory::MemoryId;
use crate::partition::PartitionPath;
use crate::summary::SummaryId;
#[derive(Debug, Clone)]
pub struct Fts5Index {
pool: Arc<SqlitePool>,
}
impl Fts5Index {
#[must_use]
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
}
}
#[async_trait]
impl LexicalIndex for Fts5Index {
async fn upsert_memory(
&self,
id: &MemoryId,
partition_path: &PartitionPath,
content: &str,
) -> Result<()> {
sqlx::query("DELETE FROM memory_fts WHERE memory_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::upsert_memory delete", e))?;
sqlx::query("INSERT INTO memory_fts(content, memory_id, partition_path) VALUES (?, ?, ?)")
.bind(content)
.bind(id.to_string())
.bind(partition_path.as_str())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::upsert_memory insert", e))?;
Ok(())
}
async fn upsert_summary(&self, id: &SummaryId, parent_path: &str, content: &str) -> Result<()> {
sqlx::query("DELETE FROM summary_fts WHERE summary_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::upsert_summary delete", e))?;
sqlx::query("INSERT INTO summary_fts(content, summary_id, parent_path) VALUES (?, ?, ?)")
.bind(content)
.bind(id.to_string())
.bind(parent_path)
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::upsert_summary insert", e))?;
Ok(())
}
async fn delete_memory(&self, id: &MemoryId) -> Result<()> {
sqlx::query("DELETE FROM memory_fts WHERE memory_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::delete_memory", e))?;
Ok(())
}
async fn delete_summary(&self, id: &SummaryId) -> Result<()> {
sqlx::query("DELETE FROM summary_fts WHERE summary_id = ?")
.bind(id.to_string())
.execute(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::delete_summary", e))?;
Ok(())
}
async fn search_memory(
&self,
query: &str,
k: u32,
scope: VectorScope,
) -> Result<Vec<(MemoryId, f32)>> {
let scope_clause: (&'static str, Vec<String>) = match &scope {
VectorScope::Tenant => ("", Vec::new()),
VectorScope::Partition(p) => (" AND partition_path = ?", vec![p.as_str().to_string()]),
VectorScope::PartitionPrefix(prefix) => (
" AND (partition_path = ? OR partition_path LIKE ?)",
vec![prefix.clone(), format!("{prefix}/%")],
),
};
let sql = format!(
"SELECT memory_id AS id, -bm25(memory_fts) AS score \
FROM memory_fts \
WHERE memory_fts MATCH ?{scope_where} \
ORDER BY score DESC LIMIT ?",
scope_where = scope_clause.0,
);
let mut q = sqlx::query(&sql).bind(query);
for s in &scope_clause.1 {
q = q.bind(s);
}
q = q.bind(i64::from(k));
let rows = q
.fetch_all(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::search_memory", e))?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let id_s: String = row
.try_get("id")
.map_err(|e| Error::metadata("read memory_id", e))?;
let score: f32 = row
.try_get::<f64, _>("score")
.map_err(|e| Error::metadata("read score", e))? as f32;
let id = MemoryId::from_str(&id_s)
.map_err(|_| Error::metadata("parse memory_id", std::io::Error::other("bad id")))?;
out.push((id, score));
}
Ok(out)
}
async fn search_summary(
&self,
query: &str,
k: u32,
parent_path_prefix: &str,
) -> Result<Vec<(SummaryId, f32)>> {
let prefix_eq = parent_path_prefix.to_string();
let prefix_like = format!("{parent_path_prefix}/%");
let sql = "SELECT summary_id AS id, -bm25(summary_fts) AS score \
FROM summary_fts \
WHERE summary_fts MATCH ? \
AND (parent_path = ? OR parent_path LIKE ?) \
ORDER BY score DESC LIMIT ?";
let rows = sqlx::query(sql)
.bind(query)
.bind(&prefix_eq)
.bind(&prefix_like)
.bind(i64::from(k))
.fetch_all(self.pool.as_ref())
.await
.map_err(|e| Error::metadata("Fts5Index::search_summary", e))?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let id_s: String = row
.try_get("id")
.map_err(|e| Error::metadata("read summary_id", e))?;
let score: f32 = row
.try_get::<f64, _>("score")
.map_err(|e| Error::metadata("read score", e))? as f32;
let id = SummaryId::from_str(&id_s).map_err(|_| {
Error::metadata("parse summary_id", std::io::Error::other("bad id"))
})?;
out.push((id, score));
}
Ok(out)
}
fn id(&self) -> &str {
"sqlite-fts5"
}
fn capabilities(&self) -> LexicalIndexCapabilities {
LexicalIndexCapabilities::default()
}
}