#[derive(Debug, Clone, PartialEq)]
pub struct FuzzyCorrection {
pub token: &'static str,
pub distance: u8,
pub confidence: f32,
}
pub struct FuzzyVocabMatcher<'v> {
vocab: &'v [&'static str],
}
impl<'v> FuzzyVocabMatcher<'v> {
pub fn new(vocab: &'v [&'static str]) -> Self {
Self { vocab }
}
pub fn correct(&self, token: &str) -> Option<FuzzyCorrection> {
if self.vocab.binary_search(&token).is_ok() {
return None;
}
let token_len = token.len();
if token_len < MIN_FUZZY_LEN {
return None;
}
let scratch_len = token_len + 1;
let mut prev = vec![0u8; scratch_len];
let mut curr = vec![0u8; scratch_len];
let mut best_dist = u8::MAX;
let mut best_token: Option<&'static str> = None;
let mut ambiguous = false;
for &candidate in self.vocab {
let cand_len = candidate.len();
let len_diff = token_len.abs_diff(cand_len);
if len_diff > MAX_EDIT_DISTANCE as usize {
continue;
}
let d = levenshtein_with_scratch(token, candidate, &mut prev, &mut curr);
if d < best_dist {
best_dist = d;
best_token = Some(candidate);
ambiguous = false;
} else if d == best_dist {
ambiguous = true;
}
}
if ambiguous || best_dist > MAX_EDIT_DISTANCE {
return None;
}
let token = best_token?;
let confidence = correction_confidence(best_dist, token_len);
if confidence < MIN_USEFUL_CONFIDENCE {
return None;
}
Some(FuzzyCorrection {
token,
distance: best_dist,
confidence,
})
}
pub fn correct_all(&self, token: &str) -> Vec<FuzzyCorrection> {
self.correct_all_with_floor(token, MIN_USEFUL_CONFIDENCE)
}
pub fn correct_all_with_floor(
&self,
token: &str,
confidence_floor: f32,
) -> Vec<FuzzyCorrection> {
debug_assert!(
(0.0..=1.0).contains(&confidence_floor),
"confidence_floor must be in [0.0, 1.0], got {confidence_floor}"
);
if self.vocab.binary_search(&token).is_ok() {
return Vec::new();
}
let token_len = token.len();
if token_len < MIN_FUZZY_LEN {
return Vec::new();
}
let scratch_len = token_len + 1;
let mut prev = vec![0u8; scratch_len];
let mut curr = vec![0u8; scratch_len];
let mut hits: Vec<FuzzyCorrection> = Vec::new();
for &candidate in self.vocab {
let cand_len = candidate.len();
let len_diff = token_len.abs_diff(cand_len);
if len_diff > MAX_EDIT_DISTANCE as usize {
continue;
}
let d = levenshtein_with_scratch(token, candidate, &mut prev, &mut curr);
if d > MAX_EDIT_DISTANCE {
continue;
}
let confidence = correction_confidence(d, token_len);
if confidence < confidence_floor {
continue;
}
hits.push(FuzzyCorrection {
token: candidate,
distance: d,
confidence,
});
}
hits.sort_by_key(|h| h.distance);
hits
}
}
pub const MIN_FUZZY_LEN: usize = 3;
pub const MAX_EDIT_DISTANCE: u8 = 2;
const MIN_USEFUL_CONFIDENCE: f32 = 0.45;
fn correction_confidence(distance: u8, token_len: usize) -> f32 {
match distance {
0 => 1.0, 1 => {
let bonus = (token_len.min(6).saturating_sub(3)) as f32 * 0.05;
0.55 + bonus
}
2 => {
let bonus = (token_len.min(8).saturating_sub(5)) as f32 * 0.05;
0.40 + bonus
}
_ => 0.0,
}
}
#[cfg(test)]
pub(crate) fn levenshtein(a: &str, b: &str) -> u8 {
let shorter = a.len().min(b.len());
let mut prev = vec![0u8; shorter + 1];
let mut curr = vec![0u8; shorter + 1];
levenshtein_with_scratch(a, b, &mut prev, &mut curr)
}
pub(crate) fn levenshtein_with_scratch(
a: &str,
b: &str,
prev_buf: &mut [u8],
curr_buf: &mut [u8],
) -> u8 {
let (a, b) = if a.len() <= b.len() { (a, b) } else { (b, a) };
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let n = a_bytes.len();
let m = b_bytes.len();
if n == 0 {
return m.min(u8::MAX as usize) as u8;
}
debug_assert!(
prev_buf.len() > n && curr_buf.len() > n,
"scratch buffers must have len > min(a.len(), b.len())"
);
let (mut prev, mut curr) = (prev_buf, curr_buf);
for (i, slot) in prev.iter_mut().enumerate().take(n + 1) {
*slot = i.min(u8::MAX as usize) as u8;
}
for j in 1..=m {
curr[0] = j.min(u8::MAX as usize) as u8;
for i in 1..=n {
let cost = if a_bytes[i - 1] == b_bytes[j - 1] {
0u8
} else {
1u8
};
let del = prev[i].saturating_add(1);
let ins = curr[i - 1].saturating_add(1);
let sub = prev[i - 1].saturating_add(cost);
curr[i] = del.min(ins).min(sub);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[n]
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn lev_identical_strings() {
assert_eq!(levenshtein("SECRET", "SECRET"), 0);
}
#[test]
fn lev_single_transpose() {
assert_eq!(levenshtein("SERCET", "SECRET"), 2);
}
#[test]
fn lev_insertion() {
assert_eq!(levenshtein("CONFIDETIAL", "CONFIDENTIAL"), 1);
}
#[test]
fn lev_transposition() {
assert_eq!(levenshtein("NOFORN", "NOFRON"), 2);
}
#[test]
fn lev_empty_vs_nonempty() {
assert_eq!(levenshtein("", "SECRET"), 6);
assert_eq!(levenshtein("SECRET", ""), 6);
}
#[test]
fn lev_symmetry() {
assert_eq!(
levenshtein("NOFORN", "NOFRON"),
levenshtein("NOFRON", "NOFORN")
);
}
static TEST_VOCAB: &[&str] = &[
"CONFIDENTIAL",
"FOUO",
"NOFORN",
"SECRET",
"TOP SECRET",
"UNCLASSIFIED",
];
fn matcher() -> FuzzyVocabMatcher<'static> {
FuzzyVocabMatcher::new(TEST_VOCAB)
}
#[test]
fn known_token_returns_none() {
assert!(matcher().correct("SECRET").is_none());
}
#[test]
fn short_token_returns_none() {
assert!(matcher().correct("S").is_none());
assert!(matcher().correct("NF").is_none());
}
#[test]
fn confidetial_corrects_to_confidential() {
let result = matcher().correct("CONFIDETIAL");
assert_eq!(result.as_ref().map(|c| c.token), Some("CONFIDENTIAL"));
assert_eq!(result.map(|c| c.distance), Some(1));
}
#[test]
fn nofron_corrects_to_noforn() {
let result = matcher().correct("NOFRON");
assert_eq!(result.map(|c| c.token), Some("NOFORN"));
}
#[test]
fn sercet_corrects_to_secret() {
let result = matcher().correct("SERCET");
assert_eq!(result.as_ref().map(|c| c.token), Some("SECRET"));
let c = result.unwrap();
assert!(
c.confidence > 0.44,
"confidence should be non-trivial: {}",
c.confidence
);
}
#[test]
fn confidence_increases_with_token_length() {
let short_conf = correction_confidence(1, 4); let long_conf = correction_confidence(1, 12); assert!(
long_conf > short_conf,
"expected long_conf {long_conf} > short_conf {short_conf}"
);
}
#[test]
fn completely_unrelated_string_returns_none() {
assert!(matcher().correct("BANANA").is_none());
}
#[test]
fn ambiguous_corrections_return_none() {
let vocab = &["BOOK", "COOK"];
let matcher = FuzzyVocabMatcher::new(vocab);
assert!(matcher.correct("NOOK").is_none());
}
#[test]
fn distance_2_edit_returns_correction_for_long_tokens() {
let result = matcher().correct("UNCLASSIFEID");
assert_eq!(result.map(|c| c.token), Some("UNCLASSIFIED"));
}
#[test]
fn correction_confidence_distance1_scales_with_length() {
let eps = 1e-5_f32;
assert!((correction_confidence(1, 3) - 0.55).abs() < eps); assert!((correction_confidence(1, 4) - 0.60).abs() < eps); assert!((correction_confidence(1, 6) - 0.70).abs() < eps); assert!((correction_confidence(1, 12) - 0.70).abs() < eps); }
#[test]
fn real_vocab_corrects_noforon_to_noforn() {
use marque_ism::CapcoTokenSet;
use marque_ism::token_set::TokenSet as _;
let vocab = CapcoTokenSet.correction_vocab();
let matcher = FuzzyVocabMatcher::new(vocab);
let result = matcher.correct("NOFORON");
assert_eq!(result.as_ref().map(|c| c.token), Some("NOFORN"));
assert_eq!(result.map(|c| c.distance), Some(1));
}
#[test]
fn real_vocab_corrects_nofron_to_noforn() {
use marque_ism::CapcoTokenSet;
use marque_ism::token_set::TokenSet as _;
let vocab = CapcoTokenSet.correction_vocab();
let matcher = FuzzyVocabMatcher::new(vocab);
let result = matcher.correct("NOFRON");
assert_eq!(result.map(|c| c.token), Some("NOFORN"));
}
#[test]
fn real_vocab_corrects_orcon_typo() {
use marque_ism::CapcoTokenSet;
use marque_ism::token_set::TokenSet as _;
let vocab = CapcoTokenSet.correction_vocab();
let matcher = FuzzyVocabMatcher::new(vocab);
let result = matcher.correct("ORCN");
assert_eq!(result.as_ref().map(|c| c.token), Some("ORCON"));
}
#[test]
fn real_vocab_emits_multi_word_banner_for_whitespace_free_typo() {
use marque_ism::CapcoTokenSet;
use marque_ism::token_set::TokenSet as _;
let vocab = CapcoTokenSet.correction_vocab();
let matcher = FuzzyVocabMatcher::new(vocab);
let result = matcher.correct("SBUNOFORN");
assert_eq!(result.as_ref().map(|c| c.token), Some("SBU NOFORN"));
assert_eq!(result.map(|c| c.distance), Some(1));
}
#[test]
fn correction_confidence_distance2_scales_with_length() {
let eps = 1e-5_f32;
assert!((correction_confidence(2, 5) - 0.40).abs() < eps); assert!((correction_confidence(2, 6) - 0.45).abs() < eps); assert!((correction_confidence(2, 8) - 0.55).abs() < eps); assert!((correction_confidence(2, 15) - 0.55).abs() < eps); }
#[test]
fn correct_all_returns_empty_for_known_token() {
let result = matcher().correct_all("SECRET");
assert!(
result.is_empty(),
"known token must return empty vec from correct_all, got {result:?}"
);
}
#[test]
fn correct_all_returns_empty_for_short_token() {
assert!(matcher().correct_all("S").is_empty());
assert!(matcher().correct_all("NF").is_empty());
}
#[test]
fn correct_all_returns_multiple_for_ambiguous_input() {
let vocab = &["BOOK", "COOK"];
let m = FuzzyVocabMatcher::new(vocab);
let result = m.correct_all("NOOK");
assert_eq!(
result.len(),
2,
"correct_all must return all tied candidates; got {result:?}"
);
let tokens: Vec<&str> = result.iter().map(|c| c.token).collect();
assert!(tokens.contains(&"BOOK"), "BOOK must be among alternates");
assert!(tokens.contains(&"COOK"), "COOK must be among alternates");
}
#[test]
fn correct_all_with_floor_filters_confidence() {
let vocab = &["NOFORN", "UNCLASSIFIED"];
let m = FuzzyVocabMatcher::new(vocab);
let with_default_floor = m.correct_all("NOFORON");
assert_eq!(
with_default_floor.len(),
1,
"default floor should keep the distance-1 match; got {with_default_floor:?}"
);
assert_eq!(with_default_floor[0].token, "NOFORN");
let with_high_floor = m.correct_all_with_floor("NOFORON", 0.80);
assert!(
with_high_floor.is_empty(),
"floor 0.80 must exclude NOFORON→NOFORN (confidence 0.65); got {with_high_floor:?}"
);
}
#[test]
fn correct_all_with_zero_floor_includes_distance2_on_short_token() {
let vocab = &["AUS", "AUT"];
let m = FuzzyVocabMatcher::new(vocab);
let default_floor = m.correct_all("ASU");
assert!(
default_floor.is_empty(),
"default floor must exclude distance-2 3-char correction; got {default_floor:?}"
);
let zero_floor = m.correct_all_with_floor("ASU", 0.0);
assert!(
!zero_floor.is_empty(),
"zero floor must include distance-2 3-char corrections; got {zero_floor:?}"
);
let tokens: Vec<&str> = zero_floor.iter().map(|c| c.token).collect();
assert!(
tokens.contains(&"AUS") || tokens.contains(&"AUT"),
"at least one of AUS/AUT must be a distance-2 candidate of ASU; got {zero_floor:?}"
);
}
#[test]
fn correct_all_result_is_sorted_by_distance_ascending() {
let vocab = &["NOFORN", "NORORN", "NONORN"];
let m = FuzzyVocabMatcher::new(vocab);
let result = m.correct_all("NOFRRN");
if result.len() > 1 {
for pair in result.windows(2) {
assert!(
pair[0].distance <= pair[1].distance,
"correct_all result not sorted by distance: {:?}",
result
);
}
}
}
}