use std::collections::HashMap;
use failure::Fallible;
use crate::algorithms::state_mappers::ArcSumMapper;
use crate::algorithms::{all_pairs_shortest_distance, state_map};
use crate::arc::Arc;
use crate::fst_traits::{ExpandedFst, FinalStatesIterator, MutableFst};
use crate::semirings::{Semiring, StarSemiring};
use crate::EPS_LABEL;
fn compute_fst_epsilon<W, F1, F2>(fst: &F1, keep_only_epsilon: bool) -> Fallible<F2>
where
W: Semiring,
F1: ExpandedFst<W = W>,
F2: MutableFst<W = W> + ExpandedFst<W = W>,
{
let mut fst_epsilon = F2::new();
let mut mapping_states = HashMap::new();
for old_state_id in fst.states_iter() {
let new_state_id = fst_epsilon.add_state();
mapping_states.insert(old_state_id, new_state_id);
}
for old_state_id in fst.states_iter() {
for old_arc in fst.arcs_iter(old_state_id)? {
let a = keep_only_epsilon && old_arc.ilabel == EPS_LABEL && old_arc.olabel == EPS_LABEL;
let b =
!(old_arc.ilabel == EPS_LABEL && old_arc.olabel == EPS_LABEL || keep_only_epsilon);
if a || b {
fst_epsilon.add_arc(
mapping_states[&old_state_id],
Arc::new(
old_arc.ilabel,
old_arc.olabel,
old_arc.weight.clone(),
mapping_states[&old_arc.nextstate],
),
)?;
}
}
}
if let Some(start_state) = fst.start() {
fst_epsilon.set_start(mapping_states[&start_state])?;
}
for old_final_state in fst.final_states_iter() {
fst_epsilon.set_final(
mapping_states[&old_final_state.state_id],
old_final_state.final_weight,
)?;
}
Ok(fst_epsilon)
}
pub fn rm_epsilon<W, F1, F2>(fst: &F1) -> Fallible<F2>
where
W: StarSemiring,
F1: ExpandedFst<W = W>,
F2: MutableFst<W = W> + ExpandedFst<W = W>,
{
let fst_epsilon: F2 = compute_fst_epsilon(fst, true)?;
let dists_fst_epsilon = all_pairs_shortest_distance(&fst_epsilon)?;
let mut eps_closures = vec![vec![]; fst_epsilon.num_states()];
for p in fst_epsilon.states_iter() {
for q in fst_epsilon.states_iter() {
if p != q && dists_fst_epsilon[p][q] != W::zero() {
eps_closures[p].push((q, &dists_fst_epsilon[p][q]));
}
}
}
let fst_no_epsilon: F2 = compute_fst_epsilon(fst, false)?;
let mut output_fst = fst_no_epsilon.clone();
for p in fst_no_epsilon.states_iter() {
for (q, w_prime) in &eps_closures[p] {
for arc in fst_no_epsilon.arcs_iter(*q)? {
output_fst.add_arc(
p,
Arc::new(
arc.ilabel,
arc.olabel,
w_prime.times(&arc.weight)?,
arc.nextstate,
),
)?;
}
if fst_no_epsilon.is_final(*q) {
if !fst_no_epsilon.is_final(p) {
output_fst.set_final(p, W::zero())?;
}
let rho_prime_p = output_fst.final_weight(p).unwrap();
let rho_q = fst_no_epsilon.final_weight(*q).unwrap();
let new_weight = rho_prime_p.plus(&w_prime.times(&rho_q)?)?;
output_fst.set_final(p, new_weight)?;
}
}
}
let mut arc_sum_mapper = ArcSumMapper {};
state_map(&mut output_fst, &mut arc_sum_mapper)?;
Ok(output_fst)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fst_impls::VectorFst;
use crate::fst_traits::PathsIterator;
use crate::semirings::IntegerWeight;
use crate::test_data::vector_fst::get_vector_fsts_for_tests;
use counter::Counter;
use failure::format_err;
use failure::ResultExt;
#[test]
fn test_epsilon_removal_generic() -> Fallible<()> {
for data in get_vector_fsts_for_tests() {
let fst = &data.fst;
let paths_ref: Counter<_> = fst.paths_iter().collect();
let epsilon_removed_fst: VectorFst<IntegerWeight> =
rm_epsilon(fst).with_context(|_| {
format_err!(
"Error when performing epsilon removal operation for wFST {:?}",
&data.name,
)
})?;
let paths: Counter<_> = epsilon_removed_fst.paths_iter().collect();
assert_eq!(
paths, paths_ref,
"Test failing for epsilon removal for wFST {:?}",
&data.name
);
}
Ok(())
}
}