use anyhow::Result;
use super::super::database::MemoryDatabase;
#[derive(Debug, Clone)]
pub struct Bm25Hit {
pub rowid: i64,
pub id: String,
pub score: f64,
}
pub fn search_bm25(db: &MemoryDatabase, query: &str, limit: usize) -> Result<Vec<Bm25Hit>> {
let conn = db.conn();
let fts_query = if query.contains('"')
|| query.contains("AND")
|| query.contains("OR")
|| query.contains("NOT")
{
query.to_string()
} else {
let mut tokens = Vec::new();
for word in query.split_whitespace() {
if word.chars().all(|c| c.is_ascii_alphanumeric()) {
if word.len() >= 2 {
tokens.push(word.to_string());
}
} else {
let has_cjk = !word.is_ascii();
if has_cjk {
for ch in word.chars() {
if !ch.is_ascii() && ch.is_alphabetic() {
tokens.push(ch.to_string());
}
}
}
if word.len() >= 2 {
tokens.push(word.to_string());
}
}
}
if tokens.is_empty() {
return Ok(Vec::new());
}
tokens.join(" OR ")
};
let sql = "SELECT rowid, id, -bm25(memories_fts) as score
FROM memories_fts
WHERE memories_fts MATCH ?1
ORDER BY score DESC
LIMIT ?2"
.to_string();
let mut stmt = match conn.prepare(&sql) {
Ok(s) => s,
Err(e) => {
tracing::debug!(query = %fts_query, error = %e, "FTS5 query parse error");
return Ok(Vec::new());
}
};
let rows = match stmt.query_map(rusqlite::params![fts_query, limit], |row| {
Ok(Bm25Hit {
rowid: row.get(0)?,
id: row.get(1)?,
score: row.get(2)?,
})
}) {
Ok(r) => r,
Err(e) => {
tracing::debug!(query = %fts_query, error = %e, "FTS5 query execution error");
return Ok(Vec::new());
}
};
let mut results = Vec::new();
for row in rows {
results.push(row?);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_basic_search() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
{
let conn = db.conn();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('bm-test-1', 'fact', 'Rust is a systems programming language', 0.6, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('bm-test-2', 'fact', 'Python is great for data science', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
}
let results = search_bm25(&db, "Rust programming", 10).unwrap();
assert!(!results.is_empty(), "BM25 should find results");
assert_eq!(results[0].id, "bm-test-1");
}
#[test]
fn test_bm25_no_results() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let results = search_bm25(&db, "nonexistent", 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_bm25_korean() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
{
let conn = db.conn();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('kr-bm-1', 'fact', '한국어 메모리 테스트입니다', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
}
let results = search_bm25(&db, "한국어", 10).unwrap();
assert!(!results.is_empty(), "Korean BM25 should find results");
}
#[test]
fn test_bm25_limit() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
{
let conn = db.conn();
for i in 0..20 {
conn.execute(
&format!(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('limit-{}', 'fact', 'test content number {}', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
i, i
),
[],
).unwrap();
}
}
let results = search_bm25(&db, "test content", 5).unwrap();
assert!(results.len() <= 5, "Should respect limit");
}
}