use crate::{Result, StorageError};
use std::cmp::min;
pub struct FuzzySearcher {
max_distance: u32,
use_ngram_filter: bool,
}
impl FuzzySearcher {
pub fn new(max_distance: u32) -> Self {
Self {
max_distance,
use_ngram_filter: true, }
}
pub fn without_ngram_filter(max_distance: u32) -> Self {
Self {
max_distance,
use_ngram_filter: false,
}
}
pub fn search_in_index(
&self,
index: &crate::index::TextFTSIndex,
query: &str,
) -> Result<Vec<u64>> {
let dictionary = index.get_all_terms();
let candidates = self.find_fuzzy_matches(query, &dictionary);
let mut result_docs = std::collections::HashSet::new();
for term in candidates {
if let Some(posting) = index.get_posting_list(&term)? {
for doc_id in posting.doc_ids() {
result_docs.insert(doc_id);
}
}
}
Ok(result_docs.into_iter().collect())
}
fn find_fuzzy_matches(&self, query: &str, dictionary: &[String]) -> Vec<String> {
if !self.use_ngram_filter || dictionary.len() < 100 {
return self.find_fuzzy_matches_direct(query, dictionary);
}
let query_ngrams = generate_ngrams(query, 2); let query_ngram_count = query_ngrams.len();
let min_shared = if query_ngram_count > 2 * self.max_distance as usize {
query_ngram_count - 2 * self.max_distance as usize
} else {
0
};
let candidates: Vec<&String> = dictionary
.iter()
.filter(|term| {
let len_diff = (query.len() as i32 - term.len() as i32).abs() as u32;
if len_diff > self.max_distance {
return false;
}
let term_ngrams = generate_ngrams(term, 2);
let shared_count = count_shared_ngrams(&query_ngrams, &term_ngrams);
shared_count >= min_shared
})
.collect();
candidates
.into_iter()
.filter_map(|term| {
let distance = levenshtein_distance(query, term);
if distance <= self.max_distance {
Some(term.clone())
} else {
None
}
})
.collect()
}
fn find_fuzzy_matches_direct(&self, query: &str, dictionary: &[String]) -> Vec<String> {
dictionary
.iter()
.filter_map(|term| {
let distance = levenshtein_distance(query, term);
if distance <= self.max_distance {
Some(term.clone())
} else {
None
}
})
.collect()
}
}
fn generate_ngrams(s: &str, n: usize) -> Vec<String> {
if s.len() < n {
return vec![s.to_string()];
}
let chars: Vec<char> = s.chars().collect();
(0..=chars.len() - n)
.map(|i| chars[i..i + n].iter().collect())
.collect()
}
fn count_shared_ngrams(ngrams1: &[String], ngrams2: &[String]) -> usize {
use std::collections::HashSet;
let set1: HashSet<&String> = ngrams1.iter().collect();
let set2: HashSet<&String> = ngrams2.iter().collect();
set1.intersection(&set2).count()
}
pub fn levenshtein_distance(s1: &str, s2: &str) -> u32 {
let chars1: Vec<char> = s1.chars().collect();
let chars2: Vec<char> = s2.chars().collect();
let len1 = chars1.len();
let len2 = chars2.len();
if len1 == 0 {
return len2 as u32;
}
if len2 == 0 {
return len1 as u32;
}
let mut prev_row: Vec<u32> = (0..=len2 as u32).collect();
let mut curr_row: Vec<u32> = vec![0; len2 + 1];
for i in 1..=len1 {
curr_row[0] = i as u32;
for j in 1..=len2 {
let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
curr_row[j] = min(
min(
prev_row[j] + 1, curr_row[j - 1] + 1 ),
prev_row[j - 1] + cost );
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
prev_row[len2]
}
pub fn damerau_levenshtein_distance(s1: &str, s2: &str) -> u32 {
let chars1: Vec<char> = s1.chars().collect();
let chars2: Vec<char> = s2.chars().collect();
let len1 = chars1.len();
let len2 = chars2.len();
if len1 == 0 {
return len2 as u32;
}
if len2 == 0 {
return len1 as u32;
}
let mut matrix = vec![vec![0u32; len2 + 1]; len1 + 1];
for i in 0..=len1 {
matrix[i][0] = i as u32;
}
for j in 0..=len2 {
matrix[0][j] = j as u32;
}
for i in 1..=len1 {
for j in 1..=len2 {
let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
matrix[i][j] = min(
min(
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1 ),
matrix[i - 1][j - 1] + cost );
if i > 1 && j > 1
&& chars1[i - 1] == chars2[j - 2]
&& chars1[i - 2] == chars2[j - 1]
{
matrix[i][j] = min(matrix[i][j], matrix[i - 2][j - 2] + 1);
}
}
}
matrix[len1][len2]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levenshtein_basic() {
assert_eq!(levenshtein_distance("", ""), 0);
assert_eq!(levenshtein_distance("abc", "abc"), 0);
assert_eq!(levenshtein_distance("", "abc"), 3);
assert_eq!(levenshtein_distance("abc", ""), 3);
}
#[test]
fn test_levenshtein_examples() {
assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(levenshtein_distance("saturday", "sunday"), 3);
assert_eq!(levenshtein_distance("database", "databse"), 1);
assert_eq!(levenshtein_distance("algorithm", "logarithm"), 3);
}
#[test]
fn test_damerau_levenshtein() {
assert_eq!(damerau_levenshtein_distance("ab", "ba"), 1);
assert_eq!(damerau_levenshtein_distance("abc", "acb"), 1);
}
#[test]
fn test_ngram_generation() {
let ngrams = generate_ngrams("hello", 2);
assert_eq!(ngrams, vec!["he", "el", "ll", "lo"]);
let ngrams = generate_ngrams("ab", 2);
assert_eq!(ngrams, vec!["ab"]);
let ngrams = generate_ngrams("a", 2);
assert_eq!(ngrams, vec!["a"]);
}
#[test]
fn test_ngram_shared_count() {
let ngrams1 = vec!["he".to_string(), "el".to_string(), "ll".to_string(), "lo".to_string()];
let ngrams2 = vec!["he".to_string(), "el".to_string(), "lp".to_string()];
assert_eq!(count_shared_ngrams(&ngrams1, &ngrams2), 2);
}
#[test]
fn test_fuzzy_search_with_ngram_filter() {
let dictionary = vec![
"learning".to_string(),
"machine".to_string(),
"algorithm".to_string(),
"database".to_string(),
"network".to_string(),
];
let searcher = FuzzySearcher::new(1);
let results = searcher.find_fuzzy_matches("learing", &dictionary);
assert!(results.contains(&"learning".to_string()));
let searcher_direct = FuzzySearcher::without_ngram_filter(1);
let results_direct = searcher_direct.find_fuzzy_matches("learing", &dictionary);
assert_eq!(results.len(), results_direct.len());
}
}