use crate::ocr::error::{OcrError, OcrResult};
pub trait PostProcessor: Send + Sync {
fn correct(&self, text: &str) -> String;
}
pub struct NoopCorrector;
impl PostProcessor for NoopCorrector {
fn correct(&self, text: &str) -> String {
text.to_string()
}
}
pub struct SymspellCorrector {
inner: symspell::SymSpell<symspell::AsciiStringStrategy>,
max_edit_distance: i64,
}
impl SymspellCorrector {
pub fn with_wordlist(content: &str, max_edit_distance: i64) -> OcrResult<Self> {
use symspell::{SymSpell, SymSpellBuilder};
let mut inner: SymSpell<symspell::AsciiStringStrategy> = SymSpellBuilder::default()
.max_dictionary_edit_distance(max_edit_distance)
.prefix_length(7)
.count_threshold(1)
.build()
.map_err(|e| OcrError::Config(format!("symspell build: {e:?}")))?;
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let mut parts = line.split_whitespace();
if let (Some(word), Some(count)) = (parts.next(), parts.next()) {
let normalized = format!("{word} {count}");
inner.load_dictionary_line(&normalized, 0, 1, " ");
}
}
Ok(Self {
inner,
max_edit_distance,
})
}
pub fn with_default_wordlist() -> OcrResult<Self> {
Self::with_wordlist(DEFAULT_WORDLIST, 2)
}
}
impl PostProcessor for SymspellCorrector {
fn correct(&self, text: &str) -> String {
use symspell::Verbosity;
text.split_inclusive(char::is_whitespace)
.map(|token| {
let (word, trailing): (String, String) = token
.chars()
.partition(|c| !c.is_whitespace());
if word.is_empty() {
return token.to_string();
}
let suggestions = self
.inner
.lookup(&word.to_lowercase(), Verbosity::Top, self.max_edit_distance);
if let Some(best) = suggestions.first() {
let mut corrected = best.term.clone();
if let Some(first) = word.chars().next() {
if first.is_uppercase() {
let mut chars = corrected.chars();
if let Some(c) = chars.next() {
corrected = c.to_uppercase().collect::<String>() + chars.as_str();
}
}
}
format!("{corrected}{trailing}")
} else {
token.to_string()
}
})
.collect()
}
}
pub const DEFAULT_WORDLIST: &str = include_str!("wordlist.txt");
pub fn beam_search_line(
glyphs: &[crate::ocr::recognize::RecognizedLine],
beam_width: usize,
wordlist: &str,
) -> String {
if glyphs.is_empty() {
return String::new();
}
let beam_width = beam_width.max(1);
let ranker = crate::ocr::bigram::BigramRanker::english();
let dict = build_word_set(wordlist);
#[derive(Clone)]
struct Hypothesis {
text: String,
score: f32,
}
let mut beam: Vec<Hypothesis> = vec![Hypothesis {
text: String::new(),
score: 0.0,
}];
let alpha: f32 = 1.0;
let beta: f32 = 0.6;
let gamma: f32 = 0.8;
for g in glyphs {
let cands: Vec<(char, f32)> = if g.alternatives.is_empty() {
vec![(g.text.chars().next().unwrap_or(' '), 0.0)]
} else {
g.alternatives.clone()
};
let mut extended: Vec<Hypothesis> = Vec::with_capacity(beam.len() * cands.len());
for h in &beam {
for (label, dist) in &cands {
let prev_char = h.text.chars().last().unwrap_or(' ');
let bigram = ranker
.log_probs
.get(&(prev_char.to_ascii_lowercase(), label.to_ascii_lowercase()))
.copied()
.unwrap_or(ranker.floor_log_prob);
let mut new_text = h.text.clone();
new_text.push(*label);
let bonus = dict_bonus(&new_text, &dict);
let recog = -dist;
let score = h.score + alpha * recog + beta * bigram + gamma * bonus;
extended.push(Hypothesis {
text: new_text,
score,
});
}
}
extended.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
extended.truncate(beam_width);
beam = extended;
}
beam.into_iter()
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
.map(|h| h.text)
.unwrap_or_default()
}
fn build_word_set(wordlist: &str) -> std::collections::HashSet<String> {
let mut set = std::collections::HashSet::new();
for line in wordlist.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
if let Some(word) = trimmed.split_whitespace().next() {
set.insert(word.to_ascii_lowercase());
}
}
set
}
fn dict_bonus(text: &str, dict: &std::collections::HashSet<String>) -> f32 {
let last_word: String = text
.chars()
.rev()
.take_while(|c| !c.is_whitespace())
.collect::<Vec<_>>()
.into_iter()
.rev()
.map(|c| c.to_ascii_lowercase())
.collect();
if last_word.is_empty() {
return 0.0;
}
if dict.contains(&last_word) {
1.0
} else if dict.iter().any(|w| w.starts_with(&last_word) && w.len() > last_word.len()) {
0.2
} else {
0.0
}
}