use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::spelling::dictionary::{BuiltinDictionary, SpellingDictionary};
use crate::spelling::suggest::{Suggestion, SuggestionConfig, SuggestionEngine};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrectorConfig {
pub max_distance: usize,
pub max_suggestions: usize,
pub min_frequency: u32,
pub auto_correct: bool,
pub auto_correct_threshold: f64,
pub use_index_terms: bool,
pub learn_from_queries: bool,
}
impl Default for CorrectorConfig {
fn default() -> Self {
CorrectorConfig {
max_distance: 2,
max_suggestions: 5,
min_frequency: 1,
auto_correct: false,
auto_correct_threshold: 0.8,
use_index_terms: true,
learn_from_queries: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrectionResult {
pub original: String,
pub corrected: Option<String>,
pub word_suggestions: HashMap<String, Vec<Suggestion>>,
pub confidence: f64,
pub auto_corrected: bool,
}
impl CorrectionResult {
pub fn new(original: String) -> Self {
CorrectionResult {
original,
corrected: None,
word_suggestions: HashMap::new(),
confidence: 1.0,
auto_corrected: false,
}
}
pub fn has_suggestions(&self) -> bool {
!self.word_suggestions.is_empty()
}
pub fn best_suggestion(&self, word: &str) -> Option<&Suggestion> {
self.word_suggestions.get(word)?.first()
}
pub fn query(&self) -> &str {
self.corrected.as_ref().unwrap_or(&self.original)
}
pub fn should_show_did_you_mean(&self) -> bool {
self.has_suggestions() && !self.auto_corrected && self.confidence < 0.7
}
}
pub struct SpellingCorrector {
engine: SuggestionEngine,
config: CorrectorConfig,
query_history: HashMap<String, u32>,
}
impl SpellingCorrector {
pub fn new() -> Self {
let dictionary = BuiltinDictionary::english();
let config = CorrectorConfig::default();
let suggestion_config = SuggestionConfig {
max_distance: config.max_distance,
max_suggestions: config.max_suggestions,
min_frequency: config.min_frequency,
..Default::default()
};
let engine = SuggestionEngine::with_config(dictionary, suggestion_config);
SpellingCorrector {
engine,
config,
query_history: HashMap::new(),
}
}
pub fn with_dictionary(dictionary: SpellingDictionary) -> Self {
let config = CorrectorConfig::default();
let suggestion_config = SuggestionConfig {
max_distance: config.max_distance,
max_suggestions: config.max_suggestions,
min_frequency: config.min_frequency,
..Default::default()
};
let engine = SuggestionEngine::with_config(dictionary, suggestion_config);
SpellingCorrector {
engine,
config,
query_history: HashMap::new(),
}
}
pub fn with_config(dictionary: SpellingDictionary, config: CorrectorConfig) -> Self {
let suggestion_config = SuggestionConfig {
max_distance: config.max_distance,
max_suggestions: config.max_suggestions,
min_frequency: config.min_frequency,
..Default::default()
};
let engine = SuggestionEngine::with_config(dictionary, suggestion_config);
SpellingCorrector {
engine,
config,
query_history: HashMap::new(),
}
}
pub fn set_config(&mut self, config: CorrectorConfig) {
let suggestion_config = SuggestionConfig {
max_distance: config.max_distance,
max_suggestions: config.max_suggestions,
min_frequency: config.min_frequency,
..Default::default()
};
self.engine.set_config(suggestion_config);
self.config = config;
}
pub fn correct(&mut self, query: &str) -> CorrectionResult {
let mut result = CorrectionResult::new(query.to_string());
if self.config.learn_from_queries {
self.learn_query(query);
}
let words = self.extract_query_words(query);
let mut corrected_words = Vec::new();
let mut total_confidence = 0.0;
let mut corrections_made = 0;
for word in &words {
if self.engine.is_correct(word) {
corrected_words.push(word.clone());
total_confidence += 1.0;
} else {
let suggestions = self.engine.suggest(word);
if !suggestions.is_empty() {
result
.word_suggestions
.insert(word.clone(), suggestions.clone());
let best_suggestion = &suggestions[0];
total_confidence += best_suggestion.score;
if self.config.auto_correct
&& best_suggestion.score >= self.config.auto_correct_threshold
{
corrected_words.push(best_suggestion.word.clone());
corrections_made += 1;
result.auto_corrected = true;
} else {
corrected_words.push(word.clone());
}
} else {
corrected_words.push(word.clone());
total_confidence += 0.5; }
}
}
result.confidence = if words.is_empty() {
1.0
} else {
total_confidence / words.len() as f64
};
if corrections_made > 0 {
result.corrected = Some(corrected_words.join(" "));
}
result
}
pub fn suggest_word(&self, word: &str) -> Vec<Suggestion> {
self.engine.suggest(word)
}
pub fn is_correct(&self, word: &str) -> bool {
self.engine.is_correct(word)
}
pub fn learn_from_terms<I>(&mut self, terms: I) -> Result<()>
where
I: IntoIterator<Item = (String, u32)>,
{
if !self.config.use_index_terms {
return Ok(());
}
for (term, frequency) in terms {
if term.len() >= 2 && term.chars().all(|c| c.is_alphabetic() || c == '-') {
self.engine.add_word(&term, frequency);
}
}
Ok(())
}
fn learn_query(&mut self, query: &str) {
let words = self.extract_query_words(query);
for word in words {
if word.len() >= 3 && word.chars().all(|c| c.is_alphabetic()) {
let entry = self.query_history.entry(word.clone()).or_insert(0);
*entry = entry.saturating_add(1);
let count = *entry;
if count >= 5 {
self.engine.add_word(&word, count);
}
}
}
}
fn extract_query_words(&self, query: &str) -> Vec<String> {
let stop_words = [
"is", "a", "an", "the", "and", "or", "not", "in", "on", "at", "to", "for", "of",
"with", "by",
];
query
.split_whitespace()
.filter_map(|word| {
let cleaned: String = word
.chars()
.filter(|c| c.is_alphabetic())
.collect::<String>()
.to_lowercase();
if cleaned.len() >= 2 && !stop_words.contains(&cleaned.as_str()) {
Some(cleaned)
} else {
None
}
})
.collect()
}
pub fn stats(&self) -> CorrectorStats {
let (dict_words, dict_frequency) = self.engine.dictionary_stats();
CorrectorStats {
dictionary_words: dict_words,
dictionary_total_frequency: dict_frequency,
queries_learned: self.query_history.len(),
total_query_frequency: self.query_history.values().sum(),
}
}
pub fn clear_query_history(&mut self) {
self.query_history.clear();
}
}
impl Default for SpellingCorrector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrectorStats {
pub dictionary_words: usize,
pub dictionary_total_frequency: u64,
pub queries_learned: usize,
pub total_query_frequency: u32,
}
pub struct DidYouMean {
corrector: SpellingCorrector,
}
impl DidYouMean {
pub fn new(corrector: SpellingCorrector) -> Self {
DidYouMean { corrector }
}
pub fn suggest(&mut self, query: &str) -> Option<String> {
let result = self.corrector.correct(query);
if result.should_show_did_you_mean() {
let words = self.corrector.extract_query_words(query);
let mut corrected_words = Vec::new();
let mut made_corrections = false;
for word in words {
if let Some(suggestions) = result.word_suggestions.get(&word) {
if let Some(best) = suggestions.first() {
if best.score > 0.5 {
corrected_words.push(best.word.clone());
made_corrections = true;
} else {
corrected_words.push(word);
}
} else {
corrected_words.push(word);
}
} else {
corrected_words.push(word);
}
}
if made_corrections {
Some(corrected_words.join(" "))
} else {
None
}
} else {
None
}
}
pub fn should_suggest(&mut self, query: &str) -> bool {
let result = self.corrector.correct(query);
result.should_show_did_you_mean()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spelling::dictionary::BuiltinDictionary;
#[test]
fn test_corrector_creation() {
let corrector = SpellingCorrector::new();
let stats = corrector.stats();
assert!(stats.dictionary_words > 0);
}
#[test]
fn test_corrector_with_custom_dictionary() {
let dict = BuiltinDictionary::minimal();
let corrector = SpellingCorrector::with_dictionary(dict);
assert!(corrector.is_correct("hello"));
assert!(corrector.is_correct("search"));
}
#[test]
fn test_word_extraction() {
let corrector = SpellingCorrector::new();
let words = corrector.extract_query_words("Hello, world! This is a test.");
assert!(words.contains(&"hello".to_string()));
assert!(words.contains(&"world".to_string()));
assert!(words.contains(&"test".to_string()));
assert!(!words.contains(&"is".to_string()));
}
#[test]
fn test_correction_result() {
let dict = BuiltinDictionary::minimal();
let mut corrector = SpellingCorrector::with_dictionary(dict);
let result = corrector.correct("hello world");
assert!(!result.has_suggestions());
assert_eq!(result.query(), "hello world");
assert!(!result.should_show_did_you_mean());
let result = corrector.correct("helo wrld");
assert_eq!(result.query(), "helo wrld"); }
#[test]
fn test_auto_correction() {
let dict = BuiltinDictionary::minimal();
let config = CorrectorConfig {
auto_correct: true,
auto_correct_threshold: 0.5, ..Default::default()
};
let mut corrector = SpellingCorrector::with_config(dict, config);
let _result = corrector.correct("helo");
}
#[test]
fn test_suggestion_for_single_word() {
let dict = BuiltinDictionary::minimal();
let dict_clone = dict.clone();
let corrector = SpellingCorrector::with_dictionary(dict);
let suggestions = corrector.suggest_word("helo");
assert!(!suggestions.is_empty() || !dict_clone.contains("hello")); }
#[test]
fn test_did_you_mean() {
let dict = BuiltinDictionary::minimal();
let corrector = SpellingCorrector::with_dictionary(dict);
let mut dym = DidYouMean::new(corrector);
assert!(!dym.should_suggest("hello world"));
assert!(dym.suggest("hello world").is_none());
let _should_suggest = dym.should_suggest("helo wrld");
}
#[test]
fn test_corrector_stats() {
let corrector = SpellingCorrector::new();
let stats = corrector.stats();
assert!(stats.dictionary_words > 0);
assert!(stats.dictionary_total_frequency > 0);
assert_eq!(stats.queries_learned, 0); assert_eq!(stats.total_query_frequency, 0);
}
#[test]
fn test_config_update() {
let mut corrector = SpellingCorrector::new();
let new_config = CorrectorConfig {
max_suggestions: 10,
auto_correct: true,
..Default::default()
};
corrector.set_config(new_config.clone());
assert_eq!(corrector.config.max_suggestions, 10);
assert!(corrector.config.auto_correct);
}
#[test]
fn test_learning_queries() {
let dict = BuiltinDictionary::minimal();
let config = CorrectorConfig {
learn_from_queries: true,
..Default::default()
};
let mut corrector = SpellingCorrector::with_config(dict, config);
corrector.correct("hello world test");
corrector.correct("hello programming test");
let stats = corrector.stats();
assert!(stats.queries_learned > 0);
assert!(stats.total_query_frequency > 0);
}
}