Skip to main content

sqz_engine/
advanced_search.rs

1use rusqlite::{params, Connection};
2
3use crate::error::Result;
4
5/// A single search result with id, fused score, and a smart snippet.
6#[derive(Debug, Clone)]
7pub struct SearchResult {
8    pub id: String,
9    pub score: f64,
10    pub snippet: String,
11}
12
13/// Advanced search engine combining BM25 (porter stemming) and trigram
14/// substring search, merged via Reciprocal Rank Fusion.  Backed by an
15/// in-memory SQLite database with two FTS5 virtual tables.
16pub struct AdvancedSearch {
17    db: Connection,
18}
19
20// ── Schema ────────────────────────────────────────────────────────────────────
21
22const SCHEMA: &str = r#"
23CREATE TABLE IF NOT EXISTS docs (
24    id      TEXT PRIMARY KEY,
25    content TEXT NOT NULL
26);
27
28CREATE VIRTUAL TABLE IF NOT EXISTS docs_porter USING fts5(
29    id,
30    content,
31    content='docs',
32    content_rowid='rowid',
33    tokenize='porter ascii'
34);
35
36CREATE VIRTUAL TABLE IF NOT EXISTS docs_trigram USING fts5(
37    id,
38    content,
39    content='docs',
40    content_rowid='rowid',
41    tokenize='trigram'
42);
43
44-- Keep FTS tables in sync with the docs table.
45CREATE TRIGGER IF NOT EXISTS docs_ai AFTER INSERT ON docs BEGIN
46    INSERT INTO docs_porter(rowid, id, content)
47    VALUES (new.rowid, new.id, new.content);
48    INSERT INTO docs_trigram(rowid, id, content)
49    VALUES (new.rowid, new.id, new.content);
50END;
51
52CREATE TRIGGER IF NOT EXISTS docs_ad AFTER DELETE ON docs BEGIN
53    INSERT INTO docs_porter(docs_porter, rowid, id, content)
54    VALUES ('delete', old.rowid, old.id, old.content);
55    INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
56    VALUES ('delete', old.rowid, old.id, old.content);
57END;
58
59CREATE TRIGGER IF NOT EXISTS docs_au AFTER UPDATE ON docs BEGIN
60    INSERT INTO docs_porter(docs_porter, rowid, id, content)
61    VALUES ('delete', old.rowid, old.id, old.content);
62    INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
63    VALUES ('delete', old.rowid, old.id, old.content);
64    INSERT INTO docs_porter(rowid, id, content)
65    VALUES (new.rowid, new.id, new.content);
66    INSERT INTO docs_trigram(rowid, id, content)
67    VALUES (new.rowid, new.id, new.content);
68END;
69"#;
70
71// ── RRF constant ──────────────────────────────────────────────────────────────
72
73/// Reciprocal Rank Fusion smoothing constant (standard value from the
74/// original Cormack, Clarke & Buettcher paper).
75const RRF_K: f64 = 60.0;
76
77// ── Helpers ───────────────────────────────────────────────────────────────────
78
79/// Compute Levenshtein edit distance between two strings.
80fn levenshtein(a: &str, b: &str) -> usize {
81    let a_chars: Vec<char> = a.chars().collect();
82    let b_chars: Vec<char> = b.chars().collect();
83    let m = a_chars.len();
84    let n = b_chars.len();
85
86    let mut prev = (0..=n).collect::<Vec<_>>();
87    let mut curr = vec![0usize; n + 1];
88
89    for i in 1..=m {
90        curr[0] = i;
91        for j in 1..=n {
92            let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
93            curr[j] = (prev[j] + 1)
94                .min(curr[j - 1] + 1)
95                .min(prev[j - 1] + cost);
96        }
97        std::mem::swap(&mut prev, &mut curr);
98    }
99    prev[n]
100}
101
102/// Extract a smart snippet: a window of text around the first occurrence
103/// of any query term, with `…` ellipsis markers when truncated.
104fn extract_snippet(content: &str, query_terms: &[&str], window: usize) -> String {
105    let lower = content.to_lowercase();
106    // Find the earliest match position across all terms.
107    let mut best_pos: Option<usize> = None;
108    for term in query_terms {
109        if let Some(pos) = lower.find(&term.to_lowercase()) {
110            best_pos = Some(match best_pos {
111                Some(bp) => bp.min(pos),
112                None => pos,
113            });
114        }
115    }
116
117    let pos = match best_pos {
118        Some(p) => p,
119        None => 0,
120    };
121
122    let start = pos.saturating_sub(window);
123    let end = (pos + window).min(content.len());
124
125    // Snap to char boundaries.
126    let start = content.floor_char_boundary(start);
127    let end = content.ceil_char_boundary(end);
128
129    let mut snippet = String::new();
130    if start > 0 {
131        snippet.push_str("…");
132    }
133    snippet.push_str(&content[start..end]);
134    if end < content.len() {
135        snippet.push_str("…");
136    }
137    snippet
138}
139
140// ── AdvancedSearch impl ───────────────────────────────────────────────────────
141
142impl AdvancedSearch {
143    /// Create a new `AdvancedSearch` backed by an in-memory SQLite database.
144    pub fn new() -> Result<Self> {
145        let db = Connection::open_in_memory()?;
146        db.execute_batch(SCHEMA)?;
147        Ok(Self { db })
148    }
149
150    /// Index a document.  If a document with the same `id` already exists it
151    /// is replaced.
152    pub fn index(&self, id: &str, content: &str) -> Result<()> {
153        self.db.execute(
154            "INSERT INTO docs (id, content) VALUES (?1, ?2)
155             ON CONFLICT(id) DO UPDATE SET content = excluded.content",
156            params![id, content],
157        )?;
158        Ok(())
159    }
160
161    /// Run an advanced search combining BM25, trigram, RRF, fuzzy correction,
162    /// proximity reranking, and smart snippet extraction.
163    pub fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
164        let query = query.trim();
165        if query.is_empty() {
166            return Ok(Vec::new());
167        }
168
169        let terms: Vec<&str> = query.split_whitespace().collect();
170
171        // 1. BM25 search (porter stemming).
172        let bm25 = self.bm25_search(query);
173
174        // 2. Trigram substring search.
175        let trigram = self.trigram_search(query);
176
177        // 3. Merge via Reciprocal Rank Fusion.
178        let mut results = self.reciprocal_rank_fusion(&bm25, &trigram);
179
180        // 4. If no results, try fuzzy correction (Levenshtein ≤ 2).
181        if results.is_empty() {
182            if let Some(corrected) = self.fuzzy_correct(query) {
183                let bm25_c = self.bm25_search(&corrected);
184                let trigram_c = self.trigram_search(&corrected);
185                results = self.reciprocal_rank_fusion(&bm25_c, &trigram_c);
186            }
187        }
188
189        // 5. Proximity reranking for multi-term queries.
190        if terms.len() > 1 {
191            self.proximity_rerank(&mut results, &terms);
192        }
193
194        // 6. Smart snippet extraction.
195        for r in &mut results {
196            if let Ok(content) = self.get_content(&r.id) {
197                r.snippet = extract_snippet(&content, &terms, 80);
198            }
199        }
200
201        Ok(results)
202    }
203
204    // ── Internal helpers ──────────────────────────────────────────────────────
205
206    fn get_content(&self, id: &str) -> Result<String> {
207        let content: String = self.db.query_row(
208            "SELECT content FROM docs WHERE id = ?1",
209            params![id],
210            |row| row.get(0),
211        )?;
212        Ok(content)
213    }
214
215    /// BM25 search on the porter-stemmed FTS5 table.
216    /// Returns `(bm25_score, doc_id)` pairs ordered by relevance.
217    fn bm25_search(&self, query: &str) -> Vec<(f64, String)> {
218        let mut stmt = match self.db.prepare(
219            "SELECT d.id, bm25(docs_porter) AS score
220             FROM docs_porter p
221             JOIN docs d ON d.rowid = p.rowid
222             WHERE docs_porter MATCH ?1
223             ORDER BY score",
224        ) {
225            Ok(s) => s,
226            Err(_) => return Vec::new(),
227        };
228
229        let rows = match stmt.query_map(params![query], |row| {
230            Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
231        }) {
232            Ok(r) => r,
233            Err(_) => return Vec::new(),
234        };
235
236        rows.filter_map(|r| r.ok())
237            .map(|(id, score)| (score, id))
238            .collect()
239    }
240
241    /// Trigram substring search on the trigram FTS5 table.
242    /// Returns `(bm25_score, doc_id)` pairs ordered by relevance.
243    fn trigram_search(&self, query: &str) -> Vec<(f64, String)> {
244        // Trigram tokenizer requires the query to be at least 3 chars.
245        if query.len() < 3 {
246            return Vec::new();
247        }
248        let mut stmt = match self.db.prepare(
249            "SELECT d.id, bm25(docs_trigram) AS score
250             FROM docs_trigram t
251             JOIN docs d ON d.rowid = t.rowid
252             WHERE docs_trigram MATCH ?1
253             ORDER BY score",
254        ) {
255            Ok(s) => s,
256            Err(_) => return Vec::new(),
257        };
258
259        let rows = match stmt.query_map(params![query], |row| {
260            Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
261        }) {
262            Ok(r) => r,
263            Err(_) => return Vec::new(),
264        };
265
266        rows.filter_map(|r| r.ok())
267            .map(|(id, score)| (score, id))
268            .collect()
269    }
270
271    /// Merge two ranked lists using Reciprocal Rank Fusion.
272    ///
273    /// Documents appearing in both lists receive a higher fused score than
274    /// documents appearing in only one.
275    fn reciprocal_rank_fusion(
276        &self,
277        a: &[(f64, String)],
278        b: &[(f64, String)],
279    ) -> Vec<SearchResult> {
280        use std::collections::HashMap;
281
282        let mut scores: HashMap<String, f64> = HashMap::new();
283
284        for (rank, (_score, id)) in a.iter().enumerate() {
285            *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
286        }
287        for (rank, (_score, id)) in b.iter().enumerate() {
288            *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
289        }
290
291        let mut results: Vec<SearchResult> = scores
292            .into_iter()
293            .map(|(id, score)| SearchResult {
294                id,
295                score,
296                snippet: String::new(),
297            })
298            .collect();
299
300        // Sort descending by fused score.
301        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
302        results
303    }
304
305    /// Attempt fuzzy correction for the query.  Collects all known terms from
306    /// the docs table and finds the closest match within Levenshtein distance 2.
307    fn fuzzy_correct(&self, query: &str) -> Option<String> {
308        // Build a vocabulary of unique words from indexed documents.
309        let vocab = self.vocabulary();
310        if vocab.is_empty() {
311            return None;
312        }
313
314        let terms: Vec<&str> = query.split_whitespace().collect();
315        let mut corrected_terms: Vec<String> = Vec::new();
316        let mut any_corrected = false;
317
318        for term in &terms {
319            let lower = term.to_lowercase();
320            let mut best: Option<(usize, String)> = None;
321            for word in &vocab {
322                let dist = levenshtein(&lower, word);
323                if dist > 0 && dist <= 2 {
324                    if best.as_ref().map_or(true, |(d, _)| dist < *d) {
325                        best = Some((dist, word.clone()));
326                    }
327                }
328            }
329            if let Some((_dist, correction)) = best {
330                corrected_terms.push(correction);
331                any_corrected = true;
332            } else {
333                corrected_terms.push(lower);
334            }
335        }
336
337        if any_corrected {
338            Some(corrected_terms.join(" "))
339        } else {
340            None
341        }
342    }
343
344    /// Collect unique lowercase words from all indexed documents.
345    fn vocabulary(&self) -> Vec<String> {
346        let mut stmt = match self.db.prepare("SELECT content FROM docs") {
347            Ok(s) => s,
348            Err(_) => return Vec::new(),
349        };
350        let rows = match stmt.query_map([], |row| row.get::<_, String>(0)) {
351            Ok(r) => r,
352            Err(_) => return Vec::new(),
353        };
354
355        let mut words = std::collections::HashSet::new();
356        for row in rows.flatten() {
357            for word in row.split_whitespace() {
358                let w: String = word
359                    .chars()
360                    .filter(|c| c.is_alphanumeric())
361                    .collect::<String>()
362                    .to_lowercase();
363                if w.len() >= 2 {
364                    words.insert(w);
365                }
366            }
367        }
368        words.into_iter().collect()
369    }
370
371    /// Proximity reranking: boost results where query terms appear close
372    /// together in the document content.
373    fn proximity_rerank(&self, results: &mut Vec<SearchResult>, query_terms: &[&str]) {
374        for r in results.iter_mut() {
375            let content = match self.get_content(&r.id) {
376                Ok(c) => c,
377                Err(_) => continue,
378            };
379            let lower = content.to_lowercase();
380
381            // Find positions of each query term.
382            let mut positions: Vec<usize> = Vec::new();
383            for term in query_terms {
384                if let Some(pos) = lower.find(&term.to_lowercase()) {
385                    positions.push(pos);
386                }
387            }
388
389            if positions.len() >= 2 {
390                positions.sort_unstable();
391                // Compute the span (distance between first and last term).
392                let span = positions.last().unwrap() - positions.first().unwrap();
393                // Boost: closer terms → higher boost.  A span of 0 gives max
394                // boost of 2×; very distant terms give ~1× (no boost).
395                let boost = 1.0 + 1.0 / (1.0 + span as f64 / 50.0);
396                r.score *= boost;
397            }
398        }
399
400        // Re-sort after boosting.
401        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
402    }
403}
404
405// ── Tests ─────────────────────────────────────────────────────────────────────
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    fn make_search() -> AdvancedSearch {
412        AdvancedSearch::new().unwrap()
413    }
414
415    #[test]
416    fn test_index_and_bm25_search() {
417        let s = make_search();
418        s.index("d1", "the quick brown fox jumps over the lazy dog").unwrap();
419        s.index("d2", "a fast red car drives on the highway").unwrap();
420
421        let results = s.bm25_search("fox");
422        assert_eq!(results.len(), 1);
423        assert_eq!(results[0].1, "d1");
424    }
425
426    #[test]
427    fn test_trigram_search() {
428        let s = make_search();
429        s.index("d1", "authentication middleware handles tokens").unwrap();
430        s.index("d2", "database migration scripts for postgres").unwrap();
431
432        let results = s.trigram_search("auth");
433        assert_eq!(results.len(), 1);
434        assert_eq!(results[0].1, "d1");
435    }
436
437    #[test]
438    fn test_rrf_merge_both_lists() {
439        let s = make_search();
440        s.index("d1", "rust programming language systems").unwrap();
441        s.index("d2", "rust prevention coating for metal surfaces").unwrap();
442        s.index("d3", "programming in python is fun").unwrap();
443
444        // "rust programming" should match d1 in both porter and trigram,
445        // giving it a higher RRF score than d2 or d3.
446        let results = s.search("rust programming").unwrap();
447        assert!(!results.is_empty());
448        assert_eq!(results[0].id, "d1");
449    }
450
451    #[test]
452    fn test_rrf_docs_in_both_rank_higher() {
453        let s = make_search();
454        // d1 contains both "alpha" and "beta" — will appear in both BM25 and trigram.
455        s.index("d1", "alpha beta gamma delta").unwrap();
456        // d2 contains only "alpha".
457        s.index("d2", "alpha only here nothing else relevant").unwrap();
458
459        let bm25 = s.bm25_search("alpha");
460        let trigram = s.trigram_search("alpha");
461
462        // Both should find d1 and d2.
463        let merged = s.reciprocal_rank_fusion(&bm25, &trigram);
464        // d1 and d2 both appear in both lists, but let's just verify the
465        // merge produces results from both lists.
466        assert!(merged.len() >= 1);
467
468        // Docs appearing in both lists should have higher scores.
469        let in_bm25: std::collections::HashSet<_> = bm25.iter().map(|(_, id)| id.clone()).collect();
470        let in_trigram: std::collections::HashSet<_> = trigram.iter().map(|(_, id)| id.clone()).collect();
471        let in_both: std::collections::HashSet<_> = in_bm25.intersection(&in_trigram).cloned().collect();
472
473        if merged.len() >= 2 {
474            let top = &merged[0];
475            if in_both.contains(&top.id) {
476                // Good — doc in both lists is ranked first.
477            }
478        }
479    }
480
481    #[test]
482    fn test_fuzzy_correction() {
483        let s = make_search();
484        s.index("d1", "authentication middleware").unwrap();
485        s.index("d2", "database migration").unwrap();
486
487        // "authentcation" is a typo (missing 'i'), Levenshtein distance 1.
488        let corrected = s.fuzzy_correct("authentcation");
489        assert!(corrected.is_some());
490        let c = corrected.unwrap();
491        assert!(c.contains("authentication"), "corrected to: {}", c);
492    }
493
494    #[test]
495    fn test_fuzzy_search_end_to_end() {
496        let s = make_search();
497        s.index("d1", "authentication middleware handles tokens").unwrap();
498
499        // Typo query — should still find d1 via fuzzy correction.
500        let results = s.search("authentcation").unwrap();
501        assert!(!results.is_empty());
502        assert_eq!(results[0].id, "d1");
503    }
504
505    #[test]
506    fn test_proximity_reranking() {
507        let s = make_search();
508        // d1: terms "error" and "handler" are close together.
509        s.index("d1", "the error handler catches all exceptions").unwrap();
510        // d2: terms "error" and "handler" are far apart.
511        s.index(
512            "d2",
513            "an error occurred in the system and after many lines of unrelated text the handler was invoked",
514        ).unwrap();
515
516        let results = s.search("error handler").unwrap();
517        assert!(results.len() >= 2);
518        // d1 should rank higher due to proximity boost.
519        assert_eq!(results[0].id, "d1");
520    }
521
522    #[test]
523    fn test_smart_snippet_extraction() {
524        let content = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
525                        The authentication module verifies JWT tokens. \
526                        Sed do eiusmod tempor incididunt ut labore.";
527        let snippet = extract_snippet(content, &["authentication"], 40);
528        assert!(snippet.contains("authentication"));
529        // Should have ellipsis since it's in the middle.
530        assert!(snippet.contains("…"));
531    }
532
533    #[test]
534    fn test_empty_query_returns_empty() {
535        let s = make_search();
536        s.index("d1", "some content").unwrap();
537        let results = s.search("").unwrap();
538        assert!(results.is_empty());
539    }
540
541    #[test]
542    fn test_no_results_returns_empty() {
543        let s = make_search();
544        s.index("d1", "hello world").unwrap();
545        let results = s.search("zzzznonexistent").unwrap();
546        assert!(results.is_empty());
547    }
548
549    #[test]
550    fn test_index_upsert() {
551        let s = make_search();
552        s.index("d1", "original content about cats").unwrap();
553        s.index("d1", "updated content about dogs").unwrap();
554
555        let results = s.search("dogs").unwrap();
556        assert_eq!(results.len(), 1);
557        assert_eq!(results[0].id, "d1");
558
559        let results = s.search("cats").unwrap();
560        assert!(results.is_empty());
561    }
562
563    #[test]
564    fn test_levenshtein_distance() {
565        assert_eq!(levenshtein("kitten", "sitting"), 3);
566        assert_eq!(levenshtein("hello", "hello"), 0);
567        assert_eq!(levenshtein("hello", "helo"), 1);
568        assert_eq!(levenshtein("", "abc"), 3);
569        assert_eq!(levenshtein("abc", ""), 3);
570    }
571
572    #[test]
573    fn test_snippet_at_start() {
574        let content = "authentication is important for security";
575        let snippet = extract_snippet(content, &["authentication"], 80);
576        assert!(snippet.contains("authentication"));
577        // Should not have leading ellipsis since match is at start.
578        assert!(!snippet.starts_with('…'));
579    }
580
581    #[test]
582    fn test_multiple_documents_search() {
583        let s = make_search();
584        for i in 0..10 {
585            s.index(&format!("d{}", i), &format!("document number {} about testing", i))
586                .unwrap();
587        }
588        let results = s.search("testing").unwrap();
589        assert_eq!(results.len(), 10);
590    }
591
592    mod prop_tests {
593        use super::*;
594        use proptest::prelude::*;
595        use std::collections::{HashMap, HashSet};
596
597        // **Validates: Requirements 41.1, 41.2**
598        //
599        // Property 41: Advanced search RRF merges correctly
600        //
601        // For any two ranked lists, the Reciprocal Rank Fusion merge SHALL:
602        // 1. Produce a result set containing all unique documents from both inputs.
603        // 2. Assign a higher RRF score to documents appearing in both lists than
604        //    to documents appearing in only one list (at comparable rank positions).
605        proptest! {
606            #[test]
607            fn prop_rrf_merge_contains_all_unique_docs_and_both_rank_higher(
608                // Generate 2-6 unique doc IDs for list A, 2-6 for list B,
609                // with at least some overlap guaranteed by construction.
610                shared_count in 1..4usize,
611                a_only_count in 1..4usize,
612                b_only_count in 1..4usize,
613            ) {
614                let s = make_search();
615
616                // Build document sets: shared docs appear in both lists,
617                // a_only docs appear only in list A, b_only only in list B.
618                let mut list_a: Vec<(f64, String)> = Vec::new();
619                let mut list_b: Vec<(f64, String)> = Vec::new();
620
621                // Shared documents — present in both lists.
622                for i in 0..shared_count {
623                    let id = format!("shared_{}", i);
624                    // Use descending scores so rank = index.
625                    list_a.push((-(i as f64), id.clone()));
626                    list_b.push((-(i as f64), id));
627                }
628
629                // A-only documents.
630                for i in 0..a_only_count {
631                    let id = format!("a_only_{}", i);
632                    list_a.push((-((shared_count + i) as f64), id));
633                }
634
635                // B-only documents.
636                for i in 0..b_only_count {
637                    let id = format!("b_only_{}", i);
638                    list_b.push((-((shared_count + i) as f64), id));
639                }
640
641                let merged = s.reciprocal_rank_fusion(&list_a, &list_b);
642
643                // ── Property 1: merged contains all unique docs from both lists ──
644                let all_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone())
645                    .chain(list_b.iter().map(|(_, id)| id.clone()))
646                    .collect();
647                let merged_ids: HashSet<String> = merged.iter().map(|r| r.id.clone()).collect();
648                prop_assert_eq!(
649                    merged_ids, all_ids,
650                    "Merged result set must contain all unique documents from both input lists"
651                );
652
653                // ── Property 2: docs in both lists score higher than docs in only one ──
654                let a_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone()).collect();
655                let b_ids: HashSet<String> = list_b.iter().map(|(_, id)| id.clone()).collect();
656                let in_both: HashSet<String> = a_ids.intersection(&b_ids).cloned().collect();
657                let in_one_only: HashSet<String> = a_ids.symmetric_difference(&b_ids).cloned().collect();
658
659                if !in_both.is_empty() && !in_one_only.is_empty() {
660                    let scores: HashMap<String, f64> = merged.iter()
661                        .map(|r| (r.id.clone(), r.score))
662                        .collect();
663
664                    let min_both_score = in_both.iter()
665                        .filter_map(|id| scores.get(id))
666                        .cloned()
667                        .fold(f64::INFINITY, f64::min);
668
669                    let max_one_score = in_one_only.iter()
670                        .filter_map(|id| scores.get(id))
671                        .cloned()
672                        .fold(f64::NEG_INFINITY, f64::max);
673
674                    prop_assert!(
675                        min_both_score > max_one_score,
676                        "Documents in both lists (min score {}) must score higher \
677                         than documents in only one list (max score {})",
678                        min_both_score, max_one_score
679                    );
680                }
681            }
682        }
683    }
684}