Skip to main content

trueno_rag/sqlite/
fts.rs

1//! FTS5 full-text search with BM25 ranking.
2//!
3//! Implements search over the `chunks_fts` virtual table using SQLite's
4//! built-in `bm25()` ranking function (Robertson & Zaragoza, 2009).
5
6use crate::Result;
7use rusqlite::Connection;
8
9/// A single FTS5 search result with BM25 score.
10#[derive(Debug, Clone)]
11pub struct FtsResult {
12    /// Chunk ID from the chunks table.
13    pub chunk_id: String,
14    /// Document ID the chunk belongs to.
15    pub doc_id: String,
16    /// Chunk text content.
17    pub content: String,
18    /// BM25 relevance score (lower is more relevant in SQLite's convention).
19    /// We negate it so higher = more relevant.
20    pub score: f64,
21    /// Chunk position within its document.
22    pub position: i64,
23}
24
25/// Escape user input for FTS5 MATCH syntax.
26///
27/// FTS5 treats certain characters as operators (AND, OR, NOT, *, ^, NEAR).
28/// We wrap each token in double quotes to treat them as literal terms.
29fn escape_fts5_query(query: &str) -> String {
30    query
31        .split_whitespace()
32        .filter(|token| !token.is_empty())
33        .map(|token| {
34            // Strip any existing quotes and re-wrap
35            let clean = token.replace('"', "");
36            if clean.is_empty() {
37                return String::new();
38            }
39            format!("\"{clean}\"")
40        })
41        .filter(|s| !s.is_empty())
42        .collect::<Vec<_>>()
43        .join(" ")
44}
45
46/// Search the FTS5 index with BM25 ranking.
47///
48/// Returns up to `k` results ordered by descending relevance.
49/// SQLite's `bm25()` returns negative scores where more negative = more
50/// relevant, so we negate to produce positive scores where higher = better.
51pub fn search(conn: &Connection, query: &str, k: usize) -> Result<Vec<FtsResult>> {
52    let escaped = escape_fts5_query(query);
53    if escaped.is_empty() {
54        return Ok(Vec::new());
55    }
56
57    let mut stmt = conn
58        .prepare_cached(
59            "SELECT c.id, c.doc_id, c.content, -bm25(chunks_fts) AS score, c.position
60             FROM chunks_fts
61             JOIN chunks c ON chunks_fts.rowid = c.rowid
62             WHERE chunks_fts MATCH ?1
63             ORDER BY score DESC
64             LIMIT ?2",
65        )
66        .map_err(|e| crate::Error::Query(format!("Failed to prepare FTS5 search: {e}")))?;
67
68    let results = stmt
69        .query_map(rusqlite::params![escaped, k as i64], |row| {
70            Ok(FtsResult {
71                chunk_id: row.get(0)?,
72                doc_id: row.get(1)?,
73                content: row.get(2)?,
74                score: row.get(3)?,
75                position: row.get(4)?,
76            })
77        })
78        .map_err(|e| crate::Error::Query(format!("FTS5 search failed: {e}")))?
79        .collect::<std::result::Result<Vec<_>, _>>()
80        .map_err(|e| crate::Error::Query(format!("FTS5 result mapping failed: {e}")))?;
81
82    Ok(results)
83}
84
85/// Optimize the FTS5 index by merging segments.
86///
87/// Should be called after large batch inserts, not on every query.
88pub fn optimize(conn: &Connection) -> Result<()> {
89    conn.execute("INSERT INTO chunks_fts(chunks_fts) VALUES('optimize')", [])
90        .map_err(|e| crate::Error::Query(format!("FTS5 optimize failed: {e}")))?;
91    Ok(())
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::sqlite::schema;
98
99    fn setup() -> Connection {
100        let conn = Connection::open_in_memory().unwrap();
101        schema::initialize(&conn).unwrap();
102        conn
103    }
104
105    fn insert_chunk(conn: &Connection, doc_id: &str, chunk_id: &str, content: &str, pos: i64) {
106        // Ensure document exists
107        conn.execute("INSERT OR IGNORE INTO documents (id, content) VALUES (?1, '')", [doc_id])
108            .unwrap();
109        conn.execute(
110            "INSERT INTO chunks (id, doc_id, content, position) VALUES (?1, ?2, ?3, ?4)",
111            rusqlite::params![chunk_id, doc_id, content, pos],
112        )
113        .unwrap();
114    }
115
116    #[test]
117    fn test_search_returns_results() {
118        let conn = setup();
119        insert_chunk(&conn, "doc1", "c1", "SIMD vector operations for tensor math", 0);
120        insert_chunk(&conn, "doc1", "c2", "GPU kernel dispatch and scheduling", 1);
121
122        let results = search(&conn, "SIMD tensor", 10).unwrap();
123        assert!(!results.is_empty());
124        assert_eq!(results[0].chunk_id, "c1");
125    }
126
127    #[test]
128    fn test_search_bm25_ordering() {
129        let conn = setup();
130        insert_chunk(&conn, "d1", "c1", "machine learning algorithms", 0);
131        insert_chunk(
132            &conn,
133            "d2",
134            "c2",
135            "machine learning machine learning machine learning deep learning",
136            0,
137        );
138        insert_chunk(&conn, "d3", "c3", "cooking recipes for dinner", 0);
139
140        let results = search(&conn, "machine learning", 10).unwrap();
141        // c2 has higher TF for "machine learning" so should rank higher
142        assert!(results.len() >= 2);
143        assert_eq!(results[0].chunk_id, "c2");
144        assert_eq!(results[1].chunk_id, "c1");
145        // "cooking recipes" should not match
146        assert!(results.iter().all(|r| r.chunk_id != "c3"));
147    }
148
149    #[test]
150    fn test_search_empty_query() {
151        let conn = setup();
152        insert_chunk(&conn, "d1", "c1", "some content", 0);
153        let results = search(&conn, "", 10).unwrap();
154        assert!(results.is_empty());
155    }
156
157    #[test]
158    fn test_search_no_matches() {
159        let conn = setup();
160        insert_chunk(&conn, "d1", "c1", "SIMD vector operations", 0);
161        let results = search(&conn, "cryptocurrency blockchain", 10).unwrap();
162        assert!(results.is_empty());
163    }
164
165    #[test]
166    fn test_escape_fts5_query_special_chars() {
167        // Should wrap tokens in quotes to prevent FTS5 operator interpretation
168        let escaped = escape_fts5_query("hello AND world");
169        assert!(escaped.contains("\"hello\""));
170        assert!(escaped.contains("\"AND\""));
171        assert!(escaped.contains("\"world\""));
172    }
173
174    #[test]
175    fn test_porter_stemming() {
176        let conn = setup();
177        insert_chunk(&conn, "d1", "c1", "tokenizer tokenization tokenizing", 0);
178
179        // "tokenize" should match via Porter stemming conflation
180        let results = search(&conn, "tokenize", 10).unwrap();
181        assert!(!results.is_empty(), "Porter stemmer should conflate 'tokenize' variants");
182    }
183
184    #[test]
185    fn test_optimize_does_not_error() {
186        let conn = setup();
187        insert_chunk(&conn, "d1", "c1", "some content", 0);
188        optimize(&conn).unwrap();
189    }
190
191    #[test]
192    fn test_scores_are_positive() {
193        let conn = setup();
194        insert_chunk(&conn, "d1", "c1", "machine learning algorithms", 0);
195        let results = search(&conn, "machine learning", 10).unwrap();
196        assert!(!results.is_empty());
197        for r in &results {
198            assert!(r.score > 0.0, "Negated BM25 scores should be positive");
199        }
200    }
201}