use crate::Result;
use rusqlite::Connection;
#[derive(Debug, Clone)]
pub struct FtsResult {
pub chunk_id: String,
pub doc_id: String,
pub content: String,
pub score: f64,
pub position: i64,
}
fn escape_fts5_query(query: &str) -> String {
query
.split_whitespace()
.filter(|token| !token.is_empty())
.map(|token| {
let clean = token.replace('"', "");
if clean.is_empty() {
return String::new();
}
format!("\"{clean}\"")
})
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join(" ")
}
pub fn search(conn: &Connection, query: &str, k: usize) -> Result<Vec<FtsResult>> {
let escaped = escape_fts5_query(query);
if escaped.is_empty() {
return Ok(Vec::new());
}
let mut stmt = conn
.prepare_cached(
"SELECT c.id, c.doc_id, c.content, -bm25(chunks_fts) AS score, c.position
FROM chunks_fts
JOIN chunks c ON chunks_fts.rowid = c.rowid
WHERE chunks_fts MATCH ?1
ORDER BY score DESC
LIMIT ?2",
)
.map_err(|e| crate::Error::Query(format!("Failed to prepare FTS5 search: {e}")))?;
let results = stmt
.query_map(rusqlite::params![escaped, k as i64], |row| {
Ok(FtsResult {
chunk_id: row.get(0)?,
doc_id: row.get(1)?,
content: row.get(2)?,
score: row.get(3)?,
position: row.get(4)?,
})
})
.map_err(|e| crate::Error::Query(format!("FTS5 search failed: {e}")))?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| crate::Error::Query(format!("FTS5 result mapping failed: {e}")))?;
Ok(results)
}
pub fn optimize(conn: &Connection) -> Result<()> {
conn.execute("INSERT INTO chunks_fts(chunks_fts) VALUES('optimize')", [])
.map_err(|e| crate::Error::Query(format!("FTS5 optimize failed: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sqlite::schema;
fn setup() -> Connection {
let conn = Connection::open_in_memory().unwrap();
schema::initialize(&conn).unwrap();
conn
}
fn insert_chunk(conn: &Connection, doc_id: &str, chunk_id: &str, content: &str, pos: i64) {
conn.execute("INSERT OR IGNORE INTO documents (id, content) VALUES (?1, '')", [doc_id])
.unwrap();
conn.execute(
"INSERT INTO chunks (id, doc_id, content, position) VALUES (?1, ?2, ?3, ?4)",
rusqlite::params![chunk_id, doc_id, content, pos],
)
.unwrap();
}
#[test]
fn test_search_returns_results() {
let conn = setup();
insert_chunk(&conn, "doc1", "c1", "SIMD vector operations for tensor math", 0);
insert_chunk(&conn, "doc1", "c2", "GPU kernel dispatch and scheduling", 1);
let results = search(&conn, "SIMD tensor", 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].chunk_id, "c1");
}
#[test]
fn test_search_bm25_ordering() {
let conn = setup();
insert_chunk(&conn, "d1", "c1", "machine learning algorithms", 0);
insert_chunk(
&conn,
"d2",
"c2",
"machine learning machine learning machine learning deep learning",
0,
);
insert_chunk(&conn, "d3", "c3", "cooking recipes for dinner", 0);
let results = search(&conn, "machine learning", 10).unwrap();
assert!(results.len() >= 2);
assert_eq!(results[0].chunk_id, "c2");
assert_eq!(results[1].chunk_id, "c1");
assert!(results.iter().all(|r| r.chunk_id != "c3"));
}
#[test]
fn test_search_empty_query() {
let conn = setup();
insert_chunk(&conn, "d1", "c1", "some content", 0);
let results = search(&conn, "", 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_no_matches() {
let conn = setup();
insert_chunk(&conn, "d1", "c1", "SIMD vector operations", 0);
let results = search(&conn, "cryptocurrency blockchain", 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_escape_fts5_query_special_chars() {
let escaped = escape_fts5_query("hello AND world");
assert!(escaped.contains("\"hello\""));
assert!(escaped.contains("\"AND\""));
assert!(escaped.contains("\"world\""));
}
#[test]
fn test_porter_stemming() {
let conn = setup();
insert_chunk(&conn, "d1", "c1", "tokenizer tokenization tokenizing", 0);
let results = search(&conn, "tokenize", 10).unwrap();
assert!(!results.is_empty(), "Porter stemmer should conflate 'tokenize' variants");
}
#[test]
fn test_optimize_does_not_error() {
let conn = setup();
insert_chunk(&conn, "d1", "c1", "some content", 0);
optimize(&conn).unwrap();
}
#[test]
fn test_scores_are_positive() {
let conn = setup();
insert_chunk(&conn, "d1", "c1", "machine learning algorithms", 0);
let results = search(&conn, "machine learning", 10).unwrap();
assert!(!results.is_empty());
for r in &results {
assert!(r.score > 0.0, "Negated BM25 scores should be positive");
}
}
}