use std::collections::HashMap;
use std::marker::PhantomData;
use crate::semiring::{Semiring, TropicalWeight};
use crate::wfst::{MutableWfst, VectorWfst};
#[derive(Clone, Debug)]
pub struct ConfusionConfig {
pub identity_cost: f64,
pub default_substitution_cost: f64,
pub default_deletion_cost: f64,
pub default_insertion_cost: f64,
pub include_identity: bool,
pub max_confusions_per_char: Option<usize>,
}
impl Default for ConfusionConfig {
fn default() -> Self {
ConfusionConfig {
identity_cost: 0.0,
default_substitution_cost: 2.0,
default_deletion_cost: 1.5,
default_insertion_cost: 1.5,
include_identity: true,
max_confusions_per_char: None,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct ConfusionMatrix {
substitutions: HashMap<(char, char), f64>,
deletions: HashMap<char, f64>,
insertions: HashMap<char, f64>,
transpositions: HashMap<(char, char), f64>,
}
impl ConfusionMatrix {
pub fn new() -> Self {
ConfusionMatrix::default()
}
pub fn add_substitution(&mut self, intended: char, observed: char, cost: f64) -> &mut Self {
self.substitutions.insert((intended, observed), cost);
self
}
pub fn add_symmetric_substitution(&mut self, a: char, b: char, cost: f64) -> &mut Self {
self.substitutions.insert((a, b), cost);
self.substitutions.insert((b, a), cost);
self
}
pub fn add_deletion(&mut self, intended: char, cost: f64) -> &mut Self {
self.deletions.insert(intended, cost);
self
}
pub fn add_insertion(&mut self, observed: char, cost: f64) -> &mut Self {
self.insertions.insert(observed, cost);
self
}
pub fn add_transposition(&mut self, a: char, b: char, cost: f64) -> &mut Self {
self.transpositions.insert((a, b), cost);
self.transpositions.insert((b, a), cost);
self
}
pub fn substitution_cost(&self, intended: char, observed: char) -> Option<f64> {
self.substitutions.get(&(intended, observed)).copied()
}
pub fn deletion_cost(&self, intended: char) -> Option<f64> {
self.deletions.get(&intended).copied()
}
pub fn insertion_cost(&self, observed: char) -> Option<f64> {
self.insertions.get(&observed).copied()
}
pub fn transposition_cost(&self, a: char, b: char) -> Option<f64> {
self.transpositions.get(&(a, b)).copied()
}
pub fn substitutions_for(&self, intended: char) -> impl Iterator<Item = (char, f64)> + '_ {
self.substitutions
.iter()
.filter(move |((i, _), _)| *i == intended)
.map(|((_, o), c)| (*o, *c))
}
pub fn confusable_with(&self, intended: char) -> Vec<(char, f64)> {
self.substitutions_for(intended).collect()
}
pub fn merge(&mut self, other: &ConfusionMatrix) {
for ((i, o), cost) in &other.substitutions {
self.substitutions
.entry((*i, *o))
.and_modify(|c| *c = c.min(*cost))
.or_insert(*cost);
}
for (c, cost) in &other.deletions {
self.deletions
.entry(*c)
.and_modify(|c| *c = c.min(*cost))
.or_insert(*cost);
}
for (c, cost) in &other.insertions {
self.insertions
.entry(*c)
.and_modify(|c| *c = c.min(*cost))
.or_insert(*cost);
}
for ((a, b), cost) in &other.transpositions {
self.transpositions
.entry((*a, *b))
.and_modify(|c| *c = c.min(*cost))
.or_insert(*cost);
}
}
pub fn alphabet(&self) -> Vec<char> {
let mut chars: Vec<char> = self
.substitutions
.keys()
.flat_map(|(a, b)| [*a, *b])
.chain(self.deletions.keys().copied())
.chain(self.insertions.keys().copied())
.chain(self.transpositions.keys().flat_map(|(a, b)| [*a, *b]))
.collect();
chars.sort();
chars.dedup();
chars
}
pub fn num_substitutions(&self) -> usize {
self.substitutions.len()
}
pub fn num_deletions(&self) -> usize {
self.deletions.len()
}
pub fn num_insertions(&self) -> usize {
self.insertions.len()
}
}
#[derive(Clone, Debug)]
pub struct ConfusionTransducer<W: Semiring> {
matrix: ConfusionMatrix,
config: ConfusionConfig,
_phantom: PhantomData<W>,
}
impl<W: Semiring> ConfusionTransducer<W> {
pub fn from_matrix(matrix: ConfusionMatrix) -> Self {
ConfusionTransducer {
matrix,
config: ConfusionConfig::default(),
_phantom: PhantomData,
}
}
pub fn with_config(matrix: ConfusionMatrix, config: ConfusionConfig) -> Self {
ConfusionTransducer {
matrix,
config,
_phantom: PhantomData,
}
}
pub fn matrix(&self) -> &ConfusionMatrix {
&self.matrix
}
pub fn config(&self) -> &ConfusionConfig {
&self.config
}
pub fn build(&self) -> VectorWfst<char, W>
where
W: From<TropicalWeight>,
{
let mut fst = VectorWfst::new();
let state = fst.add_state();
fst.set_start(state);
fst.set_final(state, W::one());
let alphabet = self.matrix.alphabet();
for &c in &alphabet {
if self.config.include_identity {
let weight = W::from(TropicalWeight::new(self.config.identity_cost));
fst.add_arc(state, Some(c), Some(c), state, weight);
}
let mut subs: Vec<_> = self.matrix.substitutions_for(c).collect();
subs.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.expect("error_models/confusion.rs: required value was None/Err")
});
if let Some(max) = self.config.max_confusions_per_char {
subs.truncate(max);
}
for (observed, cost) in subs {
if observed != c {
let weight = W::from(TropicalWeight::new(cost));
fst.add_arc(state, Some(c), Some(observed), state, weight);
}
}
}
fst
}
pub fn build_with_indels(&self) -> VectorWfst<Option<char>, W>
where
W: From<TropicalWeight>,
{
let mut fst: VectorWfst<Option<char>, W> = VectorWfst::new();
let state = fst.add_state();
fst.set_start(state);
fst.set_final(state, W::one());
let alphabet = self.matrix.alphabet();
for &c in &alphabet {
if self.config.include_identity {
let weight = W::from(TropicalWeight::new(self.config.identity_cost));
fst.add_arc(state, Some(Some(c)), Some(Some(c)), state, weight);
}
for (observed, cost) in self.matrix.substitutions_for(c) {
if observed != c {
let weight = W::from(TropicalWeight::new(cost));
fst.add_arc(state, Some(Some(c)), Some(Some(observed)), state, weight);
}
}
let del_cost = self
.matrix
.deletion_cost(c)
.unwrap_or(self.config.default_deletion_cost);
let weight = W::from(TropicalWeight::new(del_cost));
fst.add_arc(state, Some(Some(c)), Some(None), state, weight);
}
for &c in &alphabet {
let ins_cost = self
.matrix
.insertion_cost(c)
.unwrap_or(self.config.default_insertion_cost);
let weight = W::from(TropicalWeight::new(ins_cost));
fst.add_arc(state, Some(None), Some(Some(c)), state, weight);
}
fst
}
}
pub fn train_confusion_matrix(pairs: &[(String, String)], smoothing: f64) -> ConfusionMatrix {
let mut counts: HashMap<(char, char), f64> = HashMap::new();
let mut char_counts: HashMap<char, f64> = HashMap::new();
for (correct, observed) in pairs {
for (c_char, o_char) in correct.chars().zip(observed.chars()) {
*counts.entry((c_char, o_char)).or_default() += 1.0;
*char_counts.entry(c_char).or_default() += 1.0;
}
}
let mut matrix = ConfusionMatrix::new();
for ((intended, observed), count) in counts {
let total = char_counts.get(&intended).unwrap_or(&1.0);
let prob = (count + smoothing) / (total + smoothing * 256.0);
let cost = -prob.ln();
if intended != observed {
matrix.add_substitution(intended, observed, cost);
}
}
matrix
}
const QWERTY_ROWS: &[&str] = &[
"1234567890-=",
"qwertyuiop[]\\",
"asdfghjkl;'",
"zxcvbnm,./",
];
const DVORAK_ROWS: &[&str] = &[
"1234567890[]",
"',.pyfgcrl/=\\",
"aoeuidhtns-",
";qjkxbmwvz",
];
fn keyboard_confusion_from_layout(
rows: &[&str],
base_cost: f64,
diagonal_penalty: f64,
) -> ConfusionMatrix {
let mut matrix = ConfusionMatrix::new();
let mut positions: HashMap<char, (usize, usize)> = HashMap::new();
for (row_idx, row) in rows.iter().enumerate() {
for (col_idx, c) in row.chars().enumerate() {
positions.insert(c, (row_idx, col_idx));
positions.insert(c.to_ascii_uppercase(), (row_idx, col_idx));
}
}
for (row_idx, row) in rows.iter().enumerate() {
for (col_idx, c) in row.chars().enumerate() {
let offsets: [(i32, i32); 8] = [
(-1, -1),
(-1, 0),
(-1, 1),
(0, -1),
(0, 1),
(1, -1),
(1, 0),
(1, 1),
];
for (dr, dc) in offsets {
let new_row = row_idx as i32 + dr;
let new_col = col_idx as i32 + dc;
if new_row >= 0 && new_row < rows.len() as i32 {
if let Some(adj_char) = rows[new_row as usize].chars().nth(new_col as usize) {
let cost = if dr != 0 && dc != 0 {
base_cost + diagonal_penalty
} else {
base_cost
};
matrix.add_symmetric_substitution(c, adj_char, cost);
matrix.add_symmetric_substitution(
c.to_ascii_uppercase(),
adj_char.to_ascii_uppercase(),
cost,
);
}
}
}
}
}
matrix
}
pub fn qwerty_confusion_matrix() -> ConfusionMatrix {
keyboard_confusion_from_layout(QWERTY_ROWS, 0.5, 0.2)
}
pub fn dvorak_confusion_matrix() -> ConfusionMatrix {
keyboard_confusion_from_layout(DVORAK_ROWS, 0.5, 0.2)
}
pub fn ocr_confusion_matrix() -> ConfusionMatrix {
let mut matrix = ConfusionMatrix::new();
let high_prob_confusions = [
('0', 'O'),
('O', '0'),
('0', 'o'),
('o', '0'),
('1', 'l'),
('l', '1'),
('1', 'I'),
('I', '1'),
('l', 'I'),
('I', 'l'),
('5', 'S'),
('S', '5'),
('8', 'B'),
('B', '8'),
('2', 'Z'),
('Z', '2'),
];
for (a, b) in high_prob_confusions {
matrix.add_substitution(a, b, 0.3);
}
let medium_prob_confusions = [
('c', 'e'),
('e', 'c'),
('n', 'h'),
('h', 'n'),
('u', 'v'),
('v', 'u'),
('f', 't'),
('t', 'f'),
('i', 'j'),
('j', 'i'),
('m', 'n'),
('n', 'm'),
('a', 'o'),
('o', 'a'),
('g', 'q'),
('q', 'g'),
('p', 'P'),
('P', 'p'), ('k', 'K'),
('K', 'k'),
];
for (a, b) in medium_prob_confusions {
matrix.add_substitution(a, b, 0.7);
}
let low_prob_confusions = [
('b', 'd'),
('d', 'b'),
('p', 'q'),
('q', 'p'),
('6', 'G'),
('G', '6'),
('9', 'g'),
('g', '9'),
];
for (a, b) in low_prob_confusions {
matrix.add_substitution(a, b, 1.2);
}
matrix
}
pub fn combined_confusion_matrix() -> ConfusionMatrix {
let mut matrix = qwerty_confusion_matrix();
matrix.merge(&ocr_confusion_matrix());
matrix
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::Wfst;
#[test]
fn test_confusion_matrix_basic() {
let mut matrix = ConfusionMatrix::new();
matrix.add_substitution('a', 'e', 0.5);
matrix.add_substitution('a', 'o', 0.8);
matrix.add_deletion('x', 1.0);
matrix.add_insertion('z', 1.2);
assert_eq!(matrix.substitution_cost('a', 'e'), Some(0.5));
assert_eq!(matrix.substitution_cost('a', 'o'), Some(0.8));
assert_eq!(matrix.substitution_cost('b', 'c'), None);
assert_eq!(matrix.deletion_cost('x'), Some(1.0));
assert_eq!(matrix.insertion_cost('z'), Some(1.2));
}
#[test]
fn test_symmetric_substitution() {
let mut matrix = ConfusionMatrix::new();
matrix.add_symmetric_substitution('a', 'e', 0.5);
assert_eq!(matrix.substitution_cost('a', 'e'), Some(0.5));
assert_eq!(matrix.substitution_cost('e', 'a'), Some(0.5));
}
#[test]
fn test_confusable_with() {
let mut matrix = ConfusionMatrix::new();
matrix.add_substitution('a', 'e', 0.5);
matrix.add_substitution('a', 'o', 0.8);
matrix.add_substitution('a', 'i', 1.0);
let confusable = matrix.confusable_with('a');
assert_eq!(confusable.len(), 3);
assert!(confusable.contains(&('e', 0.5)));
assert!(confusable.contains(&('o', 0.8)));
}
#[test]
fn test_alphabet() {
let mut matrix = ConfusionMatrix::new();
matrix.add_substitution('a', 'e', 0.5);
matrix.add_substitution('b', 'c', 0.5);
matrix.add_deletion('x', 1.0);
let alphabet = matrix.alphabet();
assert!(alphabet.contains(&'a'));
assert!(alphabet.contains(&'e'));
assert!(alphabet.contains(&'b'));
assert!(alphabet.contains(&'c'));
assert!(alphabet.contains(&'x'));
}
#[test]
fn test_merge() {
let mut matrix1 = ConfusionMatrix::new();
matrix1.add_substitution('a', 'e', 0.8);
matrix1.add_substitution('b', 'c', 0.5);
let mut matrix2 = ConfusionMatrix::new();
matrix2.add_substitution('a', 'e', 0.5); matrix2.add_substitution('d', 'f', 0.6);
matrix1.merge(&matrix2);
assert_eq!(matrix1.substitution_cost('a', 'e'), Some(0.5));
assert_eq!(matrix1.substitution_cost('b', 'c'), Some(0.5));
assert_eq!(matrix1.substitution_cost('d', 'f'), Some(0.6));
}
#[test]
fn test_build_transducer() {
let mut matrix = ConfusionMatrix::new();
matrix.add_substitution('a', 'e', 0.5);
matrix.add_substitution('a', 'o', 0.8);
let transducer = ConfusionTransducer::<TropicalWeight>::from_matrix(matrix);
let fst = transducer.build();
assert_eq!(fst.num_states(), 1);
let start = fst.start();
assert!(fst.is_final(start));
}
#[test]
fn test_qwerty_matrix() {
let matrix = qwerty_confusion_matrix();
assert!(matrix.substitution_cost('q', 'w').is_some());
assert!(matrix.substitution_cost('w', 'q').is_some());
assert!(matrix.substitution_cost('a', 's').is_some());
}
#[test]
fn test_ocr_matrix() {
let matrix = ocr_confusion_matrix();
assert!(matrix.substitution_cost('0', 'O').is_some());
assert!(matrix.substitution_cost('1', 'l').is_some());
assert!(matrix.substitution_cost('l', 'I').is_some());
}
#[test]
fn test_combined_matrix() {
let matrix = combined_confusion_matrix();
assert!(matrix.substitution_cost('q', 'w').is_some()); assert!(matrix.substitution_cost('0', 'O').is_some()); }
#[test]
fn test_train_confusion_matrix() {
let pairs = vec![
("hello".to_string(), "hallo".to_string()),
("hello".to_string(), "hella".to_string()),
("world".to_string(), "warld".to_string()),
];
let matrix = train_confusion_matrix(&pairs, 0.1);
assert!(matrix.substitution_cost('e', 'a').is_some());
assert!(matrix.substitution_cost('o', 'a').is_some());
}
#[test]
fn test_config() {
let matrix = qwerty_confusion_matrix();
let config = ConfusionConfig {
identity_cost: 0.1,
include_identity: true,
max_confusions_per_char: Some(3),
..Default::default()
};
let transducer = ConfusionTransducer::<TropicalWeight>::with_config(matrix, config);
assert_eq!(transducer.config().identity_cost, 0.1);
assert_eq!(transducer.config().max_confusions_per_char, Some(3));
}
#[test]
fn test_build_with_indels() {
let mut matrix = ConfusionMatrix::new();
matrix.add_substitution('a', 'e', 0.5);
matrix.add_deletion('x', 1.0);
matrix.add_insertion('z', 1.2);
let transducer = ConfusionTransducer::<TropicalWeight>::from_matrix(matrix);
let fst = transducer.build_with_indels();
assert_eq!(fst.num_states(), 1);
let start = fst.start();
assert!(fst.is_final(start));
}
}