use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use crate::types::{AppError, Document, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum SearchStrategy {
#[default]
Semantic,
Bm25,
Fuzzy,
Hybrid,
}
impl FromStr for SearchStrategy {
type Err = AppError;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"semantic" | "dense" | "vector" => Ok(Self::Semantic),
"bm25" | "lexical" | "sparse" => Ok(Self::Bm25),
"fuzzy" | "approximate" => Ok(Self::Fuzzy),
"hybrid" | "combined" | "rrf" => Ok(Self::Hybrid),
_ => Err(AppError::Internal(format!(
"Unknown search strategy: {}. Use: semantic, bm25, fuzzy, hybrid",
s
))),
}
}
}
impl std::fmt::Display for SearchStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::Semantic => "semantic",
Self::Bm25 => "bm25",
Self::Fuzzy => "fuzzy",
Self::Hybrid => "hybrid",
};
write!(f, "{}", name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub content: String,
pub score: f32,
pub sources: Vec<SearchStrategy>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryCorrection {
pub original: String,
pub corrected: String,
pub distance: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchRequest {
pub query: String,
#[serde(default)]
pub strategy: SearchStrategy,
#[serde(default = "default_top_k")]
pub top_k: usize,
#[serde(default)]
pub min_score: f32,
#[serde(default)]
pub rerank: bool,
pub collection: String,
#[serde(default)]
pub hybrid_weights: HybridWeights,
}
fn default_top_k() -> usize {
10
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct HybridWeights {
pub semantic: f32,
pub bm25: f32,
pub fuzzy: f32,
}
impl Default for HybridWeights {
fn default() -> Self {
Self {
semantic: 0.6,
bm25: 0.3,
fuzzy: 0.1,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Bm25Index {
documents: HashMap<String, Vec<String>>,
inverted_index: HashMap<String, HashSet<String>>,
document_frequencies: HashMap<String, usize>,
doc_count: usize,
avg_doc_length: f32,
k1: f32,
b: f32,
}
impl Bm25Index {
pub fn new() -> Self {
Self {
k1: 1.2,
b: 0.75,
..Default::default()
}
}
pub fn with_params(k1: f32, b: f32) -> Self {
Self {
k1,
b,
..Default::default()
}
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(String::from)
.collect()
}
pub fn add_document(&mut self, id: &str, content: &str) {
let tokens = Self::tokenize(content);
let unique_terms: HashSet<_> = tokens.iter().cloned().collect();
for term in &unique_terms {
*self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
self.inverted_index
.entry(term.clone())
.or_default()
.insert(id.to_string());
}
self.documents.insert(id.to_string(), tokens);
self.doc_count += 1;
let total_tokens: usize = self.documents.values().map(|v| v.len()).sum();
self.avg_doc_length = total_tokens as f32 / self.doc_count as f32;
}
pub fn remove_document(&mut self, id: &str) {
if let Some(tokens) = self.documents.remove(id) {
let unique_terms: HashSet<_> = tokens.into_iter().collect();
for term in unique_terms {
if let Some(df) = self.document_frequencies.get_mut(&term) {
*df = df.saturating_sub(1);
if *df == 0 {
self.document_frequencies.remove(&term);
}
}
if let Some(docs) = self.inverted_index.get_mut(&term) {
docs.remove(id);
if docs.is_empty() {
self.inverted_index.remove(&term);
}
}
}
self.doc_count = self.doc_count.saturating_sub(1);
if self.doc_count > 0 {
let total_tokens: usize = self.documents.values().map(|v| v.len()).sum();
self.avg_doc_length = total_tokens as f32 / self.doc_count as f32;
} else {
self.avg_doc_length = 0.0;
}
}
}
fn idf(&self, term: &str) -> f32 {
let df = self.document_frequencies.get(term).copied().unwrap_or(0) as f32;
let n = self.doc_count as f32;
if df == 0.0 || n == 0.0 {
return 0.0;
}
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
fn score_document(&self, doc_id: &str, query_terms: &[String]) -> f32 {
let doc_tokens = match self.documents.get(doc_id) {
Some(tokens) => tokens,
None => return 0.0,
};
let doc_len = doc_tokens.len() as f32;
let mut score = 0.0;
let mut term_freq: HashMap<&str, usize> = HashMap::new();
for token in doc_tokens {
*term_freq.entry(token.as_str()).or_insert(0) += 1;
}
for term in query_terms {
let tf = term_freq.get(term.as_str()).copied().unwrap_or(0) as f32;
let idf = self.idf(term);
let numerator = tf * (self.k1 + 1.0);
let denominator =
tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length);
score += idf * numerator / denominator;
}
score
}
pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
let query_terms = Self::tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let mut candidates: HashSet<String> = HashSet::new();
for term in &query_terms {
if let Some(docs) = self.inverted_index.get(term) {
candidates.extend(docs.iter().cloned());
}
}
let mut results: Vec<(String, f32)> = candidates
.iter()
.map(|id| {
let score = self.score_document(id, &query_terms);
(id.clone(), score)
})
.filter(|(_, score)| *score > 0.0)
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
results
}
pub fn len(&self) -> usize {
self.doc_count
}
pub fn is_empty(&self) -> bool {
self.doc_count == 0
}
pub fn clear(&mut self) {
self.documents.clear();
self.inverted_index.clear();
self.document_frequencies.clear();
self.doc_count = 0;
self.avg_doc_length = 0.0;
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = serde_json::to_string(self)
.map_err(|e| AppError::Internal(format!("Failed to serialize BM25 index: {}", e)))?;
std::fs::write(path, json)
.map_err(|e| AppError::Internal(format!("Failed to write BM25 index file: {}", e)))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let json = std::fs::read_to_string(path)
.map_err(|e| AppError::Internal(format!("Failed to read BM25 index file: {}", e)))?;
let index: Self = serde_json::from_str(&json)
.map_err(|e| AppError::Internal(format!("Failed to deserialize BM25 index: {}", e)))?;
Ok(index)
}
pub fn load_or_new<P: AsRef<Path>>(path: P) -> Self {
if path.as_ref().exists() {
Self::load(path).unwrap_or_else(|_| Self::new())
} else {
Self::new()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FuzzyIndex {
documents: HashMap<String, String>,
vocabulary: HashSet<String>,
max_distance: usize,
}
impl Default for FuzzyIndex {
fn default() -> Self {
Self {
documents: HashMap::new(),
vocabulary: HashSet::new(),
max_distance: 2,
}
}
}
impl FuzzyIndex {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_distance(max_distance: usize) -> Self {
Self {
max_distance,
..Default::default()
}
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(String::from)
.collect()
}
pub fn add_document(&mut self, id: &str, content: &str) {
let lower_content = content.to_lowercase();
for word in Self::tokenize(&lower_content) {
self.vocabulary.insert(word);
}
self.documents.insert(id.to_string(), lower_content);
}
pub fn remove_document(&mut self, id: &str) {
self.documents.remove(id);
}
fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
if len1 == 0 {
return len2;
}
if len2 == 0 {
return len1;
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
let mut prev_row: Vec<usize> = (0..=len2).collect();
let mut curr_row = vec![0; len2 + 1];
for (i, c1) in s1_chars.iter().enumerate() {
curr_row[0] = i + 1;
for (j, c2) in s2_chars.iter().enumerate() {
let cost = if c1 == c2 { 0 } else { 1 };
curr_row[j + 1] = (prev_row[j + 1] + 1)
.min(curr_row[j] + 1)
.min(prev_row[j] + cost);
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
prev_row[len2]
}
pub fn correct_word(&self, word: &str) -> Option<(String, usize)> {
let word_lower = word.to_lowercase();
if self.vocabulary.contains(&word_lower) {
return Some((word_lower, 0));
}
let mut best_match: Option<(String, usize)> = None;
for vocab_word in &self.vocabulary {
let len_diff = (word_lower.len() as isize - vocab_word.len() as isize).unsigned_abs();
if len_diff > self.max_distance {
continue;
}
let distance = Self::levenshtein_distance(&word_lower, vocab_word);
if distance <= self.max_distance {
match &best_match {
None => best_match = Some((vocab_word.clone(), distance)),
Some((_, best_dist)) if distance < *best_dist => {
best_match = Some((vocab_word.clone(), distance));
}
_ => {}
}
}
}
best_match
}
pub fn correct_query(&self, query: &str) -> (String, Vec<QueryCorrection>) {
let words = Self::tokenize(query);
let mut corrected_words = Vec::with_capacity(words.len());
let mut corrections = Vec::new();
for word in &words {
if let Some((corrected, distance)) = self.correct_word(word) {
if distance > 0 {
corrections.push(QueryCorrection {
original: word.clone(),
corrected: corrected.clone(),
distance,
});
}
corrected_words.push(corrected);
} else {
corrected_words.push(word.clone());
}
}
(corrected_words.join(" "), corrections)
}
fn fuzzy_score(query: &str, text: &str, max_distance: usize) -> f32 {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
let mut total_score = 0.0;
let mut matched_words = 0;
for query_word in &query_words {
let mut best_score = 0.0f32;
for text_word in text.split_whitespace() {
if text_word.len() < 2 {
continue;
}
let distance = Self::levenshtein_distance(query_word, text_word);
if distance <= max_distance {
let max_len = query_word.len().max(text_word.len());
let score = 1.0 - (distance as f32 / max_len as f32);
best_score = best_score.max(score);
}
}
if best_score > 0.0 {
total_score += best_score;
matched_words += 1;
}
}
if matched_words > 0 {
(total_score / query_words.len() as f32)
* (matched_words as f32 / query_words.len() as f32)
} else {
0.0
}
}
pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
let mut results: Vec<(String, f32)> = self
.documents
.iter()
.filter_map(|(id, content)| {
let score = Self::fuzzy_score(query, content, self.max_distance);
if score > 0.0 {
Some((id.clone(), score))
} else {
None
}
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
results
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
pub fn clear(&mut self) {
self.documents.clear();
self.vocabulary.clear();
}
pub fn vocabulary_size(&self) -> usize {
self.vocabulary.len()
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = serde_json::to_string(self)
.map_err(|e| AppError::Internal(format!("Failed to serialize fuzzy index: {}", e)))?;
std::fs::write(path, json)
.map_err(|e| AppError::Internal(format!("Failed to write fuzzy index file: {}", e)))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let json = std::fs::read_to_string(path)
.map_err(|e| AppError::Internal(format!("Failed to read fuzzy index file: {}", e)))?;
let index: Self = serde_json::from_str(&json)
.map_err(|e| AppError::Internal(format!("Failed to deserialize fuzzy index: {}", e)))?;
Ok(index)
}
pub fn load_or_new<P: AsRef<Path>>(path: P) -> Self {
if path.as_ref().exists() {
Self::load(path).unwrap_or_else(|_| Self::new())
} else {
Self::new()
}
}
}
#[derive(Debug, Clone)]
pub struct RrfFusion {
k: f32,
}
impl Default for RrfFusion {
fn default() -> Self {
Self { k: 60.0 }
}
}
impl RrfFusion {
pub fn new() -> Self {
Self::default()
}
pub fn with_k(k: f32) -> Self {
Self { k }
}
pub fn fuse(&self, ranked_lists: &[(&[(String, f32)], f32)]) -> Vec<(String, f32)> {
let mut fused_scores: HashMap<String, f32> = HashMap::new();
for (results, weight) in ranked_lists {
for (rank, (doc_id, _score)) in results.iter().enumerate() {
let rrf_score = weight / (self.k + rank as f32 + 1.0);
*fused_scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
}
}
let mut results: Vec<_> = fused_scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
}
#[derive(Debug, Default)]
pub struct SearchEngine {
pub bm25: Bm25Index,
pub fuzzy: FuzzyIndex,
pub rrf: RrfFusion,
}
impl SearchEngine {
pub fn new() -> Self {
Self::default()
}
pub fn index_document(&mut self, doc: &Document) {
self.bm25.add_document(&doc.id, &doc.content);
self.fuzzy.add_document(&doc.id, &doc.content);
}
pub fn index_documents(&mut self, docs: &[Document]) {
for doc in docs {
self.index_document(doc);
}
}
pub fn remove_document(&mut self, id: &str) {
self.bm25.remove_document(id);
self.fuzzy.remove_document(id);
}
pub fn search_bm25(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
self.bm25.search(query, top_k)
}
pub fn search_fuzzy(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
self.fuzzy.search(query, top_k)
}
pub fn search_hybrid(
&self,
query: &str,
semantic_results: &[(String, f32)],
weights: &HybridWeights,
top_k: usize,
) -> Vec<(String, f32)> {
let bm25_results = self.bm25.search(query, top_k * 2);
let fuzzy_results = self.fuzzy.search(query, top_k * 2);
let ranked_lists: Vec<(&[(String, f32)], f32)> = vec![
(semantic_results, weights.semantic),
(&bm25_results, weights.bm25),
(&fuzzy_results, weights.fuzzy),
];
let mut fused = self.rrf.fuse(&ranked_lists);
fused.truncate(top_k);
fused
}
pub fn search_bm25_with_correction(
&self,
query: &str,
top_k: usize,
) -> (Vec<(String, f32)>, String, Vec<QueryCorrection>) {
let (corrected_query, corrections) = self.fuzzy.correct_query(query);
let results = self.bm25.search(&corrected_query, top_k);
(results, corrected_query, corrections)
}
pub fn search_hybrid_with_correction(
&self,
query: &str,
semantic_results: &[(String, f32)],
weights: &HybridWeights,
top_k: usize,
) -> (Vec<(String, f32)>, String, Vec<QueryCorrection>) {
let (corrected_query, corrections) = self.fuzzy.correct_query(query);
let results = self.search_hybrid(&corrected_query, semantic_results, weights, top_k);
(results, corrected_query, corrections)
}
pub fn clear(&mut self) {
self.bm25.clear();
self.fuzzy.clear();
}
pub fn len(&self) -> usize {
self.bm25.len()
}
pub fn is_empty(&self) -> bool {
self.bm25.is_empty()
}
pub fn save<P: AsRef<Path>>(&self, dir: P) -> Result<()> {
let dir = dir.as_ref();
std::fs::create_dir_all(dir).map_err(|e| {
AppError::Internal(format!("Failed to create search index directory: {}", e))
})?;
self.bm25.save(dir.join("bm25_index.json"))?;
self.fuzzy.save(dir.join("fuzzy_index.json"))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let bm25 = Bm25Index::load(dir.join("bm25_index.json"))?;
let fuzzy = FuzzyIndex::load(dir.join("fuzzy_index.json"))?;
Ok(Self {
bm25,
fuzzy,
rrf: RrfFusion::default(),
})
}
pub fn load_or_new<P: AsRef<Path>>(dir: P) -> Self {
let dir = dir.as_ref();
if dir.exists() {
Self::load(dir).unwrap_or_else(|_| Self::new())
} else {
Self::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_strategy_from_str() {
assert_eq!(
"semantic".parse::<SearchStrategy>().unwrap(),
SearchStrategy::Semantic
);
assert_eq!(
"bm25".parse::<SearchStrategy>().unwrap(),
SearchStrategy::Bm25
);
assert_eq!(
"fuzzy".parse::<SearchStrategy>().unwrap(),
SearchStrategy::Fuzzy
);
assert_eq!(
"hybrid".parse::<SearchStrategy>().unwrap(),
SearchStrategy::Hybrid
);
}
#[test]
fn test_bm25_basic() {
let mut index = Bm25Index::new();
index.add_document("doc1", "The quick brown fox jumps over the lazy dog");
index.add_document("doc2", "A fast brown fox leaps over sleeping dogs");
index.add_document("doc3", "The cat sleeps on the mat");
let results = index.search("quick brown fox", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1"); }
#[test]
fn test_bm25_ranking() {
let mut index = Bm25Index::new();
index.add_document("doc1", "apple apple apple");
index.add_document("doc2", "apple banana");
index.add_document("doc3", "banana banana banana");
let results = index.search("apple", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_bm25_remove_document() {
let mut index = Bm25Index::new();
index.add_document("doc1", "hello world");
index.add_document("doc2", "goodbye world");
assert_eq!(index.len(), 2);
index.remove_document("doc1");
assert_eq!(index.len(), 1);
let results = index.search("hello", 10);
assert!(results.is_empty()); }
#[test]
fn test_fuzzy_exact_match() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "machine learning algorithms");
index.add_document("doc2", "deep neural networks");
let results = index.search("machine", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_fuzzy_typo_tolerance() {
let mut index = FuzzyIndex::with_max_distance(2);
index.add_document("doc1", "machine learning");
index.add_document("doc2", "deep learning");
let results = index.search("machne", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_levenshtein_distance() {
assert_eq!(FuzzyIndex::levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(FuzzyIndex::levenshtein_distance("hello", "hello"), 0);
assert_eq!(FuzzyIndex::levenshtein_distance("", "abc"), 3);
assert_eq!(FuzzyIndex::levenshtein_distance("abc", ""), 3);
}
#[test]
fn test_rrf_fusion() {
let rrf = RrfFusion::new();
let list1 = [
("doc1".to_string(), 0.9),
("doc2".to_string(), 0.8),
("doc3".to_string(), 0.7),
];
let list2 = [
("doc2".to_string(), 0.95),
("doc1".to_string(), 0.85),
("doc4".to_string(), 0.75),
];
let ranked_lists = vec![(&list1[..], 1.0), (&list2[..], 1.0)];
let fused = rrf.fuse(&ranked_lists);
assert!(!fused.is_empty());
let top_ids: Vec<_> = fused.iter().take(2).map(|(id, _)| id.clone()).collect();
assert!(top_ids.contains(&"doc1".to_string()));
assert!(top_ids.contains(&"doc2".to_string()));
}
#[test]
fn test_search_engine_integration() {
let mut engine = SearchEngine::new();
let docs = vec![
Document {
id: "doc1".to_string(),
content: "Rust programming language is fast and memory safe".to_string(),
metadata: Default::default(),
embedding: None,
},
Document {
id: "doc2".to_string(),
content: "Python is popular for machine learning and data science".to_string(),
metadata: Default::default(),
embedding: None,
},
Document {
id: "doc3".to_string(),
content: "JavaScript runs in web browsers".to_string(),
metadata: Default::default(),
embedding: None,
},
];
engine.index_documents(&docs);
assert_eq!(engine.len(), 3);
let bm25_results = engine.search_bm25("Rust programming", 10);
assert!(!bm25_results.is_empty());
assert_eq!(bm25_results[0].0, "doc1");
let fuzzy_results = engine.search_fuzzy("rust", 10);
assert!(!fuzzy_results.is_empty(), "Fuzzy search should find 'rust'");
}
#[test]
fn test_hybrid_search() {
let mut engine = SearchEngine::new();
let docs = vec![
Document {
id: "doc1".to_string(),
content: "Vector databases enable semantic search".to_string(),
metadata: Default::default(),
embedding: None,
},
Document {
id: "doc2".to_string(),
content: "BM25 is a lexical search algorithm".to_string(),
metadata: Default::default(),
embedding: None,
},
];
engine.index_documents(&docs);
let semantic_results = vec![("doc1".to_string(), 0.95), ("doc2".to_string(), 0.80)];
let weights = HybridWeights {
semantic: 0.5,
bm25: 0.4,
fuzzy: 0.1,
};
let hybrid = engine.search_hybrid("vector search", &semantic_results, &weights, 10);
assert!(!hybrid.is_empty());
}
#[test]
fn test_hybrid_weights_default() {
let weights = HybridWeights::default();
assert!((weights.semantic - 0.6).abs() < 0.001);
assert!((weights.bm25 - 0.3).abs() < 0.001);
assert!((weights.fuzzy - 0.1).abs() < 0.001);
}
#[test]
fn test_correct_word_exact_match() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "programming language");
let result = index.correct_word("programming");
assert!(result.is_some());
let (corrected, distance) = result.unwrap();
assert_eq!(corrected, "programming");
assert_eq!(distance, 0);
}
#[test]
fn test_correct_word_with_typo() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "programming language");
let result = index.correct_word("progamming");
assert!(result.is_some());
let (corrected, distance) = result.unwrap();
assert_eq!(corrected, "programming");
assert_eq!(distance, 1);
}
#[test]
fn test_correct_word_no_match() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "programming language");
let result = index.correct_word("xyz");
assert!(result.is_none());
}
#[test]
fn test_correct_query_single_typo() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "rust programming language");
let (corrected, corrections) = index.correct_query("progamming");
assert_eq!(corrected, "programming");
assert_eq!(corrections.len(), 1);
assert_eq!(corrections[0].original, "progamming");
assert_eq!(corrections[0].corrected, "programming");
assert_eq!(corrections[0].distance, 1);
}
#[test]
fn test_correct_query_multiple_typos() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "rust programming language");
let (corrected, corrections) = index.correct_query("progamming languge");
assert_eq!(corrected, "programming language");
assert_eq!(corrections.len(), 2);
}
#[test]
fn test_correct_query_no_typos() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "rust programming language");
let (corrected, corrections) = index.correct_query("programming language");
assert_eq!(corrected, "programming language");
assert!(corrections.is_empty());
}
#[test]
fn test_search_bm25_with_correction() {
let mut engine = SearchEngine::new();
let docs = vec![
Document {
id: "doc1".to_string(),
content: "Rust is a systems programming language".to_string(),
metadata: Default::default(),
embedding: None,
},
Document {
id: "doc2".to_string(),
content: "Python is popular for scripting".to_string(),
metadata: Default::default(),
embedding: None,
},
];
engine.index_documents(&docs);
let (results, corrected_query, corrections) =
engine.search_bm25_with_correction("progamming", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1");
assert_eq!(corrected_query, "programming");
assert_eq!(corrections.len(), 1);
assert_eq!(corrections[0].original, "progamming");
assert_eq!(corrections[0].corrected, "programming");
}
#[test]
fn test_vocabulary_cleared() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "programming language");
assert!(index.vocabulary_size() > 0);
index.clear();
assert_eq!(index.vocabulary_size(), 0);
assert!(index.is_empty());
}
#[test]
fn test_typo_correction_case_insensitive() {
let mut index = FuzzyIndex::new();
index.add_document("doc1", "Programming Language");
let result = index.correct_word("PROGAMMING");
assert!(result.is_some());
let (corrected, _) = result.unwrap();
assert_eq!(corrected, "programming"); }
#[test]
fn test_bm25_save_load() {
let temp_dir = std::env::temp_dir().join("ares_test_bm25");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
let path = temp_dir.join("bm25_index.json");
let mut index = Bm25Index::new();
index.add_document("doc1", "The quick brown fox");
index.add_document("doc2", "A lazy dog sleeps");
assert_eq!(index.len(), 2);
index.save(&path).unwrap();
let loaded = Bm25Index::load(&path).unwrap();
assert_eq!(loaded.len(), 2);
let results = loaded.search("quick brown", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1");
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_fuzzy_save_load() {
let temp_dir = std::env::temp_dir().join("ares_test_fuzzy");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
let path = temp_dir.join("fuzzy_index.json");
let mut index = FuzzyIndex::new();
index.add_document("doc1", "machine learning algorithms");
index.add_document("doc2", "deep neural networks");
assert_eq!(index.len(), 2);
index.save(&path).unwrap();
let loaded = FuzzyIndex::load(&path).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded.vocabulary_size(), index.vocabulary_size());
let results = loaded.search("machine", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, "doc1");
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_search_engine_save_load() {
let temp_dir = std::env::temp_dir().join("ares_test_engine");
let _ = std::fs::remove_dir_all(&temp_dir);
let mut engine = SearchEngine::new();
let docs = vec![
Document {
id: "doc1".to_string(),
content: "Rust programming language".to_string(),
metadata: Default::default(),
embedding: None,
},
Document {
id: "doc2".to_string(),
content: "Python scripting language".to_string(),
metadata: Default::default(),
embedding: None,
},
];
engine.index_documents(&docs);
assert_eq!(engine.len(), 2);
engine.save(&temp_dir).unwrap();
let loaded = SearchEngine::load(&temp_dir).unwrap();
assert_eq!(loaded.len(), 2);
let bm25_results = loaded.search_bm25("Rust programming", 10);
assert!(!bm25_results.is_empty());
assert_eq!(bm25_results[0].0, "doc1");
let fuzzy_results = loaded.search_fuzzy("rust", 10);
assert!(!fuzzy_results.is_empty());
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_load_or_new_missing_file() {
let path = std::env::temp_dir().join("nonexistent_bm25_index.json");
let _ = std::fs::remove_file(&path);
let index = Bm25Index::load_or_new(&path);
assert!(index.is_empty());
}
#[test]
fn test_search_engine_load_or_new() {
let temp_dir = std::env::temp_dir().join("ares_test_load_or_new");
let _ = std::fs::remove_dir_all(&temp_dir);
let engine = SearchEngine::load_or_new(&temp_dir);
assert!(engine.is_empty());
}
}