use alloc::string::String;
use alloc::vec::Vec;
use crate::dict::BUILTIN_WORDS;
use crate::freq::FreqMap;
use crate::segmenter::Tokenizer;
use crate::soundex::lk82;
use crate::token::TokenKind;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Suggestion {
pub word: String,
pub edit_distance: u8,
pub soundex_match: bool,
pub freq_score: u32,
}
pub struct SpellChecker {
words_text: &'static str,
freq: FreqMap,
tokenizer: Tokenizer,
}
impl SpellChecker {
pub fn builtin() -> Self {
Self {
words_text: BUILTIN_WORDS,
freq: FreqMap::builtin(),
tokenizer: Tokenizer::new(),
}
}
pub fn suggestions(&self, word: &str, max_n: usize) -> Vec<Suggestion> {
if word.is_empty() || max_n == 0 {
return Vec::new();
}
let word_chars: Vec<char> = word.chars().collect();
let word_len = word_chars.len();
let input_lk82 = lk82(word);
let mut results: Vec<Suggestion> = Vec::new();
for line in self.words_text.lines() {
let candidate = line.trim();
if candidate.is_empty() || candidate.starts_with('#') {
continue;
}
let cand_len = candidate.chars().count();
if cand_len + 2 < word_len || word_len + 2 < cand_len {
continue;
}
let dist = levenshtein(&word_chars, candidate, 2);
if dist > 2 {
continue;
}
let soundex_match = lk82(candidate) == input_lk82;
let freq_score = self.freq.get(candidate);
results.push(Suggestion {
word: String::from(candidate),
edit_distance: dist as u8,
soundex_match,
freq_score,
});
}
results.sort_unstable_by(|a, b| {
b.soundex_match
.cmp(&a.soundex_match)
.then(a.edit_distance.cmp(&b.edit_distance))
.then(b.freq_score.cmp(&a.freq_score))
});
results.truncate(max_n);
results
}
pub fn did_you_mean(&self, word: &str) -> Option<String> {
match self.suggestions(word, 1).into_iter().next() {
Some(s) if s.edit_distance == 0 => None,
Some(s) => Some(s.word),
None => None,
}
}
pub fn correct_text(&self, text: &str) -> String {
if text.is_empty() {
return String::new();
}
let tokens = self.tokenizer.segment(text);
let mut result = String::with_capacity(text.len());
for token in &tokens {
if token.kind == TokenKind::Unknown && token.text.chars().count() >= 2 {
match self.did_you_mean(token.text) {
Some(correction) => result.push_str(&correction),
None => result.push_str(token.text),
}
} else {
result.push_str(token.text);
}
}
result
}
}
fn levenshtein(a_chars: &[char], b: &str, max_dist: usize) -> usize {
let b_chars: Vec<char> = b.chars().collect();
let m = a_chars.len();
let n = b_chars.len();
if m == 0 {
return n;
}
if n == 0 {
return m;
}
if m.abs_diff(n) > max_dist {
return max_dist + 1;
}
let mut prev: Vec<usize> = (0..=n).collect();
let mut curr: Vec<usize> = alloc::vec![0usize; n + 1];
for i in 1..=m {
curr[0] = i;
let mut row_min = i;
for j in 1..=n {
let substitution_cost = usize::from(a_chars[i - 1] != b_chars[j - 1]);
curr[j] = (prev[j] + 1)
.min(curr[j - 1] + 1)
.min(prev[j - 1] + substitution_cost);
if curr[j] < row_min {
row_min = curr[j];
}
}
if row_min > max_dist {
return max_dist + 1;
}
core::mem::swap(&mut prev, &mut curr);
}
prev[n]
}
#[cfg(test)]
mod tests {
use super::*;
fn checker() -> SpellChecker {
SpellChecker::builtin()
}
#[test]
fn empty_input_returns_empty() {
assert!(checker().suggestions("", 5).is_empty());
}
#[test]
fn zero_max_n_returns_empty() {
assert!(checker().suggestions("กาน", 0).is_empty());
}
#[test]
fn all_results_within_threshold() {
let suggs = checker().suggestions("กาน", 20);
assert!(
suggs.iter().all(|s| s.edit_distance <= 2),
"got out-of-range result: {suggs:?}"
);
}
#[test]
fn respects_max_n() {
let suggs = checker().suggestions("กาน", 3);
assert!(suggs.len() <= 3);
}
#[test]
fn word_in_dict_appears_as_dist_zero() {
let suggs = checker().suggestions("กิน", 10);
let exact = suggs.iter().find(|s| s.word == "กิน");
assert!(exact.is_some(), "expected กิน in suggestions");
assert_eq!(exact.unwrap().edit_distance, 0);
}
#[test]
fn sorted_soundex_first_then_distance() {
let suggs = checker().suggestions("กาน", 20);
for window in suggs.windows(2) {
let (a, b) = (&window[0], &window[1]);
let ok = (a.soundex_match & !b.soundex_match)
|| (a.soundex_match == b.soundex_match && a.edit_distance < b.edit_distance)
|| (a.soundex_match == b.soundex_match
&& a.edit_distance == b.edit_distance
&& a.freq_score >= b.freq_score);
assert!(ok, "sort order violated: {a:?} before {b:?}");
}
}
#[test]
fn transposition_within_distance_two() {
let suggs = checker().suggestions("สวสัดี", 5);
let hit = suggs.iter().find(|s| s.word == "สวัสดี");
assert!(
hit.is_some(),
"expected สวัสดี in suggestions for สวสัดี; got: {suggs:?}"
);
assert!(hit.unwrap().edit_distance <= 2);
}
#[test]
fn single_char_deletion_suggestion() {
let suggs = checker().suggestions("กินข้า", 5);
let hit = suggs.iter().find(|s| s.word == "กินข้าว");
assert!(
hit.is_some(),
"expected กินข้าว in suggestions for กินข้า; got: {suggs:?}"
);
assert_eq!(hit.unwrap().edit_distance, 1);
}
fn lev(a: &str, b: &str) -> usize {
let a_chars: Vec<char> = a.chars().collect();
levenshtein(&a_chars, b, 255)
}
#[test]
fn lev_identical() {
assert_eq!(lev("กิน", "กิน"), 0);
}
#[test]
fn lev_single_substitution() {
assert_eq!(lev("กาน", "กาล"), 1);
}
#[test]
fn lev_single_deletion() {
assert_eq!(lev("กินข้าว", "กินข้า"), 1);
}
#[test]
fn lev_single_insertion() {
assert_eq!(lev("กินข้า", "กินข้าว"), 1);
}
#[test]
fn lev_transposition_is_two() {
assert_eq!(lev("สวสัดี", "สวัสดี"), 2);
}
#[test]
fn lev_empty_strings() {
assert_eq!(lev("", ""), 0);
assert_eq!(lev("กิน", ""), 3);
assert_eq!(lev("", "กิน"), 3);
}
#[test]
fn lev_early_exit_above_max() {
let a_chars: Vec<char> = "กิน".chars().collect();
assert_eq!(levenshtein(&a_chars, "สวัสดีครับ", 2), 3);
}
#[test]
fn did_you_mean_correct_word_returns_none() {
assert_eq!(checker().did_you_mean("กิน"), None);
}
#[test]
fn did_you_mean_misspelled_word_returns_suggestion() {
let result = checker().did_you_mean("กินข้า");
assert!(result.is_some(), "expected a suggestion for กินข้า");
assert_ne!(
result.as_deref(),
Some("กินข้า"),
"suggestion should differ from input"
);
}
#[test]
fn correct_text_passthrough_clean_input() {
let out = checker().correct_text("กินข้าวกับปลา");
assert!(!out.is_empty(), "output should not be empty");
let tokenizer = crate::segmenter::Tokenizer::new();
let tokens = tokenizer.segment(&out);
assert!(
tokens
.iter()
.all(|t| t.kind != crate::token::TokenKind::Unknown),
"expected no Unknown tokens in corrected output; got: {tokens:?}"
);
}
}