use std::collections::{HashMap, HashSet};
use crate::semiring::{Semiring, StarSemiring};
use crate::wfst::{MutableWfst, StateId, WeightedTransition, Wfst, NO_STATE};
use super::connect::{connect, ConnectConfig};
use super::shortest_distance::ShortestDistanceConfig;
#[derive(Clone, Debug)]
pub struct EpsilonRemovalConfig {
pub connect: bool,
pub distance_config: ShortestDistanceConfig,
}
impl Default for EpsilonRemovalConfig {
fn default() -> Self {
Self {
connect: true,
distance_config: ShortestDistanceConfig::default(),
}
}
}
impl EpsilonRemovalConfig {
pub fn acyclic() -> Self {
Self {
connect: true,
distance_config: ShortestDistanceConfig::acyclic(),
}
}
}
pub fn remove_epsilon<L, W, F>(
fst: &mut F,
config: EpsilonRemovalConfig,
) -> Result<(), EpsilonRemovalError>
where
L: Clone + PartialEq,
W: Semiring,
F: MutableWfst<L, W> + Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return Ok(());
}
if fst.start() == NO_STATE {
return Err(EpsilonRemovalError::NoStartState);
}
let closures = compute_epsilon_closures(fst);
let mut new_transitions: Vec<Vec<WeightedTransition<L, W>>> = vec![Vec::new(); n];
for state in 0..n {
let state_id = state as StateId;
for trans in fst.transitions(state_id) {
if trans.input.is_some() || trans.output.is_some() {
let to_closure = &closures[trans.to as usize];
for (closure_state, closure_weight) in to_closure {
let new_weight = trans.weight.times(closure_weight);
new_transitions[state].push(WeightedTransition {
from: state_id,
to: *closure_state,
input: trans.input.clone(),
output: trans.output.clone(),
weight: new_weight,
});
}
}
}
if state_id == fst.start() {
let start_closure = &closures[state];
for (closure_state, closure_weight) in start_closure {
if *closure_state != state_id {
for trans in fst.transitions(*closure_state) {
if trans.input.is_some() || trans.output.is_some() {
let to_closure = &closures[trans.to as usize];
for (dest_state, dest_weight) in to_closure {
let new_weight =
closure_weight.times(&trans.weight).times(dest_weight);
new_transitions[state].push(WeightedTransition {
from: state_id,
to: *dest_state,
input: trans.input.clone(),
output: trans.output.clone(),
weight: new_weight,
});
}
}
}
}
}
}
}
for state_trans in &mut new_transitions {
deduplicate_transitions(state_trans);
}
for state in 0..n {
let state_id = state as StateId;
fst.clear_transitions(state_id);
for trans in new_transitions[state].drain(..) {
fst.add_transition(trans);
}
}
let start = fst.start();
let start_closure = &closures[start as usize];
for (closure_state, closure_weight) in start_closure {
if fst.is_final(*closure_state) && *closure_state != start {
let old_final = fst.final_weight(start);
let contribution = closure_weight.times(&fst.final_weight(*closure_state));
fst.set_final(start, old_final.plus(&contribution));
}
}
for state in 0..n {
let state_id = state as StateId;
let closure = &closures[state];
let mut new_final = fst.final_weight(state_id);
for (closure_state, closure_weight) in closure {
if *closure_state != state_id && fst.is_final(*closure_state) {
let contribution = closure_weight.times(&fst.final_weight(*closure_state));
new_final = new_final.plus(&contribution);
}
}
if !new_final.is_zero() {
fst.set_final(state_id, new_final);
}
}
if config.connect {
connect(fst, ConnectConfig::trim());
}
Ok(())
}
fn compute_epsilon_closures<L, W, F>(fst: &F) -> Vec<HashMap<StateId, W>>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
let mut closures: Vec<HashMap<StateId, W>> = vec![HashMap::new(); n];
for state in 0..n {
let state_id = state as StateId;
let mut closure: HashMap<StateId, W> = HashMap::new();
let mut visited = HashSet::new();
let mut queue = vec![(state_id, W::one())];
while let Some((current, weight)) = queue.pop() {
if visited.contains(¤t) {
if let Some(existing) = closure.get(¤t) {
closure.insert(current, existing.plus(&weight));
}
continue;
}
visited.insert(current);
closure.insert(current, weight.clone());
for trans in fst.transitions(current) {
if trans.input.is_none() && trans.output.is_none() {
let new_weight = weight.times(&trans.weight);
queue.push((trans.to, new_weight));
}
}
}
closures[state] = closure;
}
closures
}
fn deduplicate_transitions<L, W>(transitions: &mut Vec<WeightedTransition<L, W>>)
where
L: Clone + PartialEq,
W: Semiring,
{
if transitions.len() <= 1 {
return;
}
let mut groups: HashMap<
(StateId, Option<usize>, Option<usize>),
(WeightedTransition<L, W>, W),
> = HashMap::new();
for trans in transitions.drain(..) {
let key = (
trans.to,
trans.input.as_ref().map(|_| 0usize),
trans.output.as_ref().map(|_| 0usize),
);
let mut found = false;
for ((to, _, _), (existing, weight)) in groups.iter_mut() {
if *to == trans.to && existing.input == trans.input && existing.output == trans.output {
*weight = weight.plus(&trans.weight);
found = true;
break;
}
}
if !found {
groups.insert(key, (trans.clone(), trans.weight.clone()));
}
}
for (_, (mut trans, weight)) in groups {
trans.weight = weight;
transitions.push(trans);
}
}
pub fn remove_epsilon_star<L, W, F>(
fst: &mut F,
config: EpsilonRemovalConfig,
) -> Result<(), EpsilonRemovalError>
where
L: Clone + PartialEq,
W: StarSemiring,
F: MutableWfst<L, W> + Wfst<L, W>,
{
remove_epsilon(fst, config)
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EpsilonRemovalError {
NoStartState,
NonConvergentCycle,
}
impl std::fmt::Display for EpsilonRemovalError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoStartState => write!(f, "WFST has no start state"),
Self::NonConvergentCycle => write!(f, "Epsilon cycle with non-converging weight"),
}
}
}
impl std::error::Error for EpsilonRemovalError {}
pub fn has_epsilon_transitions<L, W, F>(fst: &F) -> bool
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
for state in 0..fst.num_states() {
for trans in fst.transitions(state as StateId) {
if trans.input.is_none() && trans.output.is_none() {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{MutableWfst, VectorWfst, VectorWfstBuilder};
mod property_tests {
use super::*;
use crate::test_utils::arb_tropical_wfst;
use proptest::prelude::*;
proptest! {
#[test]
fn epsilon_removal_complete(
mut fst in arb_tropical_wfst(8, 3)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let result = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
if result.is_ok() {
prop_assert!(
!has_epsilon_transitions(&fst),
"FST still has epsilon transitions after removal"
);
}
}
#[test]
fn epsilon_removal_state_bound(
mut fst in arb_tropical_wfst(8, 3)
) {
if fst.num_states() == 0 {
return Ok(());
}
let original_states = fst.num_states();
let _ = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
prop_assert!(
fst.num_states() <= original_states,
"Epsilon removal increased states from {} to {}",
original_states,
fst.num_states()
);
}
#[test]
fn has_epsilon_after_removal(
mut fst in arb_tropical_wfst(6, 2)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let _ = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
let predicate_result = has_epsilon_transitions(&fst);
let mut found_epsilon = false;
for state in 0..fst.num_states() {
for trans in fst.transitions(state as StateId) {
if trans.input.is_none() && trans.output.is_none() {
found_epsilon = true;
break;
}
}
if found_epsilon {
break;
}
}
prop_assert_eq!(
predicate_result,
found_epsilon,
"has_epsilon_transitions() returned {}, but manual check found {}",
predicate_result,
found_epsilon
);
}
#[test]
fn epsilon_removal_identity_when_no_epsilon(
fst in arb_tropical_wfst(6, 2)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let mut clean_fst = fst.clone();
let _ = remove_epsilon(&mut clean_fst, EpsilonRemovalConfig::default());
if !has_epsilon_transitions(&clean_fst) {
let original_states = clean_fst.num_states();
let mut second_fst = clean_fst.clone();
let _ = remove_epsilon(&mut second_fst, EpsilonRemovalConfig::default());
prop_assert_eq!(
second_fst.num_states(),
original_states,
"Second epsilon removal changed state count"
);
}
}
}
}
fn build_simple_epsilon_chain() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(3);
fst.set_start(0);
fst.add_epsilon(0, 1, TropicalWeight::new(1.0));
fst.add_arc(1, Some('a'), Some('a'), 2, TropicalWeight::new(2.0));
fst.set_final(2, TropicalWeight::new(0.5));
fst
}
fn build_epsilon_to_final() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(3);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_epsilon(1, 2, TropicalWeight::new(0.5));
fst.set_final(2, TropicalWeight::one());
fst
}
#[test]
fn test_remove_epsilon_empty() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let result = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
assert!(result.is_ok());
}
#[test]
fn test_remove_epsilon_no_start() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
fst.add_state();
let result = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
assert_eq!(result, Err(EpsilonRemovalError::NoStartState));
}
#[test]
fn test_remove_epsilon_simple_chain() {
let mut fst = build_simple_epsilon_chain();
assert!(has_epsilon_transitions(&fst));
let result = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
assert!(result.is_ok());
assert!(!has_epsilon_transitions(&fst));
assert_eq!(fst.num_states(), 3);
assert_ne!(fst.start(), NO_STATE);
let trans = fst.transitions(0);
assert!(!trans.is_empty(), "State 0 should have transitions");
}
#[test]
fn test_remove_epsilon_to_final() {
let mut fst = build_epsilon_to_final();
assert!(has_epsilon_transitions(&fst));
let result = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
assert!(result.is_ok());
assert!(!has_epsilon_transitions(&fst));
assert!(fst.is_final(1), "State 1 should be final after ε-removal");
}
#[test]
fn test_remove_epsilon_no_epsilons() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(1, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.final_state(2, TropicalWeight::one())
.build();
assert!(!has_epsilon_transitions(&fst));
let result = remove_epsilon(&mut fst, EpsilonRemovalConfig::default());
assert!(result.is_ok());
assert_eq!(fst.num_states(), 3);
assert_eq!(fst.transitions(0).len(), 1);
assert_eq!(fst.transitions(1).len(), 1);
}
#[test]
fn test_has_epsilon_transitions() {
let with_eps = build_simple_epsilon_chain();
assert!(has_epsilon_transitions(&with_eps));
let without_eps: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(2)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::one())
.final_state(1, TropicalWeight::one())
.build();
assert!(!has_epsilon_transitions(&without_eps));
}
#[test]
fn test_epsilon_removal_error_display() {
assert_eq!(
EpsilonRemovalError::NoStartState.to_string(),
"WFST has no start state"
);
assert_eq!(
EpsilonRemovalError::NonConvergentCycle.to_string(),
"Epsilon cycle with non-converging weight"
);
}
}