use crate::error::Result;
use crate::string_metrics::{DamerauLevenshteinMetric, StringMetric};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use super::dictionary::DictionaryCorrector;
use super::error_model::ErrorModel;
use super::ngram::NGramModel;
use super::SpellingCorrector;
#[derive(Debug, Clone)]
pub struct StatisticalCorrectorConfig {
pub max_edit_distance: usize,
pub case_sensitive: bool,
pub max_suggestions: usize,
pub min_frequency: usize,
pub ngram_order: usize,
pub language_model_weight: f64,
pub edit_distance_weight: f64,
pub use_context: bool,
pub context_window: usize,
pub max_candidates: usize,
}
impl Default for StatisticalCorrectorConfig {
fn default() -> Self {
Self {
max_edit_distance: 2,
case_sensitive: false,
max_suggestions: 5,
min_frequency: 1,
ngram_order: 3,
language_model_weight: 0.7,
edit_distance_weight: 0.3,
use_context: true,
context_window: 2,
max_candidates: 5,
}
}
}
pub struct StatisticalCorrector {
dictionary: HashMap<String, usize>,
config: StatisticalCorrectorConfig,
metric: Arc<dyn StringMetric + Send + Sync>,
language_model: NGramModel,
error_model: ErrorModel,
}
impl Clone for StatisticalCorrector {
fn clone(&self) -> Self {
Self {
dictionary: self.dictionary.clone(),
config: self.config.clone(),
metric: self.metric.clone(),
language_model: self.language_model.clone(),
error_model: self.error_model.clone(),
}
}
}
impl std::fmt::Debug for StatisticalCorrector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StatisticalCorrector")
.field("dictionary", &{
let dict_len = self.dictionary.len();
format!("<{dict_len} words>")
})
.field("config", &self.config)
.field("metric", &"<StringMetric>")
.field("language_model", &self.language_model)
.field("error_model", &self.error_model)
.finish()
}
}
impl Default for StatisticalCorrector {
fn default() -> Self {
let dict_corrector = DictionaryCorrector::default();
let mut language_model = NGramModel::new(3);
let sampletexts = [
"The quick brown fox jumps over the lazy dog.",
"She sells seashells by the seashore.",
"How much wood would a woodchuck chuck if a woodchuck could chuck wood?",
"To be or not to be, that is the question.",
"Four score and seven years ago our fathers brought forth on this continent a new nation.",
"Ask not what your country can do for you, ask what you can do for your country.",
"That's one small step for man, one giant leap for mankind.",
"I have a dream that one day this nation will rise up and live out the true meaning of its creed.",
"The only thing we have to fear is fear itself.",
"We hold these truths to be self-evident, that all men are created equal.",
];
for text in &sampletexts {
language_model.addtext(text);
}
Self {
dictionary: dict_corrector.dictionary,
config: StatisticalCorrectorConfig::default(),
metric: Arc::new(DamerauLevenshteinMetric::new()),
language_model,
error_model: ErrorModel::default(),
}
}
}
impl StatisticalCorrector {
pub fn new(config: StatisticalCorrectorConfig) -> Self {
Self {
config,
..Default::default()
}
}
pub fn from_dictionary_corrector(dictcorrector: &DictionaryCorrector) -> Self {
let config = StatisticalCorrectorConfig {
max_edit_distance: dictcorrector.config.max_edit_distance,
case_sensitive: dictcorrector.config.case_sensitive,
max_suggestions: dictcorrector.config.max_suggestions,
min_frequency: dictcorrector.config.min_frequency,
..StatisticalCorrectorConfig::default()
};
Self {
dictionary: dictcorrector.dictionary.clone(),
config,
metric: dictcorrector.metric.clone(),
language_model: NGramModel::new(3),
error_model: ErrorModel::default(),
}
}
pub fn add_corpus_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
self.language_model.add_corpus_file(path)
}
pub fn add_trainingtext(&mut self, text: &str) {
self.language_model.addtext(text);
}
pub fn set_language_model(&mut self, model: NGramModel) {
self.language_model = model;
}
pub fn set_error_model(&mut self, model: ErrorModel) {
self.error_model = model;
}
pub fn set_metric<M: StringMetric + Send + Sync + 'static>(&mut self, metric: M) {
self.metric = Arc::new(metric);
}
pub fn set_config(&mut self, config: StatisticalCorrectorConfig) {
self.config = config;
}
fn get_contextual_corrections(&self, word: &str, context: &[String]) -> Vec<(String, f64)> {
if self.is_correct(word) {
return vec![(word.to_string(), 1.0)];
}
let word_to_check = if !self.config.case_sensitive {
word.to_lowercase()
} else {
word.to_string()
};
let mut candidates: Vec<(String, f64)> = Vec::new();
for (dict_word, frequency) in &self.dictionary {
if *frequency < self.config.min_frequency {
continue;
}
let dict_word_normalized = if !self.config.case_sensitive {
dict_word.to_lowercase()
} else {
dict_word.clone()
};
if dict_word_normalized.len() > word_to_check.len() + self.config.max_edit_distance
|| dict_word_normalized.len() + self.config.max_edit_distance < word_to_check.len()
{
continue;
}
if let Ok(distance) = self.metric.distance(&word_to_check, &dict_word_normalized) {
let distance_usize = distance.round() as usize;
if distance_usize <= self.config.max_edit_distance {
let edit_score = 1.0 / (1.0 + distance);
let lm_score = if self.config.use_context {
self.language_model.probability(dict_word, context)
} else {
self.language_model.unigram_probability(dict_word)
};
let error_score = self
.error_model
.error_probability(&word_to_check, &dict_word_normalized);
let combined_score = (self.config.edit_distance_weight * edit_score)
+ (self.config.language_model_weight * lm_score * error_score);
candidates.push((dict_word.clone(), combined_score));
}
}
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(self.config.max_suggestions);
candidates
}
pub fn correct_sentence(&self, sentence: &str) -> Result<String> {
let words: Vec<String> = sentence
.split_whitespace()
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty())
.collect();
if words.is_empty() {
return Ok(sentence.to_string());
}
if !self.config.use_context {
let mut result = sentence.to_string();
for word in &words {
if !self.is_correct(word) {
if let Ok(correction) = self.correct(word) {
if correction != *word {
result = result.replace(word, &correction);
}
}
}
}
return Ok(result);
}
let context_window = self.config.context_window;
let max_candidates = self.config.max_candidates;
let mut beams: Vec<(Vec<String>, f64, Vec<String>)> = vec![(Vec::new(), 0.0, Vec::new())];
for word in &words {
let mut new_beams = Vec::new();
for (partial, score, context) in beams {
let candidates = self.get_contextual_corrections(word, &context);
for (candidate, candidate_score) in candidates.iter().take(max_candidates) {
let mut new_partial = partial.clone();
new_partial.push(candidate.clone());
let mut new_context = context.clone();
new_context.push(candidate.clone());
if new_context.len() > context_window {
new_context.remove(0);
}
let new_score = score + candidate_score;
new_beams.push((new_partial, new_score, new_context));
}
}
new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
new_beams.truncate(max_candidates);
beams = new_beams;
}
if let Some((best_sentence, _, _)) = beams.first() {
let mut result = sentence.to_string();
for (i, original) in words.iter().enumerate() {
if i < best_sentence.len() && original != &best_sentence[i] {
result = result.replace(original, &best_sentence[i]);
}
}
Ok(result)
} else {
Ok(sentence.to_string())
}
}
pub fn add_word(&mut self, word: &str, frequency: usize) {
self.dictionary.insert(word.to_string(), frequency);
}
pub fn remove_word(&mut self, word: &str) {
self.dictionary.remove(word);
}
pub fn dictionary_size(&self) -> usize {
self.dictionary.len()
}
pub fn vocabulary_size(&self) -> usize {
self.language_model.vocabulary_size()
}
}
impl SpellingCorrector for StatisticalCorrector {
fn correct(&self, word: &str) -> Result<String> {
if self.is_correct(word) {
return Ok(word.to_string());
}
let suggestions = self.get_suggestions(word, 1)?;
if suggestions.is_empty() {
Ok(word.to_string())
} else {
Ok(suggestions[0].clone())
}
}
fn get_suggestions(&self, word: &str, limit: usize) -> Result<Vec<String>> {
if self.is_correct(word) {
return Ok(vec![word.to_string()]);
}
let candidates = self.get_contextual_corrections(word, &[]);
let suggestions = candidates
.into_iter()
.map(|(word, _)| word)
.take(limit)
.collect();
Ok(suggestions)
}
fn is_correct(&self, word: &str) -> bool {
if self.config.case_sensitive {
self.dictionary.contains_key(word)
} else {
self.dictionary
.keys()
.any(|dict_word| dict_word.to_lowercase() == word.to_lowercase())
}
}
fn correcttext(&self, text: &str) -> Result<String> {
let sentences: Vec<&str> = text
.split(['.', '?', '!'])
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
if sentences.is_empty() {
return Ok(text.to_string());
}
let mut result = text.to_string();
for sentence in sentences {
if sentence.trim().is_empty() {
continue;
}
let corrected_sentence = self.correct_sentence(sentence)?;
if corrected_sentence != sentence {
result = result.replace(sentence, &corrected_sentence);
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_statistical_corrector_basic() {
let mut corrector = StatisticalCorrector::default();
corrector.add_trainingtext("The quick brown fox jumps over the lazy dog.");
corrector.add_trainingtext("Programming languages like Python and Rust are popular.");
corrector.add_trainingtext("I received your message about the meeting tomorrow.");
corrector.add_word("received", 100);
corrector.add_word("message", 100);
corrector.add_word("meeting", 100);
corrector.add_word("tomorrow", 100);
assert_eq!(
corrector.correct("recieved").expect("Operation failed"),
"received"
);
assert_eq!(
corrector.correct("mesage").expect("Operation failed"),
"message"
);
let text = "I recieved your mesage about the meating tommorow.";
let corrected = corrector.correcttext(text).expect("Operation failed");
assert!(corrected.contains("received"));
assert!(corrected.contains("message"));
assert!(corrected.contains("meeting"));
assert!(corrected.contains("tomorrow"));
}
#[test]
fn test_statistical_corrector_context_aware() {
let mut corrector = StatisticalCorrector::default();
corrector.add_trainingtext("I went to the bank to deposit money.");
corrector.add_trainingtext("The river bank was muddy after the rain.");
corrector.add_trainingtext("I need to address the issues in the meeting.");
corrector.add_trainingtext("What is your home address?");
corrector.add_word("bank", 100);
corrector.add_word("deposit", 100);
corrector.add_word("money", 100);
corrector.add_word("river", 100);
corrector.add_word("muddy", 100);
corrector.add_word("rain", 100);
let text1 = "I went to the bnk to deposit money.";
let text2 = "The river bnk was muddy after the rain.";
let corrected1 = corrector.correcttext(text1).expect("Operation failed");
let corrected2 = corrector.correcttext(text2).expect("Operation failed");
assert!(corrected1.contains("bank"));
assert!(corrected2.contains("bank"));
}
#[test]
fn test_from_dictionary_corrector() {
let dict_corrector = DictionaryCorrector::default();
let stat_corrector = StatisticalCorrector::from_dictionary_corrector(&dict_corrector);
assert_eq!(
dict_corrector.dictionary_size(),
stat_corrector.dictionary_size()
);
let word = "recieve";
assert!(dict_corrector.correct(word).is_ok());
assert!(stat_corrector.correct(word).is_ok());
assert_eq!(
dict_corrector.is_correct("receive"),
stat_corrector.is_correct("receive")
);
}
}