1use std::collections::{HashMap, HashSet};
24use std::path::Path;
25use std::str::FromStr;
26
27use serde::{Deserialize, Serialize};
28
29use crate::types::{AppError, Document, Result};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
37#[serde(rename_all = "kebab-case")]
38pub enum SearchStrategy {
39 #[default]
41 Semantic,
42 Bm25,
44 Fuzzy,
46 Hybrid,
48}
49
50impl FromStr for SearchStrategy {
51 type Err = AppError;
52
53 fn from_str(s: &str) -> Result<Self> {
54 match s.to_lowercase().as_str() {
55 "semantic" | "dense" | "vector" => Ok(Self::Semantic),
56 "bm25" | "lexical" | "sparse" => Ok(Self::Bm25),
57 "fuzzy" | "approximate" => Ok(Self::Fuzzy),
58 "hybrid" | "combined" | "rrf" => Ok(Self::Hybrid),
59 _ => Err(AppError::Internal(format!(
60 "Unknown search strategy: {}. Use: semantic, bm25, fuzzy, hybrid",
61 s
62 ))),
63 }
64 }
65}
66
67impl std::fmt::Display for SearchStrategy {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 let name = match self {
70 Self::Semantic => "semantic",
71 Self::Bm25 => "bm25",
72 Self::Fuzzy => "fuzzy",
73 Self::Hybrid => "hybrid",
74 };
75 write!(f, "{}", name)
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct SearchResult {
86 pub id: String,
88 pub content: String,
90 pub score: f32,
92 pub sources: Vec<SearchStrategy>,
94 pub metadata: Option<serde_json::Value>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct QueryCorrection {
101 pub original: String,
103 pub corrected: String,
105 pub distance: usize,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct SearchRequest {
112 pub query: String,
114 #[serde(default)]
116 pub strategy: SearchStrategy,
117 #[serde(default = "default_top_k")]
119 pub top_k: usize,
120 #[serde(default)]
122 pub min_score: f32,
123 #[serde(default)]
125 pub rerank: bool,
126 pub collection: String,
128 #[serde(default)]
130 pub hybrid_weights: HybridWeights,
131}
132
133fn default_top_k() -> usize {
134 10
135}
136
137#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
139pub struct HybridWeights {
140 pub semantic: f32,
142 pub bm25: f32,
144 pub fuzzy: f32,
146}
147
148impl Default for HybridWeights {
149 fn default() -> Self {
150 Self {
151 semantic: 0.6,
152 bm25: 0.3,
153 fuzzy: 0.1,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct Bm25Index {
167 documents: HashMap<String, Vec<String>>,
169 inverted_index: HashMap<String, HashSet<String>>,
171 document_frequencies: HashMap<String, usize>,
173 doc_count: usize,
175 avg_doc_length: f32,
177 k1: f32,
179 b: f32,
181}
182
183impl Bm25Index {
184 pub fn new() -> Self {
186 Self {
187 k1: 1.2,
188 b: 0.75,
189 ..Default::default()
190 }
191 }
192
193 pub fn with_params(k1: f32, b: f32) -> Self {
195 Self {
196 k1,
197 b,
198 ..Default::default()
199 }
200 }
201
202 fn tokenize(text: &str) -> Vec<String> {
204 text.to_lowercase()
205 .split(|c: char| !c.is_alphanumeric())
206 .filter(|s| !s.is_empty() && s.len() > 1)
207 .map(String::from)
208 .collect()
209 }
210
211 pub fn add_document(&mut self, id: &str, content: &str) {
213 let tokens = Self::tokenize(content);
214
215 let unique_terms: HashSet<_> = tokens.iter().cloned().collect();
217 for term in &unique_terms {
218 *self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
219 self.inverted_index
220 .entry(term.clone())
221 .or_default()
222 .insert(id.to_string());
223 }
224
225 self.documents.insert(id.to_string(), tokens);
227 self.doc_count += 1;
228
229 let total_tokens: usize = self.documents.values().map(|v| v.len()).sum();
231 self.avg_doc_length = total_tokens as f32 / self.doc_count as f32;
232 }
233
234 pub fn remove_document(&mut self, id: &str) {
236 if let Some(tokens) = self.documents.remove(id) {
237 let unique_terms: HashSet<_> = tokens.into_iter().collect();
238 for term in unique_terms {
239 if let Some(df) = self.document_frequencies.get_mut(&term) {
240 *df = df.saturating_sub(1);
241 if *df == 0 {
242 self.document_frequencies.remove(&term);
243 }
244 }
245 if let Some(docs) = self.inverted_index.get_mut(&term) {
246 docs.remove(id);
247 if docs.is_empty() {
248 self.inverted_index.remove(&term);
249 }
250 }
251 }
252 self.doc_count = self.doc_count.saturating_sub(1);
253
254 if self.doc_count > 0 {
256 let total_tokens: usize = self.documents.values().map(|v| v.len()).sum();
257 self.avg_doc_length = total_tokens as f32 / self.doc_count as f32;
258 } else {
259 self.avg_doc_length = 0.0;
260 }
261 }
262 }
263
264 fn idf(&self, term: &str) -> f32 {
266 let df = self.document_frequencies.get(term).copied().unwrap_or(0) as f32;
267 let n = self.doc_count as f32;
268 if df == 0.0 || n == 0.0 {
269 return 0.0;
270 }
271 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
272 }
273
274 fn score_document(&self, doc_id: &str, query_terms: &[String]) -> f32 {
276 let doc_tokens = match self.documents.get(doc_id) {
277 Some(tokens) => tokens,
278 None => return 0.0,
279 };
280
281 let doc_len = doc_tokens.len() as f32;
282 let mut score = 0.0;
283
284 let mut term_freq: HashMap<&str, usize> = HashMap::new();
286 for token in doc_tokens {
287 *term_freq.entry(token.as_str()).or_insert(0) += 1;
288 }
289
290 for term in query_terms {
291 let tf = term_freq.get(term.as_str()).copied().unwrap_or(0) as f32;
292 let idf = self.idf(term);
293
294 let numerator = tf * (self.k1 + 1.0);
296 let denominator =
297 tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length);
298 score += idf * numerator / denominator;
299 }
300
301 score
302 }
303
304 pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
306 let query_terms = Self::tokenize(query);
307 if query_terms.is_empty() {
308 return Vec::new();
309 }
310
311 let mut candidates: HashSet<String> = HashSet::new();
313 for term in &query_terms {
314 if let Some(docs) = self.inverted_index.get(term) {
315 candidates.extend(docs.iter().cloned());
316 }
317 }
318
319 let mut results: Vec<(String, f32)> = candidates
321 .iter()
322 .map(|id| {
323 let score = self.score_document(id, &query_terms);
324 (id.clone(), score)
325 })
326 .filter(|(_, score)| *score > 0.0)
327 .collect();
328
329 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
331
332 results.truncate(top_k);
334 results
335 }
336
337 pub fn len(&self) -> usize {
339 self.doc_count
340 }
341
342 pub fn is_empty(&self) -> bool {
344 self.doc_count == 0
345 }
346
347 pub fn clear(&mut self) {
349 self.documents.clear();
350 self.inverted_index.clear();
351 self.document_frequencies.clear();
352 self.doc_count = 0;
353 self.avg_doc_length = 0.0;
354 }
355
356 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
362 let json = serde_json::to_string(self)
363 .map_err(|e| AppError::Internal(format!("Failed to serialize BM25 index: {}", e)))?;
364 std::fs::write(path, json)
365 .map_err(|e| AppError::Internal(format!("Failed to write BM25 index file: {}", e)))?;
366 Ok(())
367 }
368
369 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
375 let json = std::fs::read_to_string(path)
376 .map_err(|e| AppError::Internal(format!("Failed to read BM25 index file: {}", e)))?;
377 let index: Self = serde_json::from_str(&json)
378 .map_err(|e| AppError::Internal(format!("Failed to deserialize BM25 index: {}", e)))?;
379 Ok(index)
380 }
381
382 pub fn load_or_new<P: AsRef<Path>>(path: P) -> Self {
384 if path.as_ref().exists() {
385 Self::load(path).unwrap_or_else(|_| Self::new())
386 } else {
387 Self::new()
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct FuzzyIndex {
401 documents: HashMap<String, String>,
403 vocabulary: HashSet<String>,
405 max_distance: usize,
407}
408
409impl Default for FuzzyIndex {
410 fn default() -> Self {
411 Self {
412 documents: HashMap::new(),
413 vocabulary: HashSet::new(),
414 max_distance: 2,
415 }
416 }
417}
418
419impl FuzzyIndex {
420 pub fn new() -> Self {
422 Self::default()
423 }
424
425 pub fn with_max_distance(max_distance: usize) -> Self {
427 Self {
428 max_distance,
429 ..Default::default()
430 }
431 }
432
433 fn tokenize(text: &str) -> Vec<String> {
435 text.to_lowercase()
436 .split(|c: char| !c.is_alphanumeric())
437 .filter(|s| !s.is_empty() && s.len() > 1)
438 .map(String::from)
439 .collect()
440 }
441
442 pub fn add_document(&mut self, id: &str, content: &str) {
444 let lower_content = content.to_lowercase();
445
446 for word in Self::tokenize(&lower_content) {
448 self.vocabulary.insert(word);
449 }
450
451 self.documents.insert(id.to_string(), lower_content);
452 }
453
454 pub fn remove_document(&mut self, id: &str) {
456 self.documents.remove(id);
457 }
460
461 fn levenshtein_distance(s1: &str, s2: &str) -> usize {
463 let len1 = s1.chars().count();
464 let len2 = s2.chars().count();
465
466 if len1 == 0 {
467 return len2;
468 }
469 if len2 == 0 {
470 return len1;
471 }
472
473 let s1_chars: Vec<char> = s1.chars().collect();
474 let s2_chars: Vec<char> = s2.chars().collect();
475
476 let mut prev_row: Vec<usize> = (0..=len2).collect();
477 let mut curr_row = vec![0; len2 + 1];
478
479 for (i, c1) in s1_chars.iter().enumerate() {
480 curr_row[0] = i + 1;
481
482 for (j, c2) in s2_chars.iter().enumerate() {
483 let cost = if c1 == c2 { 0 } else { 1 };
484 curr_row[j + 1] = (prev_row[j + 1] + 1)
485 .min(curr_row[j] + 1)
486 .min(prev_row[j] + cost);
487 }
488
489 std::mem::swap(&mut prev_row, &mut curr_row);
490 }
491
492 prev_row[len2]
493 }
494
495 pub fn correct_word(&self, word: &str) -> Option<(String, usize)> {
498 let word_lower = word.to_lowercase();
499
500 if self.vocabulary.contains(&word_lower) {
502 return Some((word_lower, 0));
503 }
504
505 let mut best_match: Option<(String, usize)> = None;
507
508 for vocab_word in &self.vocabulary {
509 let len_diff = (word_lower.len() as isize - vocab_word.len() as isize).unsigned_abs();
511 if len_diff > self.max_distance {
512 continue;
513 }
514
515 let distance = Self::levenshtein_distance(&word_lower, vocab_word);
516 if distance <= self.max_distance {
517 match &best_match {
518 None => best_match = Some((vocab_word.clone(), distance)),
519 Some((_, best_dist)) if distance < *best_dist => {
520 best_match = Some((vocab_word.clone(), distance));
521 }
522 _ => {}
523 }
524 }
525 }
526
527 best_match
528 }
529
530 pub fn correct_query(&self, query: &str) -> (String, Vec<QueryCorrection>) {
533 let words = Self::tokenize(query);
534 let mut corrected_words = Vec::with_capacity(words.len());
535 let mut corrections = Vec::new();
536
537 for word in &words {
538 if let Some((corrected, distance)) = self.correct_word(word) {
539 if distance > 0 {
540 corrections.push(QueryCorrection {
541 original: word.clone(),
542 corrected: corrected.clone(),
543 distance,
544 });
545 }
546 corrected_words.push(corrected);
547 } else {
548 corrected_words.push(word.clone());
550 }
551 }
552
553 (corrected_words.join(" "), corrections)
554 }
555
556 fn fuzzy_score(query: &str, text: &str, max_distance: usize) -> f32 {
558 let query_lower = query.to_lowercase();
559 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
560
561 let mut total_score = 0.0;
563 let mut matched_words = 0;
564
565 for query_word in &query_words {
566 let mut best_score = 0.0f32;
567
568 for text_word in text.split_whitespace() {
569 if text_word.len() < 2 {
570 continue;
571 }
572
573 let distance = Self::levenshtein_distance(query_word, text_word);
574 if distance <= max_distance {
575 let max_len = query_word.len().max(text_word.len());
576 let score = 1.0 - (distance as f32 / max_len as f32);
577 best_score = best_score.max(score);
578 }
579 }
580
581 if best_score > 0.0 {
582 total_score += best_score;
583 matched_words += 1;
584 }
585 }
586
587 if matched_words > 0 {
588 (total_score / query_words.len() as f32)
589 * (matched_words as f32 / query_words.len() as f32)
590 } else {
591 0.0
592 }
593 }
594
595 pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
597 let mut results: Vec<(String, f32)> = self
598 .documents
599 .iter()
600 .filter_map(|(id, content)| {
601 let score = Self::fuzzy_score(query, content, self.max_distance);
602 if score > 0.0 {
603 Some((id.clone(), score))
604 } else {
605 None
606 }
607 })
608 .collect();
609
610 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
612
613 results.truncate(top_k);
615 results
616 }
617
618 pub fn len(&self) -> usize {
620 self.documents.len()
621 }
622
623 pub fn is_empty(&self) -> bool {
625 self.documents.is_empty()
626 }
627
628 pub fn clear(&mut self) {
630 self.documents.clear();
631 self.vocabulary.clear();
632 }
633
634 pub fn vocabulary_size(&self) -> usize {
636 self.vocabulary.len()
637 }
638
639 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
645 let json = serde_json::to_string(self)
646 .map_err(|e| AppError::Internal(format!("Failed to serialize fuzzy index: {}", e)))?;
647 std::fs::write(path, json)
648 .map_err(|e| AppError::Internal(format!("Failed to write fuzzy index file: {}", e)))?;
649 Ok(())
650 }
651
652 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
658 let json = std::fs::read_to_string(path)
659 .map_err(|e| AppError::Internal(format!("Failed to read fuzzy index file: {}", e)))?;
660 let index: Self = serde_json::from_str(&json)
661 .map_err(|e| AppError::Internal(format!("Failed to deserialize fuzzy index: {}", e)))?;
662 Ok(index)
663 }
664
665 pub fn load_or_new<P: AsRef<Path>>(path: P) -> Self {
667 if path.as_ref().exists() {
668 Self::load(path).unwrap_or_else(|_| Self::new())
669 } else {
670 Self::new()
671 }
672 }
673}
674
675#[derive(Debug, Clone)]
681pub struct RrfFusion {
682 k: f32,
684}
685
686impl Default for RrfFusion {
687 fn default() -> Self {
688 Self { k: 60.0 }
689 }
690}
691
692impl RrfFusion {
693 pub fn new() -> Self {
695 Self::default()
696 }
697
698 pub fn with_k(k: f32) -> Self {
700 Self { k }
701 }
702
703 pub fn fuse(&self, ranked_lists: &[(&[(String, f32)], f32)]) -> Vec<(String, f32)> {
707 let mut fused_scores: HashMap<String, f32> = HashMap::new();
708
709 for (results, weight) in ranked_lists {
710 for (rank, (doc_id, _score)) in results.iter().enumerate() {
711 let rrf_score = weight / (self.k + rank as f32 + 1.0);
713 *fused_scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
714 }
715 }
716
717 let mut results: Vec<_> = fused_scores.into_iter().collect();
719 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
720 results
721 }
722}
723
724#[derive(Debug, Default)]
734pub struct SearchEngine {
735 pub bm25: Bm25Index,
737 pub fuzzy: FuzzyIndex,
739 pub rrf: RrfFusion,
741}
742
743impl SearchEngine {
744 pub fn new() -> Self {
746 Self::default()
747 }
748
749 pub fn index_document(&mut self, doc: &Document) {
751 self.bm25.add_document(&doc.id, &doc.content);
752 self.fuzzy.add_document(&doc.id, &doc.content);
753 }
754
755 pub fn index_documents(&mut self, docs: &[Document]) {
757 for doc in docs {
758 self.index_document(doc);
759 }
760 }
761
762 pub fn remove_document(&mut self, id: &str) {
764 self.bm25.remove_document(id);
765 self.fuzzy.remove_document(id);
766 }
767
768 pub fn search_bm25(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
770 self.bm25.search(query, top_k)
771 }
772
773 pub fn search_fuzzy(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
775 self.fuzzy.search(query, top_k)
776 }
777
778 pub fn search_hybrid(
782 &self,
783 query: &str,
784 semantic_results: &[(String, f32)],
785 weights: &HybridWeights,
786 top_k: usize,
787 ) -> Vec<(String, f32)> {
788 let bm25_results = self.bm25.search(query, top_k * 2);
789 let fuzzy_results = self.fuzzy.search(query, top_k * 2);
790
791 let ranked_lists: Vec<(&[(String, f32)], f32)> = vec![
792 (semantic_results, weights.semantic),
793 (&bm25_results, weights.bm25),
794 (&fuzzy_results, weights.fuzzy),
795 ];
796
797 let mut fused = self.rrf.fuse(&ranked_lists);
798 fused.truncate(top_k);
799 fused
800 }
801
802 pub fn search_bm25_with_correction(
810 &self,
811 query: &str,
812 top_k: usize,
813 ) -> (Vec<(String, f32)>, String, Vec<QueryCorrection>) {
814 let (corrected_query, corrections) = self.fuzzy.correct_query(query);
815 let results = self.bm25.search(&corrected_query, top_k);
816 (results, corrected_query, corrections)
817 }
818
819 pub fn search_hybrid_with_correction(
829 &self,
830 query: &str,
831 semantic_results: &[(String, f32)],
832 weights: &HybridWeights,
833 top_k: usize,
834 ) -> (Vec<(String, f32)>, String, Vec<QueryCorrection>) {
835 let (corrected_query, corrections) = self.fuzzy.correct_query(query);
836 let results = self.search_hybrid(&corrected_query, semantic_results, weights, top_k);
837 (results, corrected_query, corrections)
838 }
839
840 pub fn clear(&mut self) {
842 self.bm25.clear();
843 self.fuzzy.clear();
844 }
845
846 pub fn len(&self) -> usize {
848 self.bm25.len()
849 }
850
851 pub fn is_empty(&self) -> bool {
853 self.bm25.is_empty()
854 }
855
856 pub fn save<P: AsRef<Path>>(&self, dir: P) -> Result<()> {
866 let dir = dir.as_ref();
867 std::fs::create_dir_all(dir).map_err(|e| {
868 AppError::Internal(format!("Failed to create search index directory: {}", e))
869 })?;
870
871 self.bm25.save(dir.join("bm25_index.json"))?;
872 self.fuzzy.save(dir.join("fuzzy_index.json"))?;
873
874 Ok(())
875 }
876
877 pub fn load<P: AsRef<Path>>(dir: P) -> Result<Self> {
883 let dir = dir.as_ref();
884 let bm25 = Bm25Index::load(dir.join("bm25_index.json"))?;
885 let fuzzy = FuzzyIndex::load(dir.join("fuzzy_index.json"))?;
886
887 Ok(Self {
888 bm25,
889 fuzzy,
890 rrf: RrfFusion::default(),
891 })
892 }
893
894 pub fn load_or_new<P: AsRef<Path>>(dir: P) -> Self {
896 let dir = dir.as_ref();
897 if dir.exists() {
898 Self::load(dir).unwrap_or_else(|_| Self::new())
899 } else {
900 Self::new()
901 }
902 }
903}
904
905#[cfg(test)]
910mod tests {
911 use super::*;
912
913 #[test]
914 fn test_search_strategy_from_str() {
915 assert_eq!(
916 "semantic".parse::<SearchStrategy>().unwrap(),
917 SearchStrategy::Semantic
918 );
919 assert_eq!(
920 "bm25".parse::<SearchStrategy>().unwrap(),
921 SearchStrategy::Bm25
922 );
923 assert_eq!(
924 "fuzzy".parse::<SearchStrategy>().unwrap(),
925 SearchStrategy::Fuzzy
926 );
927 assert_eq!(
928 "hybrid".parse::<SearchStrategy>().unwrap(),
929 SearchStrategy::Hybrid
930 );
931 }
932
933 #[test]
934 fn test_bm25_basic() {
935 let mut index = Bm25Index::new();
936 index.add_document("doc1", "The quick brown fox jumps over the lazy dog");
937 index.add_document("doc2", "A fast brown fox leaps over sleeping dogs");
938 index.add_document("doc3", "The cat sleeps on the mat");
939
940 let results = index.search("quick brown fox", 10);
941 assert!(!results.is_empty());
942 assert_eq!(results[0].0, "doc1"); }
944
945 #[test]
946 fn test_bm25_ranking() {
947 let mut index = Bm25Index::new();
948 index.add_document("doc1", "apple apple apple");
949 index.add_document("doc2", "apple banana");
950 index.add_document("doc3", "banana banana banana");
951
952 let results = index.search("apple", 10);
953 assert!(!results.is_empty());
954 assert_eq!(results[0].0, "doc1");
956 }
957
958 #[test]
959 fn test_bm25_remove_document() {
960 let mut index = Bm25Index::new();
961 index.add_document("doc1", "hello world");
962 index.add_document("doc2", "goodbye world");
963
964 assert_eq!(index.len(), 2);
965
966 index.remove_document("doc1");
967 assert_eq!(index.len(), 1);
968
969 let results = index.search("hello", 10);
970 assert!(results.is_empty()); }
972
973 #[test]
974 fn test_fuzzy_exact_match() {
975 let mut index = FuzzyIndex::new();
976 index.add_document("doc1", "machine learning algorithms");
977 index.add_document("doc2", "deep neural networks");
978
979 let results = index.search("machine", 10);
980 assert!(!results.is_empty());
981 assert_eq!(results[0].0, "doc1");
982 }
983
984 #[test]
985 fn test_fuzzy_typo_tolerance() {
986 let mut index = FuzzyIndex::with_max_distance(2);
987 index.add_document("doc1", "machine learning");
988 index.add_document("doc2", "deep learning");
989
990 let results = index.search("machne", 10);
992 assert!(!results.is_empty());
993 assert_eq!(results[0].0, "doc1");
994 }
995
996 #[test]
997 fn test_levenshtein_distance() {
998 assert_eq!(FuzzyIndex::levenshtein_distance("kitten", "sitting"), 3);
999 assert_eq!(FuzzyIndex::levenshtein_distance("hello", "hello"), 0);
1000 assert_eq!(FuzzyIndex::levenshtein_distance("", "abc"), 3);
1001 assert_eq!(FuzzyIndex::levenshtein_distance("abc", ""), 3);
1002 }
1003
1004 #[test]
1005 fn test_rrf_fusion() {
1006 let rrf = RrfFusion::new();
1007
1008 let list1 = [
1009 ("doc1".to_string(), 0.9),
1010 ("doc2".to_string(), 0.8),
1011 ("doc3".to_string(), 0.7),
1012 ];
1013
1014 let list2 = [
1015 ("doc2".to_string(), 0.95),
1016 ("doc1".to_string(), 0.85),
1017 ("doc4".to_string(), 0.75),
1018 ];
1019
1020 let ranked_lists = vec![(&list1[..], 1.0), (&list2[..], 1.0)];
1021 let fused = rrf.fuse(&ranked_lists);
1022
1023 assert!(!fused.is_empty());
1025 let top_ids: Vec<_> = fused.iter().take(2).map(|(id, _)| id.clone()).collect();
1026 assert!(top_ids.contains(&"doc1".to_string()));
1027 assert!(top_ids.contains(&"doc2".to_string()));
1028 }
1029
1030 #[test]
1031 fn test_search_engine_integration() {
1032 let mut engine = SearchEngine::new();
1033
1034 let docs = vec![
1035 Document {
1036 id: "doc1".to_string(),
1037 content: "Rust programming language is fast and memory safe".to_string(),
1038 metadata: Default::default(),
1039 embedding: None,
1040 },
1041 Document {
1042 id: "doc2".to_string(),
1043 content: "Python is popular for machine learning and data science".to_string(),
1044 metadata: Default::default(),
1045 embedding: None,
1046 },
1047 Document {
1048 id: "doc3".to_string(),
1049 content: "JavaScript runs in web browsers".to_string(),
1050 metadata: Default::default(),
1051 embedding: None,
1052 },
1053 ];
1054
1055 engine.index_documents(&docs);
1056 assert_eq!(engine.len(), 3);
1057
1058 let bm25_results = engine.search_bm25("Rust programming", 10);
1060 assert!(!bm25_results.is_empty());
1061 assert_eq!(bm25_results[0].0, "doc1");
1062
1063 let fuzzy_results = engine.search_fuzzy("rust", 10);
1065 assert!(!fuzzy_results.is_empty(), "Fuzzy search should find 'rust'");
1067 }
1068
1069 #[test]
1070 fn test_hybrid_search() {
1071 let mut engine = SearchEngine::new();
1072
1073 let docs = vec![
1074 Document {
1075 id: "doc1".to_string(),
1076 content: "Vector databases enable semantic search".to_string(),
1077 metadata: Default::default(),
1078 embedding: None,
1079 },
1080 Document {
1081 id: "doc2".to_string(),
1082 content: "BM25 is a lexical search algorithm".to_string(),
1083 metadata: Default::default(),
1084 embedding: None,
1085 },
1086 ];
1087
1088 engine.index_documents(&docs);
1089
1090 let semantic_results = vec![("doc1".to_string(), 0.95), ("doc2".to_string(), 0.80)];
1092
1093 let weights = HybridWeights {
1094 semantic: 0.5,
1095 bm25: 0.4,
1096 fuzzy: 0.1,
1097 };
1098
1099 let hybrid = engine.search_hybrid("vector search", &semantic_results, &weights, 10);
1100 assert!(!hybrid.is_empty());
1101 }
1102
1103 #[test]
1104 fn test_hybrid_weights_default() {
1105 let weights = HybridWeights::default();
1106 assert!((weights.semantic - 0.6).abs() < 0.001);
1107 assert!((weights.bm25 - 0.3).abs() < 0.001);
1108 assert!((weights.fuzzy - 0.1).abs() < 0.001);
1109 }
1110
1111 #[test]
1116 fn test_correct_word_exact_match() {
1117 let mut index = FuzzyIndex::new();
1118 index.add_document("doc1", "programming language");
1119
1120 let result = index.correct_word("programming");
1122 assert!(result.is_some());
1123 let (corrected, distance) = result.unwrap();
1124 assert_eq!(corrected, "programming");
1125 assert_eq!(distance, 0);
1126 }
1127
1128 #[test]
1129 fn test_correct_word_with_typo() {
1130 let mut index = FuzzyIndex::new();
1131 index.add_document("doc1", "programming language");
1132
1133 let result = index.correct_word("progamming");
1135 assert!(result.is_some());
1136 let (corrected, distance) = result.unwrap();
1137 assert_eq!(corrected, "programming");
1138 assert_eq!(distance, 1);
1139 }
1140
1141 #[test]
1142 fn test_correct_word_no_match() {
1143 let mut index = FuzzyIndex::new();
1144 index.add_document("doc1", "programming language");
1145
1146 let result = index.correct_word("xyz");
1148 assert!(result.is_none());
1149 }
1150
1151 #[test]
1152 fn test_correct_query_single_typo() {
1153 let mut index = FuzzyIndex::new();
1154 index.add_document("doc1", "rust programming language");
1155
1156 let (corrected, corrections) = index.correct_query("progamming");
1157 assert_eq!(corrected, "programming");
1158 assert_eq!(corrections.len(), 1);
1159 assert_eq!(corrections[0].original, "progamming");
1160 assert_eq!(corrections[0].corrected, "programming");
1161 assert_eq!(corrections[0].distance, 1);
1162 }
1163
1164 #[test]
1165 fn test_correct_query_multiple_typos() {
1166 let mut index = FuzzyIndex::new();
1167 index.add_document("doc1", "rust programming language");
1168
1169 let (corrected, corrections) = index.correct_query("progamming languge");
1171 assert_eq!(corrected, "programming language");
1172 assert_eq!(corrections.len(), 2);
1173 }
1174
1175 #[test]
1176 fn test_correct_query_no_typos() {
1177 let mut index = FuzzyIndex::new();
1178 index.add_document("doc1", "rust programming language");
1179
1180 let (corrected, corrections) = index.correct_query("programming language");
1182 assert_eq!(corrected, "programming language");
1183 assert!(corrections.is_empty());
1184 }
1185
1186 #[test]
1187 fn test_search_bm25_with_correction() {
1188 let mut engine = SearchEngine::new();
1189
1190 let docs = vec![
1191 Document {
1192 id: "doc1".to_string(),
1193 content: "Rust is a systems programming language".to_string(),
1194 metadata: Default::default(),
1195 embedding: None,
1196 },
1197 Document {
1198 id: "doc2".to_string(),
1199 content: "Python is popular for scripting".to_string(),
1200 metadata: Default::default(),
1201 embedding: None,
1202 },
1203 ];
1204
1205 engine.index_documents(&docs);
1206
1207 let (results, corrected_query, corrections) =
1209 engine.search_bm25_with_correction("progamming", 10);
1210
1211 assert!(!results.is_empty());
1213 assert_eq!(results[0].0, "doc1");
1214 assert_eq!(corrected_query, "programming");
1215 assert_eq!(corrections.len(), 1);
1216 assert_eq!(corrections[0].original, "progamming");
1217 assert_eq!(corrections[0].corrected, "programming");
1218 }
1219
1220 #[test]
1221 fn test_vocabulary_cleared() {
1222 let mut index = FuzzyIndex::new();
1223 index.add_document("doc1", "programming language");
1224
1225 assert!(index.vocabulary_size() > 0);
1226
1227 index.clear();
1228
1229 assert_eq!(index.vocabulary_size(), 0);
1230 assert!(index.is_empty());
1231 }
1232
1233 #[test]
1234 fn test_typo_correction_case_insensitive() {
1235 let mut index = FuzzyIndex::new();
1236 index.add_document("doc1", "Programming Language");
1237
1238 let result = index.correct_word("PROGAMMING");
1240 assert!(result.is_some());
1241 let (corrected, _) = result.unwrap();
1242 assert_eq!(corrected, "programming"); }
1244
1245 #[test]
1250 fn test_bm25_save_load() {
1251 let temp_dir = std::env::temp_dir().join("ares_test_bm25");
1252 let _ = std::fs::remove_dir_all(&temp_dir);
1253 std::fs::create_dir_all(&temp_dir).unwrap();
1254 let path = temp_dir.join("bm25_index.json");
1255
1256 let mut index = Bm25Index::new();
1258 index.add_document("doc1", "The quick brown fox");
1259 index.add_document("doc2", "A lazy dog sleeps");
1260 assert_eq!(index.len(), 2);
1261
1262 index.save(&path).unwrap();
1264
1265 let loaded = Bm25Index::load(&path).unwrap();
1267 assert_eq!(loaded.len(), 2);
1268
1269 let results = loaded.search("quick brown", 10);
1271 assert!(!results.is_empty());
1272 assert_eq!(results[0].0, "doc1");
1273
1274 let _ = std::fs::remove_dir_all(&temp_dir);
1276 }
1277
1278 #[test]
1279 fn test_fuzzy_save_load() {
1280 let temp_dir = std::env::temp_dir().join("ares_test_fuzzy");
1281 let _ = std::fs::remove_dir_all(&temp_dir);
1282 std::fs::create_dir_all(&temp_dir).unwrap();
1283 let path = temp_dir.join("fuzzy_index.json");
1284
1285 let mut index = FuzzyIndex::new();
1287 index.add_document("doc1", "machine learning algorithms");
1288 index.add_document("doc2", "deep neural networks");
1289 assert_eq!(index.len(), 2);
1290
1291 index.save(&path).unwrap();
1293
1294 let loaded = FuzzyIndex::load(&path).unwrap();
1296 assert_eq!(loaded.len(), 2);
1297 assert_eq!(loaded.vocabulary_size(), index.vocabulary_size());
1298
1299 let results = loaded.search("machine", 10);
1301 assert!(!results.is_empty());
1302 assert_eq!(results[0].0, "doc1");
1303
1304 let _ = std::fs::remove_dir_all(&temp_dir);
1306 }
1307
1308 #[test]
1309 fn test_search_engine_save_load() {
1310 let temp_dir = std::env::temp_dir().join("ares_test_engine");
1311 let _ = std::fs::remove_dir_all(&temp_dir);
1312
1313 let mut engine = SearchEngine::new();
1315 let docs = vec![
1316 Document {
1317 id: "doc1".to_string(),
1318 content: "Rust programming language".to_string(),
1319 metadata: Default::default(),
1320 embedding: None,
1321 },
1322 Document {
1323 id: "doc2".to_string(),
1324 content: "Python scripting language".to_string(),
1325 metadata: Default::default(),
1326 embedding: None,
1327 },
1328 ];
1329 engine.index_documents(&docs);
1330 assert_eq!(engine.len(), 2);
1331
1332 engine.save(&temp_dir).unwrap();
1334
1335 let loaded = SearchEngine::load(&temp_dir).unwrap();
1337 assert_eq!(loaded.len(), 2);
1338
1339 let bm25_results = loaded.search_bm25("Rust programming", 10);
1341 assert!(!bm25_results.is_empty());
1342 assert_eq!(bm25_results[0].0, "doc1");
1343
1344 let fuzzy_results = loaded.search_fuzzy("rust", 10);
1346 assert!(!fuzzy_results.is_empty());
1347
1348 let _ = std::fs::remove_dir_all(&temp_dir);
1350 }
1351
1352 #[test]
1353 fn test_load_or_new_missing_file() {
1354 let path = std::env::temp_dir().join("nonexistent_bm25_index.json");
1355 let _ = std::fs::remove_file(&path); let index = Bm25Index::load_or_new(&path);
1359 assert!(index.is_empty());
1360 }
1361
1362 #[test]
1363 fn test_search_engine_load_or_new() {
1364 let temp_dir = std::env::temp_dir().join("ares_test_load_or_new");
1365 let _ = std::fs::remove_dir_all(&temp_dir); let engine = SearchEngine::load_or_new(&temp_dir);
1369 assert!(engine.is_empty());
1370 }
1371}