Skip to main content

engram/search/
fuzzy.rs

1//! Fuzzy search with typo tolerance (RML-877)
2//!
3//! Uses Levenshtein distance for typo correction and suggestion.
4
5#![allow(clippy::needless_range_loop)]
6
7use std::collections::{HashMap, HashSet};
8
9/// Maximum edit distance for fuzzy matching
10const MAX_EDIT_DISTANCE: usize = 2;
11
12/// Minimum word length to apply fuzzy matching
13const MIN_WORD_LENGTH: usize = 4;
14
15/// Fuzzy search engine
16pub struct FuzzyEngine {
17    /// Vocabulary built from indexed content
18    vocabulary: HashSet<String>,
19    /// Word frequency for ranking suggestions
20    word_freq: HashMap<String, usize>,
21}
22
23impl FuzzyEngine {
24    /// Create a new fuzzy engine
25    pub fn new() -> Self {
26        Self {
27            vocabulary: HashSet::new(),
28            word_freq: HashMap::new(),
29        }
30    }
31
32    /// Add text to the vocabulary
33    pub fn add_to_vocabulary(&mut self, text: &str) {
34        for word in tokenize(text) {
35            if word.len() >= MIN_WORD_LENGTH {
36                self.vocabulary.insert(word.clone());
37                *self.word_freq.entry(word).or_insert(0) += 1;
38            }
39        }
40    }
41
42    /// Find corrections for a query
43    pub fn correct_query(&self, query: &str) -> CorrectionResult {
44        let mut corrections = Vec::new();
45        let mut corrected_query = String::new();
46        let mut had_corrections = false;
47
48        for word in query.split_whitespace() {
49            let word_lower = word.to_lowercase();
50
51            // Skip short words or words already in vocabulary
52            if word_lower.len() < MIN_WORD_LENGTH || self.vocabulary.contains(&word_lower) {
53                if !corrected_query.is_empty() {
54                    corrected_query.push(' ');
55                }
56                corrected_query.push_str(word);
57                continue;
58            }
59
60            // Find best correction
61            if let Some(correction) = self.find_best_correction(&word_lower) {
62                corrections.push(Correction {
63                    original: word.to_string(),
64                    corrected: correction.clone(),
65                    distance: levenshtein(&word_lower, &correction),
66                });
67                had_corrections = true;
68
69                if !corrected_query.is_empty() {
70                    corrected_query.push(' ');
71                }
72                corrected_query.push_str(&correction);
73            } else {
74                if !corrected_query.is_empty() {
75                    corrected_query.push(' ');
76                }
77                corrected_query.push_str(word);
78            }
79        }
80
81        CorrectionResult {
82            original_query: query.to_string(),
83            corrected_query: if had_corrections {
84                Some(corrected_query)
85            } else {
86                None
87            },
88            corrections,
89            suggestions: self.get_suggestions(query, 5),
90        }
91    }
92
93    /// Find the best correction for a word
94    fn find_best_correction(&self, word: &str) -> Option<String> {
95        let mut best: Option<(String, usize, usize)> = None; // (word, distance, frequency)
96
97        for vocab_word in &self.vocabulary {
98            let distance = levenshtein(word, vocab_word);
99
100            if distance <= MAX_EDIT_DISTANCE {
101                let freq = *self.word_freq.get(vocab_word).unwrap_or(&0);
102
103                match &best {
104                    None => {
105                        best = Some((vocab_word.clone(), distance, freq));
106                    }
107                    Some((_, best_dist, best_freq)) => {
108                        // Prefer smaller distance, then higher frequency
109                        if distance < *best_dist || (distance == *best_dist && freq > *best_freq) {
110                            best = Some((vocab_word.clone(), distance, freq));
111                        }
112                    }
113                }
114            }
115        }
116
117        best.map(|(word, _, _)| word)
118    }
119
120    /// Get search suggestions based on prefix matching and similarity
121    fn get_suggestions(&self, query: &str, limit: usize) -> Vec<String> {
122        let query_lower = query.to_lowercase();
123        let mut suggestions: Vec<(String, usize)> = Vec::new();
124
125        for word in &self.vocabulary {
126            // Prefix match
127            if word.starts_with(&query_lower) {
128                let freq = *self.word_freq.get(word).unwrap_or(&0);
129                suggestions.push((word.clone(), freq));
130            }
131            // Similar words
132            else if query_lower.len() >= MIN_WORD_LENGTH {
133                let distance = levenshtein(&query_lower, word);
134                if distance <= MAX_EDIT_DISTANCE {
135                    let freq = *self.word_freq.get(word).unwrap_or(&0);
136                    suggestions.push((word.clone(), freq));
137                }
138            }
139        }
140
141        // Sort by frequency (descending)
142        suggestions.sort_by(|a, b| b.1.cmp(&a.1));
143
144        suggestions
145            .into_iter()
146            .take(limit)
147            .map(|(word, _)| word)
148            .collect()
149    }
150
151    /// Get vocabulary size
152    pub fn vocabulary_size(&self) -> usize {
153        self.vocabulary.len()
154    }
155}
156
157impl Default for FuzzyEngine {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163use serde::{Deserialize, Serialize};
164
165/// Result of query correction
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct CorrectionResult {
168    pub original_query: String,
169    pub corrected_query: Option<String>,
170    pub corrections: Vec<Correction>,
171    pub suggestions: Vec<String>,
172}
173
174/// A single word correction
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct Correction {
177    pub original: String,
178    pub corrected: String,
179    pub distance: usize,
180}
181
182/// Calculate Levenshtein distance between two strings
183pub fn levenshtein(a: &str, b: &str) -> usize {
184    let a_chars: Vec<char> = a.chars().collect();
185    let b_chars: Vec<char> = b.chars().collect();
186    let a_len = a_chars.len();
187    let b_len = b_chars.len();
188
189    if a_len == 0 {
190        return b_len;
191    }
192    if b_len == 0 {
193        return a_len;
194    }
195
196    // Use two rows instead of full matrix for memory efficiency
197    let mut prev_row: Vec<usize> = (0..=b_len).collect();
198    let mut curr_row: Vec<usize> = vec![0; b_len + 1];
199
200    for i in 1..=a_len {
201        curr_row[0] = i;
202
203        for j in 1..=b_len {
204            let cost = if a_chars[i - 1] == b_chars[j - 1] {
205                0
206            } else {
207                1
208            };
209
210            curr_row[j] = (prev_row[j] + 1) // deletion
211                .min(curr_row[j - 1] + 1) // insertion
212                .min(prev_row[j - 1] + cost); // substitution
213        }
214
215        std::mem::swap(&mut prev_row, &mut curr_row);
216    }
217
218    prev_row[b_len]
219}
220
221/// Damerau-Levenshtein distance (includes transpositions)
222pub fn damerau_levenshtein(a: &str, b: &str) -> usize {
223    let a_chars: Vec<char> = a.chars().collect();
224    let b_chars: Vec<char> = b.chars().collect();
225    let a_len = a_chars.len();
226    let b_len = b_chars.len();
227
228    if a_len == 0 {
229        return b_len;
230    }
231    if b_len == 0 {
232        return a_len;
233    }
234
235    let mut matrix: Vec<Vec<usize>> = vec![vec![0; b_len + 1]; a_len + 1];
236
237    for i in 0..=a_len {
238        matrix[i][0] = i;
239    }
240    for j in 0..=b_len {
241        matrix[0][j] = j;
242    }
243
244    #[allow(clippy::needless_range_loop)]
245    for i in 1..=a_len {
246        for j in 1..=b_len {
247            let cost = if a_chars[i - 1] == b_chars[j - 1] {
248                0
249            } else {
250                1
251            };
252
253            matrix[i][j] = (matrix[i - 1][j] + 1) // deletion
254                .min(matrix[i][j - 1] + 1) // insertion
255                .min(matrix[i - 1][j - 1] + cost); // substitution
256
257            // Transposition
258            if i > 1
259                && j > 1
260                && a_chars[i - 1] == b_chars[j - 2]
261                && a_chars[i - 2] == b_chars[j - 1]
262            {
263                matrix[i][j] = matrix[i][j].min(matrix[i - 2][j - 2] + cost);
264            }
265        }
266    }
267
268    matrix[a_len][b_len]
269}
270
271/// Tokenize text into lowercase words
272fn tokenize(text: &str) -> Vec<String> {
273    text.to_lowercase()
274        .split(|c: char| !c.is_alphanumeric())
275        .filter(|s| !s.is_empty())
276        .map(String::from)
277        .collect()
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_levenshtein() {
286        assert_eq!(levenshtein("kitten", "sitting"), 3);
287        assert_eq!(levenshtein("hello", "hello"), 0);
288        assert_eq!(levenshtein("", "abc"), 3);
289        assert_eq!(levenshtein("abc", ""), 3);
290    }
291
292    #[test]
293    fn test_damerau_levenshtein() {
294        assert_eq!(damerau_levenshtein("ab", "ba"), 1); // transposition
295        assert_eq!(damerau_levenshtein("hello", "hlelo"), 1); // transposition
296    }
297
298    #[test]
299    fn test_fuzzy_engine() {
300        let mut engine = FuzzyEngine::new();
301        engine.add_to_vocabulary("authentication");
302        engine.add_to_vocabulary("authorization");
303        engine.add_to_vocabulary("automatic");
304
305        let result = engine.correct_query("authentcation"); // typo
306        assert!(result.corrected_query.is_some());
307        assert_eq!(result.corrected_query.unwrap(), "authentication");
308    }
309
310    #[test]
311    fn test_suggestions() {
312        let mut engine = FuzzyEngine::new();
313        engine.add_to_vocabulary("authentication");
314        engine.add_to_vocabulary("authorization");
315        engine.add_to_vocabulary("automatic");
316
317        let suggestions = engine.get_suggestions("auth", 5);
318        assert!(!suggestions.is_empty());
319        assert!(suggestions.iter().any(|s| s.starts_with("auth")));
320    }
321}