1use rusqlite::{params, Connection};
2
3use crate::error::Result;
4
5#[derive(Debug, Clone)]
7pub struct SearchResult {
8 pub id: String,
9 pub score: f64,
10 pub snippet: String,
11}
12
13pub struct AdvancedSearch {
17 db: Connection,
18}
19
20const SCHEMA: &str = r#"
23CREATE TABLE IF NOT EXISTS docs (
24 id TEXT PRIMARY KEY,
25 content TEXT NOT NULL
26);
27
28CREATE VIRTUAL TABLE IF NOT EXISTS docs_porter USING fts5(
29 id,
30 content,
31 content='docs',
32 content_rowid='rowid',
33 tokenize='porter ascii'
34);
35
36CREATE VIRTUAL TABLE IF NOT EXISTS docs_trigram USING fts5(
37 id,
38 content,
39 content='docs',
40 content_rowid='rowid',
41 tokenize='trigram'
42);
43
44-- Keep FTS tables in sync with the docs table.
45CREATE TRIGGER IF NOT EXISTS docs_ai AFTER INSERT ON docs BEGIN
46 INSERT INTO docs_porter(rowid, id, content)
47 VALUES (new.rowid, new.id, new.content);
48 INSERT INTO docs_trigram(rowid, id, content)
49 VALUES (new.rowid, new.id, new.content);
50END;
51
52CREATE TRIGGER IF NOT EXISTS docs_ad AFTER DELETE ON docs BEGIN
53 INSERT INTO docs_porter(docs_porter, rowid, id, content)
54 VALUES ('delete', old.rowid, old.id, old.content);
55 INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
56 VALUES ('delete', old.rowid, old.id, old.content);
57END;
58
59CREATE TRIGGER IF NOT EXISTS docs_au AFTER UPDATE ON docs BEGIN
60 INSERT INTO docs_porter(docs_porter, rowid, id, content)
61 VALUES ('delete', old.rowid, old.id, old.content);
62 INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
63 VALUES ('delete', old.rowid, old.id, old.content);
64 INSERT INTO docs_porter(rowid, id, content)
65 VALUES (new.rowid, new.id, new.content);
66 INSERT INTO docs_trigram(rowid, id, content)
67 VALUES (new.rowid, new.id, new.content);
68END;
69"#;
70
71const RRF_K: f64 = 60.0;
76
77fn levenshtein(a: &str, b: &str) -> usize {
81 let a_chars: Vec<char> = a.chars().collect();
82 let b_chars: Vec<char> = b.chars().collect();
83 let m = a_chars.len();
84 let n = b_chars.len();
85
86 let mut prev = (0..=n).collect::<Vec<_>>();
87 let mut curr = vec![0usize; n + 1];
88
89 for i in 1..=m {
90 curr[0] = i;
91 for j in 1..=n {
92 let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
93 curr[j] = (prev[j] + 1)
94 .min(curr[j - 1] + 1)
95 .min(prev[j - 1] + cost);
96 }
97 std::mem::swap(&mut prev, &mut curr);
98 }
99 prev[n]
100}
101
102fn extract_snippet(content: &str, query_terms: &[&str], window: usize) -> String {
105 let lower = content.to_lowercase();
106 let mut best_pos: Option<usize> = None;
108 for term in query_terms {
109 if let Some(pos) = lower.find(&term.to_lowercase()) {
110 best_pos = Some(match best_pos {
111 Some(bp) => bp.min(pos),
112 None => pos,
113 });
114 }
115 }
116
117 let pos = match best_pos {
118 Some(p) => p,
119 None => 0,
120 };
121
122 let start = pos.saturating_sub(window);
123 let end = (pos + window).min(content.len());
124
125 let start = {
127 let mut i = start;
128 while i > 0 && !content.is_char_boundary(i) { i -= 1; }
129 i
130 };
131 let end = {
132 let mut i = end;
133 while i < content.len() && !content.is_char_boundary(i) { i += 1; }
134 i
135 };
136
137 let mut snippet = String::new();
138 if start > 0 {
139 snippet.push_str("…");
140 }
141 snippet.push_str(&content[start..end]);
142 if end < content.len() {
143 snippet.push_str("…");
144 }
145 snippet
146}
147
148impl AdvancedSearch {
151 pub fn new() -> Result<Self> {
153 let db = Connection::open_in_memory()?;
154 db.execute_batch(SCHEMA)?;
155 Ok(Self { db })
156 }
157
158 pub fn index(&self, id: &str, content: &str) -> Result<()> {
161 self.db.execute(
162 "INSERT INTO docs (id, content) VALUES (?1, ?2)
163 ON CONFLICT(id) DO UPDATE SET content = excluded.content",
164 params![id, content],
165 )?;
166 Ok(())
167 }
168
169 pub fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
172 let query = query.trim();
173 if query.is_empty() {
174 return Ok(Vec::new());
175 }
176
177 let terms: Vec<&str> = query.split_whitespace().collect();
178
179 let bm25 = self.bm25_search(query);
181
182 let trigram = self.trigram_search(query);
184
185 let mut results = self.reciprocal_rank_fusion(&bm25, &trigram);
187
188 if results.is_empty() {
190 if let Some(corrected) = self.fuzzy_correct(query) {
191 let bm25_c = self.bm25_search(&corrected);
192 let trigram_c = self.trigram_search(&corrected);
193 results = self.reciprocal_rank_fusion(&bm25_c, &trigram_c);
194 }
195 }
196
197 if terms.len() > 1 {
199 self.proximity_rerank(&mut results, &terms);
200 }
201
202 for r in &mut results {
204 if let Ok(content) = self.get_content(&r.id) {
205 r.snippet = extract_snippet(&content, &terms, 80);
206 }
207 }
208
209 Ok(results)
210 }
211
212 fn get_content(&self, id: &str) -> Result<String> {
215 let content: String = self.db.query_row(
216 "SELECT content FROM docs WHERE id = ?1",
217 params![id],
218 |row| row.get(0),
219 )?;
220 Ok(content)
221 }
222
223 fn bm25_search(&self, query: &str) -> Vec<(f64, String)> {
226 let mut stmt = match self.db.prepare(
227 "SELECT d.id, bm25(docs_porter) AS score
228 FROM docs_porter p
229 JOIN docs d ON d.rowid = p.rowid
230 WHERE docs_porter MATCH ?1
231 ORDER BY score",
232 ) {
233 Ok(s) => s,
234 Err(_) => return Vec::new(),
235 };
236
237 let rows = match stmt.query_map(params![query], |row| {
238 Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
239 }) {
240 Ok(r) => r,
241 Err(_) => return Vec::new(),
242 };
243
244 rows.filter_map(|r| r.ok())
245 .map(|(id, score)| (score, id))
246 .collect()
247 }
248
249 fn trigram_search(&self, query: &str) -> Vec<(f64, String)> {
252 if query.len() < 3 {
254 return Vec::new();
255 }
256 let mut stmt = match self.db.prepare(
257 "SELECT d.id, bm25(docs_trigram) AS score
258 FROM docs_trigram t
259 JOIN docs d ON d.rowid = t.rowid
260 WHERE docs_trigram MATCH ?1
261 ORDER BY score",
262 ) {
263 Ok(s) => s,
264 Err(_) => return Vec::new(),
265 };
266
267 let rows = match stmt.query_map(params![query], |row| {
268 Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
269 }) {
270 Ok(r) => r,
271 Err(_) => return Vec::new(),
272 };
273
274 rows.filter_map(|r| r.ok())
275 .map(|(id, score)| (score, id))
276 .collect()
277 }
278
279 fn reciprocal_rank_fusion(
284 &self,
285 a: &[(f64, String)],
286 b: &[(f64, String)],
287 ) -> Vec<SearchResult> {
288 use std::collections::HashMap;
289
290 let mut scores: HashMap<String, f64> = HashMap::new();
291
292 for (rank, (_score, id)) in a.iter().enumerate() {
293 *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
294 }
295 for (rank, (_score, id)) in b.iter().enumerate() {
296 *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
297 }
298
299 let mut results: Vec<SearchResult> = scores
300 .into_iter()
301 .map(|(id, score)| SearchResult {
302 id,
303 score,
304 snippet: String::new(),
305 })
306 .collect();
307
308 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
310 results
311 }
312
313 fn fuzzy_correct(&self, query: &str) -> Option<String> {
316 let vocab = self.vocabulary();
318 if vocab.is_empty() {
319 return None;
320 }
321
322 let terms: Vec<&str> = query.split_whitespace().collect();
323 let mut corrected_terms: Vec<String> = Vec::new();
324 let mut any_corrected = false;
325
326 for term in &terms {
327 let lower = term.to_lowercase();
328 let mut best: Option<(usize, String)> = None;
329 for word in &vocab {
330 let dist = levenshtein(&lower, word);
331 if dist > 0 && dist <= 2 {
332 if best.as_ref().map_or(true, |(d, _)| dist < *d) {
333 best = Some((dist, word.clone()));
334 }
335 }
336 }
337 if let Some((_dist, correction)) = best {
338 corrected_terms.push(correction);
339 any_corrected = true;
340 } else {
341 corrected_terms.push(lower);
342 }
343 }
344
345 if any_corrected {
346 Some(corrected_terms.join(" "))
347 } else {
348 None
349 }
350 }
351
352 fn vocabulary(&self) -> Vec<String> {
354 let mut stmt = match self.db.prepare("SELECT content FROM docs") {
355 Ok(s) => s,
356 Err(_) => return Vec::new(),
357 };
358 let rows = match stmt.query_map([], |row| row.get::<_, String>(0)) {
359 Ok(r) => r,
360 Err(_) => return Vec::new(),
361 };
362
363 let mut words = std::collections::HashSet::new();
364 for row in rows.flatten() {
365 for word in row.split_whitespace() {
366 let w: String = word
367 .chars()
368 .filter(|c| c.is_alphanumeric())
369 .collect::<String>()
370 .to_lowercase();
371 if w.len() >= 2 {
372 words.insert(w);
373 }
374 }
375 }
376 words.into_iter().collect()
377 }
378
379 fn proximity_rerank(&self, results: &mut Vec<SearchResult>, query_terms: &[&str]) {
382 for r in results.iter_mut() {
383 let content = match self.get_content(&r.id) {
384 Ok(c) => c,
385 Err(_) => continue,
386 };
387 let lower = content.to_lowercase();
388
389 let mut positions: Vec<usize> = Vec::new();
391 for term in query_terms {
392 if let Some(pos) = lower.find(&term.to_lowercase()) {
393 positions.push(pos);
394 }
395 }
396
397 if positions.len() >= 2 {
398 positions.sort_unstable();
399 let span = positions.last().unwrap() - positions.first().unwrap();
401 let boost = 1.0 + 1.0 / (1.0 + span as f64 / 50.0);
404 r.score *= boost;
405 }
406 }
407
408 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
410 }
411}
412
413#[cfg(test)]
416mod tests {
417 use super::*;
418
419 fn make_search() -> AdvancedSearch {
420 AdvancedSearch::new().unwrap()
421 }
422
423 #[test]
424 fn test_index_and_bm25_search() {
425 let s = make_search();
426 s.index("d1", "the quick brown fox jumps over the lazy dog").unwrap();
427 s.index("d2", "a fast red car drives on the highway").unwrap();
428
429 let results = s.bm25_search("fox");
430 assert_eq!(results.len(), 1);
431 assert_eq!(results[0].1, "d1");
432 }
433
434 #[test]
435 fn test_trigram_search() {
436 let s = make_search();
437 s.index("d1", "authentication middleware handles tokens").unwrap();
438 s.index("d2", "database migration scripts for postgres").unwrap();
439
440 let results = s.trigram_search("auth");
441 assert_eq!(results.len(), 1);
442 assert_eq!(results[0].1, "d1");
443 }
444
445 #[test]
446 fn test_rrf_merge_both_lists() {
447 let s = make_search();
448 s.index("d1", "rust programming language systems").unwrap();
449 s.index("d2", "rust prevention coating for metal surfaces").unwrap();
450 s.index("d3", "programming in python is fun").unwrap();
451
452 let results = s.search("rust programming").unwrap();
455 assert!(!results.is_empty());
456 assert_eq!(results[0].id, "d1");
457 }
458
459 #[test]
460 fn test_rrf_docs_in_both_rank_higher() {
461 let s = make_search();
462 s.index("d1", "alpha beta gamma delta").unwrap();
464 s.index("d2", "alpha only here nothing else relevant").unwrap();
466
467 let bm25 = s.bm25_search("alpha");
468 let trigram = s.trigram_search("alpha");
469
470 let merged = s.reciprocal_rank_fusion(&bm25, &trigram);
472 assert!(merged.len() >= 1);
475
476 let in_bm25: std::collections::HashSet<_> = bm25.iter().map(|(_, id)| id.clone()).collect();
478 let in_trigram: std::collections::HashSet<_> = trigram.iter().map(|(_, id)| id.clone()).collect();
479 let in_both: std::collections::HashSet<_> = in_bm25.intersection(&in_trigram).cloned().collect();
480
481 if merged.len() >= 2 {
482 let top = &merged[0];
483 if in_both.contains(&top.id) {
484 }
486 }
487 }
488
489 #[test]
490 fn test_fuzzy_correction() {
491 let s = make_search();
492 s.index("d1", "authentication middleware").unwrap();
493 s.index("d2", "database migration").unwrap();
494
495 let corrected = s.fuzzy_correct("authentcation");
497 assert!(corrected.is_some());
498 let c = corrected.unwrap();
499 assert!(c.contains("authentication"), "corrected to: {}", c);
500 }
501
502 #[test]
503 fn test_fuzzy_search_end_to_end() {
504 let s = make_search();
505 s.index("d1", "authentication middleware handles tokens").unwrap();
506
507 let results = s.search("authentcation").unwrap();
509 assert!(!results.is_empty());
510 assert_eq!(results[0].id, "d1");
511 }
512
513 #[test]
514 fn test_proximity_reranking() {
515 let s = make_search();
516 s.index("d1", "the error handler catches all exceptions").unwrap();
518 s.index(
520 "d2",
521 "an error occurred in the system and after many lines of unrelated text the handler was invoked",
522 ).unwrap();
523
524 let results = s.search("error handler").unwrap();
525 assert!(results.len() >= 2);
526 assert_eq!(results[0].id, "d1");
528 }
529
530 #[test]
531 fn test_smart_snippet_extraction() {
532 let content = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
533 The authentication module verifies JWT tokens. \
534 Sed do eiusmod tempor incididunt ut labore.";
535 let snippet = extract_snippet(content, &["authentication"], 40);
536 assert!(snippet.contains("authentication"));
537 assert!(snippet.contains("…"));
539 }
540
541 #[test]
542 fn test_empty_query_returns_empty() {
543 let s = make_search();
544 s.index("d1", "some content").unwrap();
545 let results = s.search("").unwrap();
546 assert!(results.is_empty());
547 }
548
549 #[test]
550 fn test_no_results_returns_empty() {
551 let s = make_search();
552 s.index("d1", "hello world").unwrap();
553 let results = s.search("zzzznonexistent").unwrap();
554 assert!(results.is_empty());
555 }
556
557 #[test]
558 fn test_index_upsert() {
559 let s = make_search();
560 s.index("d1", "original content about cats").unwrap();
561 s.index("d1", "updated content about dogs").unwrap();
562
563 let results = s.search("dogs").unwrap();
564 assert_eq!(results.len(), 1);
565 assert_eq!(results[0].id, "d1");
566
567 let results = s.search("cats").unwrap();
568 assert!(results.is_empty());
569 }
570
571 #[test]
572 fn test_levenshtein_distance() {
573 assert_eq!(levenshtein("kitten", "sitting"), 3);
574 assert_eq!(levenshtein("hello", "hello"), 0);
575 assert_eq!(levenshtein("hello", "helo"), 1);
576 assert_eq!(levenshtein("", "abc"), 3);
577 assert_eq!(levenshtein("abc", ""), 3);
578 }
579
580 #[test]
581 fn test_snippet_at_start() {
582 let content = "authentication is important for security";
583 let snippet = extract_snippet(content, &["authentication"], 80);
584 assert!(snippet.contains("authentication"));
585 assert!(!snippet.starts_with('…'));
587 }
588
589 #[test]
590 fn test_multiple_documents_search() {
591 let s = make_search();
592 for i in 0..10 {
593 s.index(&format!("d{}", i), &format!("document number {} about testing", i))
594 .unwrap();
595 }
596 let results = s.search("testing").unwrap();
597 assert_eq!(results.len(), 10);
598 }
599
600 mod prop_tests {
601 use super::*;
602 use proptest::prelude::*;
603 use std::collections::{HashMap, HashSet};
604
605 proptest! {
614 #[test]
615 fn prop_rrf_merge_contains_all_unique_docs_and_both_rank_higher(
616 shared_count in 1..4usize,
619 a_only_count in 1..4usize,
620 b_only_count in 1..4usize,
621 ) {
622 let s = make_search();
623
624 let mut list_a: Vec<(f64, String)> = Vec::new();
627 let mut list_b: Vec<(f64, String)> = Vec::new();
628
629 for i in 0..shared_count {
631 let id = format!("shared_{}", i);
632 list_a.push((-(i as f64), id.clone()));
634 list_b.push((-(i as f64), id));
635 }
636
637 for i in 0..a_only_count {
639 let id = format!("a_only_{}", i);
640 list_a.push((-((shared_count + i) as f64), id));
641 }
642
643 for i in 0..b_only_count {
645 let id = format!("b_only_{}", i);
646 list_b.push((-((shared_count + i) as f64), id));
647 }
648
649 let merged = s.reciprocal_rank_fusion(&list_a, &list_b);
650
651 let all_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone())
653 .chain(list_b.iter().map(|(_, id)| id.clone()))
654 .collect();
655 let merged_ids: HashSet<String> = merged.iter().map(|r| r.id.clone()).collect();
656 prop_assert_eq!(
657 merged_ids, all_ids,
658 "Merged result set must contain all unique documents from both input lists"
659 );
660
661 let a_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone()).collect();
663 let b_ids: HashSet<String> = list_b.iter().map(|(_, id)| id.clone()).collect();
664 let in_both: HashSet<String> = a_ids.intersection(&b_ids).cloned().collect();
665 let in_one_only: HashSet<String> = a_ids.symmetric_difference(&b_ids).cloned().collect();
666
667 if !in_both.is_empty() && !in_one_only.is_empty() {
668 let scores: HashMap<String, f64> = merged.iter()
669 .map(|r| (r.id.clone(), r.score))
670 .collect();
671
672 let min_both_score = in_both.iter()
673 .filter_map(|id| scores.get(id))
674 .cloned()
675 .fold(f64::INFINITY, f64::min);
676
677 let max_one_score = in_one_only.iter()
678 .filter_map(|id| scores.get(id))
679 .cloned()
680 .fold(f64::NEG_INFINITY, f64::max);
681
682 prop_assert!(
683 min_both_score > max_one_score,
684 "Documents in both lists (min score {}) must score higher \
685 than documents in only one list (max score {})",
686 min_both_score, max_one_score
687 );
688 }
689 }
690 }
691 }
692}