use super::{Database, Error, Memory};
pub type Result<T> = std::result::Result<T, Error>;
impl Database {
pub fn initialize_fts(&self) -> Result<()> {
let fts_exists: bool = self
.conn
.query_row(
"SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'",
[],
|_row| Ok(true),
)
.unwrap_or(false);
if fts_exists {
let has_project_id: bool = self.conn.query_row(
"SELECT COUNT(*) FROM pragma_table_info('memories_fts') WHERE name = 'project_id'",
[],
|row| row.get::<_, i64>(0).map(|count| count > 0),
)?;
if !has_project_id {
let tx = self.conn.unchecked_transaction()?;
let memories_exists: bool = tx.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories'",
[],
|row| row.get::<_, i64>(0).map(|count| count > 0),
)?;
if !memories_exists {
return Err(Error::Sqlite(
"External content table 'memories' does not exist".to_string(),
));
}
let memory_count: i64 =
tx.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
tx.execute_batch(
"DROP TABLE IF EXISTS memories_fts;
DROP TRIGGER IF EXISTS memories_fts_insert;
DROP TRIGGER IF EXISTS memories_fts_delete;
DROP TRIGGER IF EXISTS memories_fts_update;
CREATE VIRTUAL TABLE memories_fts USING fts5(
content,
project_id UNINDEXED,
tokenize='porter unicode61',
content_rowid='rowid',
content='memories'
);
CREATE TRIGGER memories_fts_insert AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, content, project_id)
VALUES (new.rowid, new.content, new.project_id);
END;
CREATE TRIGGER memories_fts_delete AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, content, project_id)
VALUES('delete', old.rowid, old.content, old.project_id);
END;
CREATE TRIGGER memories_fts_update AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, content, project_id)
VALUES('delete', old.rowid, old.content, old.project_id);
INSERT INTO memories_fts(rowid, content, project_id)
VALUES (new.rowid, new.content, new.project_id);
END;
INSERT INTO memories_fts(rowid, content, project_id)
SELECT rowid, content, project_id FROM memories;",
)
.map_err(|e| Error::Sqlite(format!("FTS5 schema migration failed: {}", e)))?;
let fts_count: i64 =
tx.query_row("SELECT COUNT(*) FROM memories_fts", [], |row| row.get(0))?;
if fts_count != memory_count {
tx.rollback()?;
return Err(Error::Sqlite(format!(
"FTS5 migration incomplete: expected {} rows, got {} rows",
memory_count, fts_count
)));
}
tx.commit()?;
}
}
Ok(())
}
pub fn search_bm25(
&self,
query: &str,
project_id: &str,
limit: usize,
memory_types: Option<&[&str]>,
statuses: Option<&[&str]>,
) -> Result<Vec<Memory>> {
super::search::validate_limit(limit)?;
if !self.is_fts_initialized()? {
self.initialize_fts()?;
}
let escaped_query = Self::escape_fts_query(query);
if escaped_query.is_empty() {
return Ok(Vec::new());
}
let mut where_clauses = vec![
"memories_fts MATCH ?1".to_string(),
"m.project_id = ?2".to_string(),
];
let mut param_index = 3usize;
if let Some(statuses) = statuses {
if !statuses.is_empty() {
let placeholders: Vec<String> = (0..statuses.len())
.map(|i| format!("?{}", param_index + i))
.collect();
where_clauses.push(format!("m.status IN ({})", placeholders.join(", ")));
param_index += statuses.len();
}
} else {
where_clauses.push(format!("m.status = ?{}", param_index));
param_index += 1;
}
if let Some(types) = memory_types {
if !types.is_empty() {
let placeholders: Vec<String> = (0..types.len())
.map(|i| format!("?{}", param_index + i))
.collect();
where_clauses.push(format!("m.type IN ({})", placeholders.join(", ")));
param_index += types.len();
}
}
let where_clause = where_clauses.join(" AND ");
let sql = format!(
r#"
SELECT m.id, m.project_id, m.content, m.metadata, m.embedding, m.created_at, m.updated_at, m.type, m.status, m.superseded_by,
bm25(memories_fts) as bm25_score
FROM memories_fts
JOIN memories m ON m.rowid = memories_fts.rowid
WHERE {}
ORDER BY bm25(memories_fts)
LIMIT ?{}
"#,
where_clause, param_index
);
let mut stmt = self.conn.prepare(&sql)?;
let mut params: Vec<&dyn rusqlite::ToSql> = vec![&escaped_query, &project_id];
if let Some(statuses) = statuses {
if statuses.is_empty() {
} else {
for s in statuses {
params.push(s);
}
}
} else {
params.push(&"active");
}
if let Some(types) = memory_types {
for t in types {
params.push(t);
}
}
let limit_i64 = limit as i64;
params.push(&limit_i64);
let memories: rusqlite::Result<Vec<Memory>> = stmt
.query_map(params.as_slice(), |row| {
let blob: Vec<u8> = row.get(4)?;
let embedding = super::embedding::blob_to_vec(&blob).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
4,
rusqlite::types::Type::Blob,
Box::new(e),
)
})?;
Ok(Memory {
id: row.get(0)?,
project_id: row.get(1)?,
content: row.get(2)?,
metadata: row.get(3)?,
embedding,
created_at: row.get(5)?,
updated_at: row.get(6)?,
memory_type: row.get(7)?,
status: row.get(8)?,
superseded_by: row.get(9)?,
similarity: Some(row.get::<_, f64>(10)?),
})
})?
.collect();
Ok(memories?)
}
fn is_fts_initialized(&self) -> Result<bool> {
let count: i64 = self.conn.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories_fts'",
[],
|row| row.get(0),
)?;
if count == 0 {
return Ok(false);
}
let fts_count: i64 =
self.conn
.query_row("SELECT COUNT(*) FROM memories_fts", [], |row| row.get(0))?;
Ok(fts_count > 0)
}
fn escape_fts_query(query: &str) -> String {
query
.split_whitespace()
.filter(|word| !word.is_empty())
.map(|word| {
let escaped = word.replace('\\', "\\\\").replace('"', "\"\"");
format!("\"{}\"", escaped)
})
.collect::<Vec<_>>()
.join(" ")
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_db() -> Database {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.db");
let db = Database::open(&path).unwrap();
std::mem::forget(dir);
db
}
#[test]
fn test_fts5_search() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert(
"proj1",
"rust programming",
&embedding,
None,
"fact",
"active",
)
.unwrap();
db.insert("proj1", "python data", &embedding, None, "fact", "active")
.unwrap();
let results = db.search_bm25("rust", "proj1", 10, None, None).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("rust"));
}
#[test]
fn test_fts5_triggers() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let id = db
.insert("proj1", "original text", &embedding, None, "fact", "active")
.unwrap();
assert_eq!(
db.search_bm25("original", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
db.update(
&id,
Some("updated text"),
Some(&embedding.as_slice()),
None,
None,
None,
)
.unwrap();
assert_eq!(
db.search_bm25("updated", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
db.delete(&id).unwrap();
assert_eq!(
db.search_bm25("updated", "proj1", 10, None, None)
.unwrap()
.len(),
0
);
}
#[test]
fn test_fts5_limit_validation() {
let db = create_test_db();
assert!(db.search_bm25("test", "proj1", 0, None, None).is_err());
assert!(
db.search_bm25("test", "proj1", 100_000, None, None)
.is_err()
);
}
#[test]
fn test_fts5_special_characters() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert(
"proj1",
"test with \"quotes\"",
&embedding,
None,
"fact",
"active",
)
.unwrap();
db.insert(
"proj1",
"test with 'apos'",
&embedding,
None,
"fact",
"active",
)
.unwrap();
db.insert(
"proj1",
"test with\\slash",
&embedding,
None,
"fact",
"active",
)
.unwrap();
assert_eq!(
db.search_bm25("test with \"quotes\"", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
assert_eq!(
db.search_bm25("test with\\slash", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
}
#[test]
fn test_fts5_empty_query() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert("proj1", "test content", &embedding, None, "fact", "active")
.unwrap();
let results = db.search_bm25("", "proj1", 10, None, None).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_fts5_phrase_search() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert(
"proj1",
"rust programming",
&embedding,
None,
"fact",
"active",
)
.unwrap();
db.insert(
"proj1",
"rust error handling",
&embedding,
None,
"fact",
"active",
)
.unwrap();
let results = db
.search_bm25("rust programming", "proj1", 10, None, None)
.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("programming"));
}
#[test]
fn test_fts5_unicode_text() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert(
"proj1",
"café résumé 日本語",
&embedding,
None,
"fact",
"active",
)
.unwrap();
let results = db.search_bm25("café", "proj1", 10, None, None).unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("café"));
}
#[test]
fn test_initialize_fts_migration() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.db");
std::mem::forget(dir);
{
let db = Database::open(&path).unwrap();
db.insert(
"proj1",
"before migration",
&vec![0.1f32; 384],
None,
"fact",
"active",
)
.unwrap();
}
{
let db = Database::open(&path).unwrap();
db.initialize_fts().unwrap();
assert_eq!(
db.search_bm25("before", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
}
}
#[test]
fn test_initialize_fts_consistency_handling() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.db");
std::mem::forget(dir);
{
let db = Database::open(&path).unwrap();
db.insert("proj1", "first", &vec![0.1f32; 384], None, "fact", "active")
.unwrap();
db.insert(
"proj1",
"second",
&vec![0.1f32; 384],
None,
"fact",
"active",
)
.unwrap();
db.insert("proj1", "third", &vec![0.1f32; 384], None, "fact", "active")
.unwrap();
}
{
let db = Database::open(&path).unwrap();
db.initialize_fts().unwrap();
let fts_count: i64 = db
.conn()
.query_row("SELECT COUNT(*) FROM memories_fts", [], |row| row.get(0))
.unwrap();
assert_eq!(fts_count, 3);
}
{
let db = Database::open(&path).unwrap();
db.initialize_fts().unwrap();
assert_eq!(
db.search_bm25("first", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
assert_eq!(
db.search_bm25("second", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
assert_eq!(
db.search_bm25("third", "proj1", 10, None, None)
.unwrap()
.len(),
1
);
}
}
#[test]
fn test_bm25_search_filters_by_status() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert("proj1", "test content", &embedding, None, "fact", "active")
.unwrap();
db.insert(
"proj1",
"test content",
&embedding,
None,
"fact",
"superseded",
)
.unwrap();
db.insert(
"proj1",
"test content",
&embedding,
None,
"fact",
"candidate",
)
.unwrap();
let results = db
.search_bm25("test", "proj1", 10, None, Some(&["active"]))
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].status, "active");
let results = db
.search_bm25("test", "proj1", 10, None, Some(&["active", "candidate"]))
.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_bm25_search_default_excludes_non_active() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert("proj1", "active memory", &embedding, None, "fact", "active")
.unwrap();
db.insert(
"proj1",
"candidate memory",
&embedding,
None,
"fact",
"candidate",
)
.unwrap();
db.insert(
"proj1",
"superseded memory",
&embedding,
None,
"fact",
"superseded",
)
.unwrap();
let results = db.search_bm25("memory", "proj1", 10, None, None).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].status, "active");
}
#[test]
fn test_bm25_search_filters_by_type() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert(
"proj1",
"a fact about rust programming",
&embedding,
None,
"fact",
"active",
)
.unwrap();
db.insert(
"proj1",
"a preference for python data",
&embedding,
None,
"preference",
"active",
)
.unwrap();
db.insert(
"proj1",
"a procedure for testing code",
&embedding,
None,
"procedure",
"active",
)
.unwrap();
let results = db
.search_bm25("rust", "proj1", 10, Some(&["fact"]), None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].memory_type, "fact");
let results = db
.search_bm25("python", "proj1", 10, Some(&["fact", "preference"]), None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].memory_type, "preference");
}
}