rustfst/algorithms/rm_epsilon/
rm_epsilon_static.rs

1use anyhow::Result;
2
3use crate::algorithms::dfs_visit::dfs_visit;
4use crate::algorithms::queues::AutoQueue;
5use crate::algorithms::rm_epsilon::{RmEpsilonInternalConfig, RmEpsilonState};
6use crate::algorithms::top_sort::TopOrderVisitor;
7use crate::algorithms::tr_filters::EpsilonTrFilter;
8use crate::algorithms::visitors::SccVisitor;
9use crate::algorithms::Queue;
10use crate::fst_properties::mutable_properties::rmepsilon_properties;
11use crate::fst_properties::FstProperties;
12use crate::fst_traits::MutableFst;
13use crate::semirings::Semiring;
14use crate::{StateId, Trs, EPS_LABEL};
15
16/// This operation removes epsilon-transitions (when both the input and
17/// output labels are an epsilon) from a transducer. The result will be an
18/// equivalent FST that has no such epsilon transitions.
19///
20/// # Example 1
21/// ```
22/// # use rustfst::semirings::{Semiring, IntegerWeight};
23/// # use rustfst::fst_impls::VectorFst;
24/// # use rustfst::fst_traits::MutableFst;
25/// # use rustfst::algorithms::rm_epsilon::rm_epsilon;
26/// # use rustfst::Tr;
27/// # use rustfst::EPS_LABEL;
28/// # use anyhow::Result;
29/// # fn main() -> Result<()> {
30/// let mut fst = VectorFst::new();
31/// let s0 = fst.add_state();
32/// let s1 = fst.add_state();
33/// fst.add_tr(s0, Tr::new(32, 25, IntegerWeight::new(78), s1));
34/// fst.add_tr(s1, Tr::new(EPS_LABEL, EPS_LABEL, IntegerWeight::new(13), s0));
35/// fst.set_start(s0)?;
36/// fst.set_final(s0, IntegerWeight::new(5))?;
37///
38/// let mut fst_no_epsilon = fst.clone();
39/// rm_epsilon(&mut fst_no_epsilon)?;
40///
41/// let mut fst_no_epsilon_ref = VectorFst::<IntegerWeight>::new();
42/// let s0 = fst_no_epsilon_ref.add_state();
43/// let s1 = fst_no_epsilon_ref.add_state();
44/// fst_no_epsilon_ref.add_tr(s0, Tr::new(32, 25, 78, s1));
45/// fst_no_epsilon_ref.add_tr(s1, Tr::new(32, 25, 78 * 13, s1));
46/// fst_no_epsilon_ref.set_start(s0)?;
47/// fst_no_epsilon_ref.set_final(s0, 5)?;
48/// fst_no_epsilon_ref.set_final(s1, 5 * 13)?;
49///
50/// assert_eq!(fst_no_epsilon, fst_no_epsilon_ref);
51/// # Ok(())
52/// # }
53/// ```
54///
55/// # Example 2
56///
57/// ## Input
58///
59/// ![rmepsilon_in](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/rmepsilon_in.svg?sanitize=true)
60///
61/// ## RmEpsilon
62///
63/// ![rmepsilon_out](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/rmepsilon_out.svg?sanitize=true)
64///
65pub fn rm_epsilon<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<()> {
66    let tr_filter = EpsilonTrFilter {};
67    let queue = AutoQueue::new(fst, None, &tr_filter)?;
68    let opts = RmEpsilonInternalConfig::new_with_default(queue);
69    rm_epsilon_with_internal_config(fst, opts)
70}
71pub(crate) fn rm_epsilon_with_internal_config<W: Semiring, F: MutableFst<W>, Q: Queue>(
72    fst: &mut F,
73    opts: RmEpsilonInternalConfig<W, Q>,
74) -> Result<()> {
75    let connect = opts.connect;
76    let weight_threshold = opts.weight_threshold.clone();
77    let state_threshold = opts.state_threshold;
78
79    let start_state = match fst.start() {
80        None => return Ok(()),
81        Some(s) => s,
82    };
83
84    // noneps_in[s] will be set to true iff s admits a non-epsilon incoming
85    // transition or is the start state.
86    let mut noneps_in = vec![false; fst.num_states()];
87    noneps_in[start_state as usize] = true;
88
89    for state in fst.states_iter() {
90        for tr in fst.get_trs(state)?.trs() {
91            if tr.ilabel != EPS_LABEL || tr.olabel != EPS_LABEL {
92                noneps_in[tr.nextstate as usize] = true;
93            }
94        }
95    }
96
97    // States sorted in topological order when (acyclic) or generic topological
98    // order (cyclic).
99    let mut states = vec![];
100
101    let fst_props = fst.properties();
102
103    if fst_props.contains(FstProperties::TOP_SORTED) {
104        states = fst.states_iter().collect();
105    } else if fst_props.contains(FstProperties::ACYCLIC) {
106        let mut visitor = TopOrderVisitor::new();
107        dfs_visit(fst, &mut visitor, &EpsilonTrFilter {}, false);
108
109        states.resize(visitor.order.len(), 0);
110        for i in 0..visitor.order.len() {
111            states[visitor.order[i] as usize] = i as StateId;
112        }
113    } else {
114        let mut visitor = SccVisitor::new(fst, true, false);
115        dfs_visit(fst, &mut visitor, &EpsilonTrFilter {}, false);
116
117        let scc = visitor.scc.as_ref().unwrap();
118
119        let mut first = vec![None; scc.len()];
120        let mut next = vec![None; scc.len()];
121
122        for i in 0..scc.len() {
123            if first[scc[i] as usize].is_some() {
124                next[i] = first[scc[i] as usize];
125            }
126            first[scc[i] as usize] = Some(i);
127        }
128
129        for mut opt_j in &first {
130            while let Some(j) = opt_j {
131                states.push(*j as StateId);
132                opt_j = &next[*j];
133            }
134        }
135    }
136
137    let mut rmeps_state = RmEpsilonState::new(fst.num_states(), opts);
138    let zero = W::zero();
139
140    for state in states.into_iter().rev() {
141        if !noneps_in[state as usize]
142            && (connect || weight_threshold != W::zero() || state_threshold.is_some())
143        {
144            continue;
145        }
146        let (trs, final_weight) = rmeps_state.expand::<F, _>(state, &*fst)?;
147
148        unsafe {
149            fst.pop_trs_unchecked(state);
150            fst.set_trs_unchecked(state, trs.into_iter().rev().collect());
151            if final_weight != zero {
152                fst.set_final_unchecked(state, final_weight);
153            } else {
154                fst.delete_final_weight_unchecked(state);
155            }
156        }
157    }
158
159    if connect || weight_threshold != W::zero() || state_threshold.is_some() {
160        for s in 0..(fst.num_states() as StateId) {
161            if !noneps_in[s as usize] {
162                fst.delete_trs(s)?;
163            }
164        }
165    }
166
167    fst.set_properties(rmepsilon_properties(fst.properties(), false));
168
169    if weight_threshold != W::zero() || state_threshold.is_some() {
170        todo!("Implement Prune!")
171    }
172
173    if connect && weight_threshold == W::zero() && state_threshold.is_none() {
174        crate::algorithms::connect(fst)?;
175    }
176    Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::fst_traits::Fst;
183    use crate::prelude::{TropicalWeight, VectorFst};
184    use crate::SymbolTable;
185    use proptest::prelude::any;
186    use proptest::proptest;
187    use std::sync::Arc;
188
189    proptest! {
190        #[test]
191        fn test_proptest_rmepsilon_keeps_symts(mut fst in any::<VectorFst::<TropicalWeight>>()) {
192            let symt = Arc::new(SymbolTable::new());
193            fst.set_input_symbols(Arc::clone(&symt));
194            fst.set_output_symbols(Arc::clone(&symt));
195
196            rm_epsilon(&mut fst).unwrap();
197
198            assert!(fst.input_symbols().is_some());
199            assert!(fst.output_symbols().is_some());
200        }
201    }
202}