use anyhow::Result;
use crate::algorithms::dfs_visit::dfs_visit;
use crate::algorithms::queues::AutoQueue;
use crate::algorithms::rm_epsilon::{RmEpsilonInternalConfig, RmEpsilonState};
use crate::algorithms::top_sort::TopOrderVisitor;
use crate::algorithms::tr_filters::EpsilonTrFilter;
use crate::algorithms::visitors::SccVisitor;
use crate::algorithms::Queue;
use crate::fst_properties::mutable_properties::rmepsilon_properties;
use crate::fst_properties::FstProperties;
use crate::fst_traits::MutableFst;
use crate::semirings::Semiring;
use crate::{StateId, Trs, EPS_LABEL};
pub fn rm_epsilon<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<()> {
let tr_filter = EpsilonTrFilter {};
let queue = AutoQueue::new(fst, None, &tr_filter)?;
let opts = RmEpsilonInternalConfig::new_with_default(queue);
rm_epsilon_with_internal_config(fst, opts)
}
pub(crate) fn rm_epsilon_with_internal_config<W: Semiring, F: MutableFst<W>, Q: Queue>(
fst: &mut F,
opts: RmEpsilonInternalConfig<W, Q>,
) -> Result<()> {
let connect = opts.connect;
let weight_threshold = opts.weight_threshold.clone();
let state_threshold = opts.state_threshold;
let start_state = match fst.start() {
None => return Ok(()),
Some(s) => s,
};
let mut noneps_in = vec![false; fst.num_states()];
noneps_in[start_state as usize] = true;
for state in fst.states_iter() {
for tr in fst.get_trs(state)?.trs() {
if tr.ilabel != EPS_LABEL || tr.olabel != EPS_LABEL {
noneps_in[tr.nextstate as usize] = true;
}
}
}
let mut states = vec![];
let fst_props = fst.properties();
if fst_props.contains(FstProperties::TOP_SORTED) {
states = fst.states_iter().collect();
} else if fst_props.contains(FstProperties::ACYCLIC) {
let mut visitor = TopOrderVisitor::new();
dfs_visit(fst, &mut visitor, &EpsilonTrFilter {}, false);
states.resize(visitor.order.len(), 0);
for i in 0..visitor.order.len() {
states[visitor.order[i] as usize] = i as StateId;
}
} else {
let mut visitor = SccVisitor::new(fst, true, false);
dfs_visit(fst, &mut visitor, &EpsilonTrFilter {}, false);
let scc = visitor.scc.as_ref().unwrap();
let mut first = vec![None; scc.len()];
let mut next = vec![None; scc.len()];
for i in 0..scc.len() {
if first[scc[i] as usize].is_some() {
next[i] = first[scc[i] as usize];
}
first[scc[i] as usize] = Some(i);
}
for mut opt_j in &first {
while let Some(j) = opt_j {
states.push(*j as StateId);
opt_j = &next[*j];
}
}
}
let mut rmeps_state = RmEpsilonState::new(fst.num_states(), opts);
let zero = W::zero();
for state in states.into_iter().rev() {
if !noneps_in[state as usize]
&& (connect || weight_threshold != W::zero() || state_threshold.is_some())
{
continue;
}
let (trs, final_weight) = rmeps_state.expand::<F, _>(state, &*fst)?;
unsafe {
fst.pop_trs_unchecked(state);
fst.set_trs_unchecked(state, trs.into_iter().rev().collect());
if final_weight != zero {
fst.set_final_unchecked(state, final_weight);
} else {
fst.delete_final_weight_unchecked(state);
}
}
}
if connect || weight_threshold != W::zero() || state_threshold.is_some() {
for s in 0..(fst.num_states() as StateId) {
if !noneps_in[s as usize] {
fst.delete_trs(s)?;
}
}
}
fst.set_properties(rmepsilon_properties(fst.properties(), false));
if weight_threshold != W::zero() || state_threshold.is_some() {
todo!("Implement Prune!")
}
if connect && weight_threshold == W::zero() && state_threshold.is_none() {
crate::algorithms::connect(fst)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fst_traits::Fst;
use crate::prelude::{TropicalWeight, VectorFst};
use crate::SymbolTable;
use proptest::prelude::any;
use proptest::proptest;
use std::sync::Arc;
proptest! {
#[test]
fn test_proptest_rmepsilon_keeps_symts(mut fst in any::<VectorFst::<TropicalWeight>>()) {
let symt = Arc::new(SymbolTable::new());
fst.set_input_symbols(Arc::clone(&symt));
fst.set_output_symbols(Arc::clone(&symt));
rm_epsilon(&mut fst).unwrap();
assert!(fst.input_symbols().is_some());
assert!(fst.output_symbols().is_some());
}
}
}