1use std::collections::{HashMap, HashSet};
11use std::str::FromStr;
12
13use serde::{Deserialize, Serialize};
14
15use crate::types::{AppError, Document, Result};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
23#[serde(rename_all = "kebab-case")]
24pub enum SearchStrategy {
25 #[default]
27 Semantic,
28 Bm25,
30 Fuzzy,
32 Hybrid,
34}
35
36impl FromStr for SearchStrategy {
37 type Err = AppError;
38
39 fn from_str(s: &str) -> Result<Self> {
40 match s.to_lowercase().as_str() {
41 "semantic" | "dense" | "vector" => Ok(Self::Semantic),
42 "bm25" | "lexical" | "sparse" => Ok(Self::Bm25),
43 "fuzzy" | "approximate" => Ok(Self::Fuzzy),
44 "hybrid" | "combined" | "rrf" => Ok(Self::Hybrid),
45 _ => Err(AppError::Internal(format!(
46 "Unknown search strategy: {}. Use: semantic, bm25, fuzzy, hybrid",
47 s
48 ))),
49 }
50 }
51}
52
53impl std::fmt::Display for SearchStrategy {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 let name = match self {
56 Self::Semantic => "semantic",
57 Self::Bm25 => "bm25",
58 Self::Fuzzy => "fuzzy",
59 Self::Hybrid => "hybrid",
60 };
61 write!(f, "{}", name)
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchResult {
72 pub id: String,
74 pub content: String,
76 pub score: f32,
78 pub sources: Vec<SearchStrategy>,
80 pub metadata: Option<serde_json::Value>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchRequest {
87 pub query: String,
89 #[serde(default)]
91 pub strategy: SearchStrategy,
92 #[serde(default = "default_top_k")]
94 pub top_k: usize,
95 #[serde(default)]
97 pub min_score: f32,
98 #[serde(default)]
100 pub rerank: bool,
101 pub collection: String,
103 #[serde(default)]
105 pub hybrid_weights: HybridWeights,
106}
107
108fn default_top_k() -> usize {
109 10
110}
111
112#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
114pub struct HybridWeights {
115 pub semantic: f32,
117 pub bm25: f32,
119 pub fuzzy: f32,
121}
122
123impl Default for HybridWeights {
124 fn default() -> Self {
125 Self {
126 semantic: 0.6,
127 bm25: 0.3,
128 fuzzy: 0.1,
129 }
130 }
131}
132
133#[derive(Debug, Clone, Default)]
139pub struct Bm25Index {
140 documents: HashMap<String, Vec<String>>,
142 inverted_index: HashMap<String, HashSet<String>>,
144 document_frequencies: HashMap<String, usize>,
146 doc_count: usize,
148 avg_doc_length: f32,
150 k1: f32,
152 b: f32,
154}
155
156impl Bm25Index {
157 pub fn new() -> Self {
159 Self {
160 k1: 1.2,
161 b: 0.75,
162 ..Default::default()
163 }
164 }
165
166 pub fn with_params(k1: f32, b: f32) -> Self {
168 Self {
169 k1,
170 b,
171 ..Default::default()
172 }
173 }
174
175 fn tokenize(text: &str) -> Vec<String> {
177 text.to_lowercase()
178 .split(|c: char| !c.is_alphanumeric())
179 .filter(|s| !s.is_empty() && s.len() > 1)
180 .map(String::from)
181 .collect()
182 }
183
184 pub fn add_document(&mut self, id: &str, content: &str) {
186 let tokens = Self::tokenize(content);
187
188 let unique_terms: HashSet<_> = tokens.iter().cloned().collect();
190 for term in &unique_terms {
191 *self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
192 self.inverted_index
193 .entry(term.clone())
194 .or_default()
195 .insert(id.to_string());
196 }
197
198 self.documents.insert(id.to_string(), tokens);
200 self.doc_count += 1;
201
202 let total_tokens: usize = self.documents.values().map(|v| v.len()).sum();
204 self.avg_doc_length = total_tokens as f32 / self.doc_count as f32;
205 }
206
207 pub fn remove_document(&mut self, id: &str) {
209 if let Some(tokens) = self.documents.remove(id) {
210 let unique_terms: HashSet<_> = tokens.into_iter().collect();
211 for term in unique_terms {
212 if let Some(df) = self.document_frequencies.get_mut(&term) {
213 *df = df.saturating_sub(1);
214 if *df == 0 {
215 self.document_frequencies.remove(&term);
216 }
217 }
218 if let Some(docs) = self.inverted_index.get_mut(&term) {
219 docs.remove(id);
220 if docs.is_empty() {
221 self.inverted_index.remove(&term);
222 }
223 }
224 }
225 self.doc_count = self.doc_count.saturating_sub(1);
226
227 if self.doc_count > 0 {
229 let total_tokens: usize = self.documents.values().map(|v| v.len()).sum();
230 self.avg_doc_length = total_tokens as f32 / self.doc_count as f32;
231 } else {
232 self.avg_doc_length = 0.0;
233 }
234 }
235 }
236
237 fn idf(&self, term: &str) -> f32 {
239 let df = self.document_frequencies.get(term).copied().unwrap_or(0) as f32;
240 let n = self.doc_count as f32;
241 if df == 0.0 || n == 0.0 {
242 return 0.0;
243 }
244 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
245 }
246
247 fn score_document(&self, doc_id: &str, query_terms: &[String]) -> f32 {
249 let doc_tokens = match self.documents.get(doc_id) {
250 Some(tokens) => tokens,
251 None => return 0.0,
252 };
253
254 let doc_len = doc_tokens.len() as f32;
255 let mut score = 0.0;
256
257 let mut term_freq: HashMap<&str, usize> = HashMap::new();
259 for token in doc_tokens {
260 *term_freq.entry(token.as_str()).or_insert(0) += 1;
261 }
262
263 for term in query_terms {
264 let tf = term_freq.get(term.as_str()).copied().unwrap_or(0) as f32;
265 let idf = self.idf(term);
266
267 let numerator = tf * (self.k1 + 1.0);
269 let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length);
270 score += idf * numerator / denominator;
271 }
272
273 score
274 }
275
276 pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
278 let query_terms = Self::tokenize(query);
279 if query_terms.is_empty() {
280 return Vec::new();
281 }
282
283 let mut candidates: HashSet<String> = HashSet::new();
285 for term in &query_terms {
286 if let Some(docs) = self.inverted_index.get(term) {
287 candidates.extend(docs.iter().cloned());
288 }
289 }
290
291 let mut results: Vec<(String, f32)> = candidates
293 .iter()
294 .map(|id| {
295 let score = self.score_document(id, &query_terms);
296 (id.clone(), score)
297 })
298 .filter(|(_, score)| *score > 0.0)
299 .collect();
300
301 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
303
304 results.truncate(top_k);
306 results
307 }
308
309 pub fn len(&self) -> usize {
311 self.doc_count
312 }
313
314 pub fn is_empty(&self) -> bool {
316 self.doc_count == 0
317 }
318
319 pub fn clear(&mut self) {
321 self.documents.clear();
322 self.inverted_index.clear();
323 self.document_frequencies.clear();
324 self.doc_count = 0;
325 self.avg_doc_length = 0.0;
326 }
327}
328
329#[derive(Debug, Clone, Default)]
335pub struct FuzzyIndex {
336 documents: HashMap<String, String>,
338 max_distance: usize,
340}
341
342impl FuzzyIndex {
343 pub fn new() -> Self {
345 Self {
346 max_distance: 2,
347 ..Default::default()
348 }
349 }
350
351 pub fn with_max_distance(max_distance: usize) -> Self {
353 Self {
354 max_distance,
355 ..Default::default()
356 }
357 }
358
359 pub fn add_document(&mut self, id: &str, content: &str) {
361 self.documents.insert(id.to_string(), content.to_lowercase());
362 }
363
364 pub fn remove_document(&mut self, id: &str) {
366 self.documents.remove(id);
367 }
368
369 fn levenshtein_distance(s1: &str, s2: &str) -> usize {
371 let len1 = s1.chars().count();
372 let len2 = s2.chars().count();
373
374 if len1 == 0 {
375 return len2;
376 }
377 if len2 == 0 {
378 return len1;
379 }
380
381 let s1_chars: Vec<char> = s1.chars().collect();
382 let s2_chars: Vec<char> = s2.chars().collect();
383
384 let mut prev_row: Vec<usize> = (0..=len2).collect();
385 let mut curr_row = vec![0; len2 + 1];
386
387 for (i, c1) in s1_chars.iter().enumerate() {
388 curr_row[0] = i + 1;
389
390 for (j, c2) in s2_chars.iter().enumerate() {
391 let cost = if c1 == c2 { 0 } else { 1 };
392 curr_row[j + 1] = (prev_row[j + 1] + 1)
393 .min(curr_row[j] + 1)
394 .min(prev_row[j] + cost);
395 }
396
397 std::mem::swap(&mut prev_row, &mut curr_row);
398 }
399
400 prev_row[len2]
401 }
402
403 fn fuzzy_score(query: &str, text: &str, max_distance: usize) -> f32 {
405 let query_lower = query.to_lowercase();
406 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
407
408 let mut total_score = 0.0;
410 let mut matched_words = 0;
411
412 for query_word in &query_words {
413 let mut best_score = 0.0f32;
414
415 for text_word in text.split_whitespace() {
416 if text_word.len() < 2 {
417 continue;
418 }
419
420 let distance = Self::levenshtein_distance(query_word, text_word);
421 if distance <= max_distance {
422 let max_len = query_word.len().max(text_word.len());
423 let score = 1.0 - (distance as f32 / max_len as f32);
424 best_score = best_score.max(score);
425 }
426 }
427
428 if best_score > 0.0 {
429 total_score += best_score;
430 matched_words += 1;
431 }
432 }
433
434 if matched_words > 0 {
435 (total_score / query_words.len() as f32) * (matched_words as f32 / query_words.len() as f32)
436 } else {
437 0.0
438 }
439 }
440
441 pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
443 let mut results: Vec<(String, f32)> = self
444 .documents
445 .iter()
446 .filter_map(|(id, content)| {
447 let score = Self::fuzzy_score(query, content, self.max_distance);
448 if score > 0.0 {
449 Some((id.clone(), score))
450 } else {
451 None
452 }
453 })
454 .collect();
455
456 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
458
459 results.truncate(top_k);
461 results
462 }
463
464 pub fn len(&self) -> usize {
466 self.documents.len()
467 }
468
469 pub fn is_empty(&self) -> bool {
471 self.documents.is_empty()
472 }
473
474 pub fn clear(&mut self) {
476 self.documents.clear();
477 }
478}
479
480#[derive(Debug, Clone)]
486pub struct RrfFusion {
487 k: f32,
489}
490
491impl Default for RrfFusion {
492 fn default() -> Self {
493 Self { k: 60.0 }
494 }
495}
496
497impl RrfFusion {
498 pub fn new() -> Self {
500 Self::default()
501 }
502
503 pub fn with_k(k: f32) -> Self {
505 Self { k }
506 }
507
508 pub fn fuse(&self, ranked_lists: &[(&[(String, f32)], f32)]) -> Vec<(String, f32)> {
512 let mut fused_scores: HashMap<String, f32> = HashMap::new();
513
514 for (results, weight) in ranked_lists {
515 for (rank, (doc_id, _score)) in results.iter().enumerate() {
516 let rrf_score = weight / (self.k + rank as f32 + 1.0);
518 *fused_scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
519 }
520 }
521
522 let mut results: Vec<_> = fused_scores.into_iter().collect();
524 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
525 results
526 }
527}
528
529#[derive(Debug, Default)]
535pub struct SearchEngine {
536 pub bm25: Bm25Index,
538 pub fuzzy: FuzzyIndex,
540 pub rrf: RrfFusion,
542}
543
544impl SearchEngine {
545 pub fn new() -> Self {
547 Self::default()
548 }
549
550 pub fn index_document(&mut self, doc: &Document) {
552 self.bm25.add_document(&doc.id, &doc.content);
553 self.fuzzy.add_document(&doc.id, &doc.content);
554 }
555
556 pub fn index_documents(&mut self, docs: &[Document]) {
558 for doc in docs {
559 self.index_document(doc);
560 }
561 }
562
563 pub fn remove_document(&mut self, id: &str) {
565 self.bm25.remove_document(id);
566 self.fuzzy.remove_document(id);
567 }
568
569 pub fn search_bm25(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
571 self.bm25.search(query, top_k)
572 }
573
574 pub fn search_fuzzy(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
576 self.fuzzy.search(query, top_k)
577 }
578
579 pub fn search_hybrid(
583 &self,
584 query: &str,
585 semantic_results: &[(String, f32)],
586 weights: &HybridWeights,
587 top_k: usize,
588 ) -> Vec<(String, f32)> {
589 let bm25_results = self.bm25.search(query, top_k * 2);
590 let fuzzy_results = self.fuzzy.search(query, top_k * 2);
591
592 let ranked_lists: Vec<(&[(String, f32)], f32)> = vec![
593 (semantic_results, weights.semantic),
594 (&bm25_results, weights.bm25),
595 (&fuzzy_results, weights.fuzzy),
596 ];
597
598 let mut fused = self.rrf.fuse(&ranked_lists);
599 fused.truncate(top_k);
600 fused
601 }
602
603 pub fn clear(&mut self) {
605 self.bm25.clear();
606 self.fuzzy.clear();
607 }
608
609 pub fn len(&self) -> usize {
611 self.bm25.len()
612 }
613
614 pub fn is_empty(&self) -> bool {
616 self.bm25.is_empty()
617 }
618}
619
620#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_search_strategy_from_str() {
630 assert_eq!(
631 "semantic".parse::<SearchStrategy>().unwrap(),
632 SearchStrategy::Semantic
633 );
634 assert_eq!(
635 "bm25".parse::<SearchStrategy>().unwrap(),
636 SearchStrategy::Bm25
637 );
638 assert_eq!(
639 "fuzzy".parse::<SearchStrategy>().unwrap(),
640 SearchStrategy::Fuzzy
641 );
642 assert_eq!(
643 "hybrid".parse::<SearchStrategy>().unwrap(),
644 SearchStrategy::Hybrid
645 );
646 }
647
648 #[test]
649 fn test_bm25_basic() {
650 let mut index = Bm25Index::new();
651 index.add_document("doc1", "The quick brown fox jumps over the lazy dog");
652 index.add_document("doc2", "A fast brown fox leaps over sleeping dogs");
653 index.add_document("doc3", "The cat sleeps on the mat");
654
655 let results = index.search("quick brown fox", 10);
656 assert!(!results.is_empty());
657 assert_eq!(results[0].0, "doc1"); }
659
660 #[test]
661 fn test_bm25_ranking() {
662 let mut index = Bm25Index::new();
663 index.add_document("doc1", "apple apple apple");
664 index.add_document("doc2", "apple banana");
665 index.add_document("doc3", "banana banana banana");
666
667 let results = index.search("apple", 10);
668 assert!(!results.is_empty());
669 assert_eq!(results[0].0, "doc1");
671 }
672
673 #[test]
674 fn test_bm25_remove_document() {
675 let mut index = Bm25Index::new();
676 index.add_document("doc1", "hello world");
677 index.add_document("doc2", "goodbye world");
678
679 assert_eq!(index.len(), 2);
680
681 index.remove_document("doc1");
682 assert_eq!(index.len(), 1);
683
684 let results = index.search("hello", 10);
685 assert!(results.is_empty()); }
687
688 #[test]
689 fn test_fuzzy_exact_match() {
690 let mut index = FuzzyIndex::new();
691 index.add_document("doc1", "machine learning algorithms");
692 index.add_document("doc2", "deep neural networks");
693
694 let results = index.search("machine", 10);
695 assert!(!results.is_empty());
696 assert_eq!(results[0].0, "doc1");
697 }
698
699 #[test]
700 fn test_fuzzy_typo_tolerance() {
701 let mut index = FuzzyIndex::with_max_distance(2);
702 index.add_document("doc1", "machine learning");
703 index.add_document("doc2", "deep learning");
704
705 let results = index.search("machne", 10);
707 assert!(!results.is_empty());
708 assert_eq!(results[0].0, "doc1");
709 }
710
711 #[test]
712 fn test_levenshtein_distance() {
713 assert_eq!(FuzzyIndex::levenshtein_distance("kitten", "sitting"), 3);
714 assert_eq!(FuzzyIndex::levenshtein_distance("hello", "hello"), 0);
715 assert_eq!(FuzzyIndex::levenshtein_distance("", "abc"), 3);
716 assert_eq!(FuzzyIndex::levenshtein_distance("abc", ""), 3);
717 }
718
719 #[test]
720 fn test_rrf_fusion() {
721 let rrf = RrfFusion::new();
722
723 let list1 = vec![
724 ("doc1".to_string(), 0.9),
725 ("doc2".to_string(), 0.8),
726 ("doc3".to_string(), 0.7),
727 ];
728
729 let list2 = vec![
730 ("doc2".to_string(), 0.95),
731 ("doc1".to_string(), 0.85),
732 ("doc4".to_string(), 0.75),
733 ];
734
735 let ranked_lists = vec![(&list1[..], 1.0), (&list2[..], 1.0)];
736 let fused = rrf.fuse(&ranked_lists);
737
738 assert!(!fused.is_empty());
740 let top_ids: Vec<_> = fused.iter().take(2).map(|(id, _)| id.clone()).collect();
741 assert!(top_ids.contains(&"doc1".to_string()));
742 assert!(top_ids.contains(&"doc2".to_string()));
743 }
744
745 #[test]
746 fn test_search_engine_integration() {
747 let mut engine = SearchEngine::new();
748
749 let docs = vec![
750 Document {
751 id: "doc1".to_string(),
752 content: "Rust programming language is fast and memory safe".to_string(),
753 metadata: Default::default(),
754 embedding: None,
755 },
756 Document {
757 id: "doc2".to_string(),
758 content: "Python is popular for machine learning and data science".to_string(),
759 metadata: Default::default(),
760 embedding: None,
761 },
762 Document {
763 id: "doc3".to_string(),
764 content: "JavaScript runs in web browsers".to_string(),
765 metadata: Default::default(),
766 embedding: None,
767 },
768 ];
769
770 engine.index_documents(&docs);
771 assert_eq!(engine.len(), 3);
772
773 let bm25_results = engine.search_bm25("Rust programming", 10);
775 assert!(!bm25_results.is_empty());
776 assert_eq!(bm25_results[0].0, "doc1");
777
778 let fuzzy_results = engine.search_fuzzy("rust", 10);
780 assert!(!fuzzy_results.is_empty(), "Fuzzy search should find 'rust'");
782 }
783
784 #[test]
785 fn test_hybrid_search() {
786 let mut engine = SearchEngine::new();
787
788 let docs = vec![
789 Document {
790 id: "doc1".to_string(),
791 content: "Vector databases enable semantic search".to_string(),
792 metadata: Default::default(),
793 embedding: None,
794 },
795 Document {
796 id: "doc2".to_string(),
797 content: "BM25 is a lexical search algorithm".to_string(),
798 metadata: Default::default(),
799 embedding: None,
800 },
801 ];
802
803 engine.index_documents(&docs);
804
805 let semantic_results = vec![
807 ("doc1".to_string(), 0.95),
808 ("doc2".to_string(), 0.80),
809 ];
810
811 let weights = HybridWeights {
812 semantic: 0.5,
813 bm25: 0.4,
814 fuzzy: 0.1,
815 };
816
817 let hybrid = engine.search_hybrid("vector search", &semantic_results, &weights, 10);
818 assert!(!hybrid.is_empty());
819 }
820
821 #[test]
822 fn test_hybrid_weights_default() {
823 let weights = HybridWeights::default();
824 assert!((weights.semantic - 0.6).abs() < 0.001);
825 assert!((weights.bm25 - 0.3).abs() < 0.001);
826 assert!((weights.fuzzy - 0.1).abs() < 0.001);
827 }
828}