use std::marker::PhantomData;
use crate::semiring::{Semiring, TropicalWeight};
use crate::wfst::{MutableWfst, StateId, VectorWfst};
#[derive(Debug, Clone)]
pub struct EditCosts {
pub insert: f64,
pub delete: f64,
pub substitute: f64,
pub transpose: f64,
}
impl Default for EditCosts {
fn default() -> Self {
Self {
insert: 1.0,
delete: 1.0,
substitute: 1.0,
transpose: 1.0,
}
}
}
impl EditCosts {
pub fn uniform(cost: f64) -> Self {
Self {
insert: cost,
delete: cost,
substitute: cost,
transpose: cost,
}
}
pub fn prefer_transpose() -> Self {
Self {
insert: 1.0,
delete: 1.0,
substitute: 1.0,
transpose: 0.5, }
}
}
#[derive(Debug, Clone)]
pub struct EditDistanceConfig {
pub max_distance: usize,
pub costs: EditCosts,
pub include_transpositions: bool,
pub alphabet: Vec<char>,
}
impl Default for EditDistanceConfig {
fn default() -> Self {
Self {
max_distance: 2,
costs: EditCosts::default(),
include_transpositions: false,
alphabet: Vec::new(),
}
}
}
impl EditDistanceConfig {
pub fn levenshtein(max_distance: usize) -> Self {
Self {
max_distance,
include_transpositions: false,
..Default::default()
}
}
pub fn damerau_levenshtein(max_distance: usize) -> Self {
Self {
max_distance,
include_transpositions: true,
..Default::default()
}
}
pub fn with_alphabet(mut self, chars: &str) -> Self {
self.alphabet = chars.chars().collect();
self
}
pub fn with_alphabet_chars(mut self, chars: &[char]) -> Self {
self.alphabet = chars.to_vec();
self
}
pub fn with_costs(mut self, costs: EditCosts) -> Self {
self.costs = costs;
self
}
}
pub struct EditDistanceTransducer {
config: EditDistanceConfig,
}
impl EditDistanceTransducer {
pub fn new(config: EditDistanceConfig) -> Self {
Self { config }
}
pub fn levenshtein(max_distance: usize) -> Self {
Self::new(EditDistanceConfig::levenshtein(max_distance))
}
pub fn damerau_levenshtein(max_distance: usize) -> Self {
Self::new(EditDistanceConfig::damerau_levenshtein(max_distance))
}
pub fn with_alphabet(mut self, chars: &str) -> Self {
self.config = self.config.with_alphabet(chars);
self
}
pub fn build(&self) -> VectorWfst<char, TropicalWeight> {
let k = self.config.max_distance;
let costs = &self.config.costs;
let alphabet = &self.config.alphabet;
if alphabet.is_empty() {
return self.build_identity_transducer();
}
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let states: Vec<StateId> = (0..=k).map(|_| fst.add_state()).collect();
fst.set_start(states[0]);
for &state in &states {
fst.set_final(state, TropicalWeight::one());
}
for (i, &from_state) in states.iter().enumerate() {
for &c in alphabet {
fst.add_arc(
from_state,
Some(c),
Some(c),
from_state,
TropicalWeight::one(),
);
if i < k {
let next_state = states[i + 1];
fst.add_arc(
from_state,
Some(c),
None,
next_state,
TropicalWeight::new(costs.delete),
);
fst.add_arc(
from_state,
None,
Some(c),
next_state,
TropicalWeight::new(costs.insert),
);
for &d in alphabet {
if c != d {
fst.add_arc(
from_state,
Some(c),
Some(d),
next_state,
TropicalWeight::new(costs.substitute),
);
}
}
}
}
}
fst
}
fn build_identity_transducer(&self) -> VectorWfst<char, TropicalWeight> {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let state = fst.add_state();
fst.set_start(state);
fst.set_final(state, TropicalWeight::one());
fst
}
pub fn build_for_query(&self, query: &str) -> VectorWfst<char, TropicalWeight> {
let k = self.config.max_distance;
let costs = &self.config.costs;
let query_chars: Vec<char> = query.chars().collect();
let n = query_chars.len();
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let num_states = (n + 1) * (k + 1);
fst.add_states(num_states);
let state_id = |pos: usize, err: usize| -> StateId { (pos * (k + 1) + err) as StateId };
fst.set_start(state_id(0, 0));
for err in 0..=k {
fst.set_final(state_id(n, err), TropicalWeight::one());
}
let alphabet = if self.config.alphabet.is_empty() {
let mut chars: Vec<char> = query_chars.clone();
chars.sort();
chars.dedup();
chars
} else {
self.config.alphabet.clone()
};
for pos in 0..=n {
for err in 0..=k {
let from = state_id(pos, err);
if pos < n {
let query_char = query_chars[pos];
fst.add_arc(
from,
Some(query_char),
Some(query_char),
state_id(pos + 1, err),
TropicalWeight::one(),
);
if err < k {
fst.add_arc(
from,
Some(query_char),
None,
state_id(pos + 1, err + 1),
TropicalWeight::new(costs.delete),
);
for &c in &alphabet {
if c != query_char {
fst.add_arc(
from,
Some(query_char),
Some(c),
state_id(pos + 1, err + 1),
TropicalWeight::new(costs.substitute),
);
}
}
}
}
if err < k {
for &c in &alphabet {
fst.add_arc(
from,
None,
Some(c),
state_id(pos, err + 1),
TropicalWeight::new(costs.insert),
);
}
}
}
}
fst
}
}
pub type DamerauLevenshteinTransducer = EditDistanceTransducer;
impl DamerauLevenshteinTransducer {
pub fn new_damerau(max_distance: usize) -> Self {
Self::damerau_levenshtein(max_distance)
}
}
pub struct LazyEditDistanceTransducer<W: Semiring> {
query: Vec<char>,
max_distance: usize,
costs: EditCosts,
alphabet: Vec<char>,
states_per_pos: usize,
_phantom: PhantomData<W>,
}
impl<W: Semiring> LazyEditDistanceTransducer<W> {
pub fn new(query: &str, max_distance: usize, alphabet: Vec<char>) -> Self {
let query: Vec<char> = query.chars().collect();
let states_per_pos = max_distance + 1;
Self {
query,
max_distance,
costs: EditCosts::default(),
alphabet,
states_per_pos,
_phantom: PhantomData,
}
}
pub fn with_costs(mut self, costs: EditCosts) -> Self {
self.costs = costs;
self
}
#[inline]
pub fn encode_state(&self, pos: usize, err: usize) -> StateId {
(pos * self.states_per_pos + err) as StateId
}
#[inline]
pub fn decode_state(&self, state: StateId) -> (usize, usize) {
let state = state as usize;
let pos = state / self.states_per_pos;
let err = state % self.states_per_pos;
(pos, err)
}
#[inline]
pub fn is_valid_state(&self, state: StateId) -> bool {
let (pos, err) = self.decode_state(state);
pos <= self.query.len() && err <= self.max_distance
}
pub fn query_len(&self) -> usize {
self.query.len()
}
pub fn max_distance(&self) -> usize {
self.max_distance
}
pub fn alphabet(&self) -> &[char] {
&self.alphabet
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::Wfst;
#[test]
fn test_edit_costs_default() {
let costs = EditCosts::default();
assert_eq!(costs.insert, 1.0);
assert_eq!(costs.delete, 1.0);
assert_eq!(costs.substitute, 1.0);
assert_eq!(costs.transpose, 1.0);
}
#[test]
fn test_edit_costs_uniform() {
let costs = EditCosts::uniform(0.5);
assert_eq!(costs.insert, 0.5);
assert_eq!(costs.delete, 0.5);
assert_eq!(costs.substitute, 0.5);
assert_eq!(costs.transpose, 0.5);
}
#[test]
fn test_edit_distance_config() {
let config = EditDistanceConfig::levenshtein(3);
assert_eq!(config.max_distance, 3);
assert!(!config.include_transpositions);
let config = EditDistanceConfig::damerau_levenshtein(2);
assert_eq!(config.max_distance, 2);
assert!(config.include_transpositions);
}
#[test]
fn test_edit_distance_transducer_creation() {
let transducer = EditDistanceTransducer::levenshtein(2).with_alphabet("abc");
let fst = transducer.build();
assert!(!fst.is_empty());
}
#[test]
fn test_edit_distance_transducer_states() {
let transducer = EditDistanceTransducer::levenshtein(2).with_alphabet("ab");
let fst = transducer.build();
assert_eq!(fst.num_states(), 3);
assert_eq!(fst.start(), 0);
for s in 0..3 {
assert!(fst.is_final(s));
}
}
#[test]
fn test_edit_distance_transducer_transitions() {
let transducer = EditDistanceTransducer::levenshtein(1).with_alphabet("ab");
let fst = transducer.build();
let transitions = fst.transitions(0);
assert!(!transitions.is_empty());
let matches: Vec<_> = transitions
.iter()
.filter(|t| t.input == t.output && t.weight == TropicalWeight::one())
.collect();
assert!(!matches.is_empty());
}
#[test]
fn test_build_for_query() {
let transducer = EditDistanceTransducer::levenshtein(2).with_alphabet("hello");
let fst = transducer.build_for_query("helo");
assert_eq!(fst.num_states(), 15);
}
#[test]
fn test_lazy_state_encoding() {
let lazy: LazyEditDistanceTransducer<TropicalWeight> =
LazyEditDistanceTransducer::new("test", 2, vec!['a', 'b']);
for pos in 0..=4 {
for err in 0..=2 {
let encoded = lazy.encode_state(pos, err);
let (dec_pos, dec_err) = lazy.decode_state(encoded);
assert_eq!(dec_pos, pos);
assert_eq!(dec_err, err);
}
}
}
#[test]
fn test_damerau_levenshtein_alias() {
let transducer = DamerauLevenshteinTransducer::new_damerau(2);
assert!(transducer.config.include_transpositions);
}
}