use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use serde::{Deserialize, Serialize};
use crate::spelling::dictionary::SpellingDictionary;
use crate::spelling::typo_patterns::TypoPatterns;
use crate::util::levenshtein::LevenshteinMatcher;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Suggestion {
pub word: String,
pub score: f64,
pub distance: usize,
pub frequency: u32,
}
impl Suggestion {
pub fn new(word: String, score: f64, distance: usize, frequency: u32) -> Self {
Suggestion {
word,
score,
distance,
frequency,
}
}
}
impl Eq for Suggestion {}
impl Ord for Suggestion {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for Suggestion {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct SuggestionConfig {
pub max_distance: usize,
pub max_suggestions: usize,
pub min_frequency: u32,
pub distance_weight: f64,
pub frequency_weight: f64,
pub use_keyboard_distance: bool,
pub use_phonetic: bool,
}
impl Default for SuggestionConfig {
fn default() -> Self {
SuggestionConfig {
max_distance: 2,
max_suggestions: 5,
min_frequency: 1,
distance_weight: 0.6,
frequency_weight: 0.4,
use_keyboard_distance: true,
use_phonetic: false,
}
}
}
pub struct SuggestionEngine {
dictionary: SpellingDictionary,
config: SuggestionConfig,
}
impl SuggestionEngine {
pub fn new(dictionary: SpellingDictionary) -> Self {
SuggestionEngine {
dictionary,
config: SuggestionConfig::default(),
}
}
pub fn with_config(dictionary: SpellingDictionary, config: SuggestionConfig) -> Self {
SuggestionEngine { dictionary, config }
}
pub fn set_config(&mut self, config: SuggestionConfig) {
self.config = config;
}
pub fn suggest(&self, word: &str) -> Vec<Suggestion> {
let word_lower = word.to_lowercase();
if self.dictionary.contains(&word_lower) {
let frequency = self.dictionary.frequency(&word_lower);
return vec![Suggestion::new(word_lower, 1.0, 0, frequency)];
}
let mut suggestions = BinaryHeap::new();
let matcher = LevenshteinMatcher::new(word_lower.clone());
let candidates = self.generate_candidates(&word_lower);
for candidate in candidates {
if let Some(distance) = matcher.distance_threshold(&candidate, self.config.max_distance)
{
let frequency = self.dictionary.frequency(&candidate);
if frequency >= self.config.min_frequency {
let score = self.calculate_score(&word_lower, &candidate, distance, frequency);
suggestions.push(Suggestion::new(candidate, score, distance, frequency));
}
}
}
let mut result: Vec<Suggestion> = suggestions.into_sorted_vec();
result.truncate(self.config.max_suggestions);
result
}
fn generate_candidates(&self, word: &str) -> HashSet<String> {
let mut candidates = HashSet::new();
candidates.extend(self.generate_edits(word, 1));
if self.config.max_distance >= 2 {
let first_edits = self.generate_edits(word, 1);
for edit in &first_edits {
candidates.extend(self.generate_edits(edit, 1));
}
}
candidates.extend(
self.dictionary
.words_with_prefix(&word[..word.len().min(3)]),
);
candidates.retain(|candidate| self.dictionary.contains(candidate));
candidates
}
fn generate_edits(&self, word: &str, max_distance: usize) -> HashSet<String> {
if max_distance == 0 {
return HashSet::new();
}
let mut edits = HashSet::new();
let chars: Vec<char> = word.chars().collect();
let len = chars.len();
for i in 0..len {
let mut new_word = chars.clone();
new_word.remove(i);
edits.insert(new_word.into_iter().collect());
}
for i in 0..len.saturating_sub(1) {
let mut new_word = chars.clone();
new_word.swap(i, i + 1);
edits.insert(new_word.into_iter().collect());
}
for i in 0..len {
for ch in 'a'..='z' {
if ch != chars[i] {
let mut new_word = chars.clone();
new_word[i] = ch;
edits.insert(new_word.into_iter().collect());
}
}
}
for i in 0..=len {
for ch in 'a'..='z' {
let mut new_word = chars.clone();
new_word.insert(i, ch);
edits.insert(new_word.into_iter().collect());
}
}
if self.config.use_keyboard_distance {
for i in 0..len {
let nearby_keys = TypoPatterns::nearby_keys(chars[i]);
for &nearby_char in &nearby_keys {
let mut new_word = chars.clone();
new_word[i] = nearby_char;
edits.insert(new_word.into_iter().collect());
}
}
}
edits
}
fn calculate_score(
&self,
original: &str,
candidate: &str,
distance: usize,
frequency: u32,
) -> f64 {
let distance_score = if distance == 0 {
1.0
} else {
1.0 / (1.0 + distance as f64)
};
let total_freq = self.dictionary.total_frequency();
let frequency_score = if frequency == 0 || total_freq <= 1 {
0.0
} else {
(frequency as f64).ln() / (total_freq as f64).ln()
};
let length_penalty = if original.len() == candidate.len() {
1.0
} else {
0.9 };
let prefix_bonus = self.calculate_prefix_bonus(original, candidate);
let keyboard_bonus = if self.config.use_keyboard_distance {
let keyboard_dist = TypoPatterns::keyboard_distance(original, candidate);
let regular_dist = distance as f64;
if keyboard_dist < regular_dist {
1.1 } else {
1.0
}
} else {
1.0
};
let base_score = distance_score * self.config.distance_weight
+ frequency_score * self.config.frequency_weight;
(base_score * length_penalty * prefix_bonus * keyboard_bonus).min(1.0)
}
fn calculate_prefix_bonus(&self, original: &str, candidate: &str) -> f64 {
let orig_chars: Vec<char> = original.chars().collect();
let cand_chars: Vec<char> = candidate.chars().collect();
let common_prefix_len = orig_chars
.iter()
.zip(cand_chars.iter())
.take_while(|(a, b)| a == b)
.count();
let max_len = orig_chars.len().max(cand_chars.len());
if max_len == 0 {
return 1.0;
}
1.0 + (common_prefix_len as f64 / max_len as f64) * 0.2
}
pub fn dictionary_stats(&self) -> (usize, u64) {
(
self.dictionary.word_count(),
self.dictionary.total_frequency(),
)
}
pub fn is_correct(&self, word: &str) -> bool {
self.dictionary.contains(word)
}
pub fn add_word(&mut self, word: &str, frequency: u32) {
self.dictionary.add_word(word.to_string(), frequency);
}
pub fn word_frequency(&self, word: &str) -> u32 {
self.dictionary.frequency(word)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spelling::dictionary::BuiltinDictionary;
#[test]
fn test_suggestion_ordering() {
let s1 = Suggestion::new("hello".to_string(), 0.9, 1, 100);
let s2 = Suggestion::new("world".to_string(), 0.8, 1, 50);
let s3 = Suggestion::new("test".to_string(), 0.95, 0, 200);
let mut suggestions = [s1, s2, s3];
suggestions.sort();
assert_eq!(suggestions[0].word, "test");
assert_eq!(suggestions[1].word, "hello");
assert_eq!(suggestions[2].word, "world");
}
#[test]
fn test_suggestion_engine_correct_word() {
let dict = BuiltinDictionary::minimal();
let engine = SuggestionEngine::new(dict);
let suggestions = engine.suggest("hello");
assert_eq!(suggestions.len(), 1);
assert_eq!(suggestions[0].word, "hello");
assert_eq!(suggestions[0].distance, 0);
assert!((suggestions[0].score - 1.0).abs() < 1e-6);
}
#[test]
fn test_suggestion_engine_typos() {
let dict = BuiltinDictionary::minimal();
let engine = SuggestionEngine::new(dict);
let suggestions = engine.suggest("helo"); assert!(!suggestions.is_empty());
assert!(suggestions.iter().any(|s| s.word == "hello"));
let suggestions = engine.suggest("serach"); assert!(!suggestions.is_empty());
}
#[test]
fn test_suggestion_engine_configuration() {
let dict = BuiltinDictionary::minimal();
let config = SuggestionConfig {
max_distance: 1,
max_suggestions: 3,
min_frequency: 1,
distance_weight: 0.8,
frequency_weight: 0.2,
use_keyboard_distance: false,
use_phonetic: false,
};
let engine = SuggestionEngine::with_config(dict, config);
let suggestions = engine.suggest("helo");
assert!(suggestions.len() <= 3);
for suggestion in &suggestions {
assert!(suggestion.distance <= 1);
}
}
#[test]
fn test_generate_edits() {
let dict = BuiltinDictionary::minimal();
let engine = SuggestionEngine::new(dict);
let edits = engine.generate_edits("cat", 1);
assert!(edits.contains("at"));
assert!(edits.contains("ct"));
assert!(edits.contains("ca"));
assert!(edits.len() > 50);
assert!(edits.contains("bat"));
assert!(edits.contains("cot"));
}
#[test]
fn test_prefix_bonus() {
let dict = BuiltinDictionary::minimal();
let engine = SuggestionEngine::new(dict);
let bonus1 = engine.calculate_prefix_bonus("search", "searching"); let bonus2 = engine.calculate_prefix_bonus("search", "church");
assert!(bonus1 > bonus2);
assert!(bonus1 > 1.0);
}
#[test]
fn test_keyboard_distance_in_suggestions() {
let dict = BuiltinDictionary::minimal();
let config = SuggestionConfig {
use_keyboard_distance: true,
..Default::default()
};
let engine = SuggestionEngine::with_config(dict, config);
let suggestions = engine.suggest("gello");
assert!(!suggestions.is_empty());
}
}