rustfst/algorithms/rm_epsilon/
rm_epsilon_static.rs1use anyhow::Result;
2
3use crate::algorithms::dfs_visit::dfs_visit;
4use crate::algorithms::queues::AutoQueue;
5use crate::algorithms::rm_epsilon::{RmEpsilonInternalConfig, RmEpsilonState};
6use crate::algorithms::top_sort::TopOrderVisitor;
7use crate::algorithms::tr_filters::EpsilonTrFilter;
8use crate::algorithms::visitors::SccVisitor;
9use crate::algorithms::Queue;
10use crate::fst_properties::mutable_properties::rmepsilon_properties;
11use crate::fst_properties::FstProperties;
12use crate::fst_traits::MutableFst;
13use crate::semirings::Semiring;
14use crate::{StateId, Trs, EPS_LABEL};
15
16pub fn rm_epsilon<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<()> {
66 let tr_filter = EpsilonTrFilter {};
67 let queue = AutoQueue::new(fst, None, &tr_filter)?;
68 let opts = RmEpsilonInternalConfig::new_with_default(queue);
69 rm_epsilon_with_internal_config(fst, opts)
70}
71pub(crate) fn rm_epsilon_with_internal_config<W: Semiring, F: MutableFst<W>, Q: Queue>(
72 fst: &mut F,
73 opts: RmEpsilonInternalConfig<W, Q>,
74) -> Result<()> {
75 let connect = opts.connect;
76 let weight_threshold = opts.weight_threshold.clone();
77 let state_threshold = opts.state_threshold;
78
79 let start_state = match fst.start() {
80 None => return Ok(()),
81 Some(s) => s,
82 };
83
84 let mut noneps_in = vec![false; fst.num_states()];
87 noneps_in[start_state as usize] = true;
88
89 for state in fst.states_iter() {
90 for tr in fst.get_trs(state)?.trs() {
91 if tr.ilabel != EPS_LABEL || tr.olabel != EPS_LABEL {
92 noneps_in[tr.nextstate as usize] = true;
93 }
94 }
95 }
96
97 let mut states = vec![];
100
101 let fst_props = fst.properties();
102
103 if fst_props.contains(FstProperties::TOP_SORTED) {
104 states = fst.states_iter().collect();
105 } else if fst_props.contains(FstProperties::ACYCLIC) {
106 let mut visitor = TopOrderVisitor::new();
107 dfs_visit(fst, &mut visitor, &EpsilonTrFilter {}, false);
108
109 states.resize(visitor.order.len(), 0);
110 for i in 0..visitor.order.len() {
111 states[visitor.order[i] as usize] = i as StateId;
112 }
113 } else {
114 let mut visitor = SccVisitor::new(fst, true, false);
115 dfs_visit(fst, &mut visitor, &EpsilonTrFilter {}, false);
116
117 let scc = visitor.scc.as_ref().unwrap();
118
119 let mut first = vec![None; scc.len()];
120 let mut next = vec![None; scc.len()];
121
122 for i in 0..scc.len() {
123 if first[scc[i] as usize].is_some() {
124 next[i] = first[scc[i] as usize];
125 }
126 first[scc[i] as usize] = Some(i);
127 }
128
129 for mut opt_j in &first {
130 while let Some(j) = opt_j {
131 states.push(*j as StateId);
132 opt_j = &next[*j];
133 }
134 }
135 }
136
137 let mut rmeps_state = RmEpsilonState::new(fst.num_states(), opts);
138 let zero = W::zero();
139
140 for state in states.into_iter().rev() {
141 if !noneps_in[state as usize]
142 && (connect || weight_threshold != W::zero() || state_threshold.is_some())
143 {
144 continue;
145 }
146 let (trs, final_weight) = rmeps_state.expand::<F, _>(state, &*fst)?;
147
148 unsafe {
149 fst.pop_trs_unchecked(state);
150 fst.set_trs_unchecked(state, trs.into_iter().rev().collect());
151 if final_weight != zero {
152 fst.set_final_unchecked(state, final_weight);
153 } else {
154 fst.delete_final_weight_unchecked(state);
155 }
156 }
157 }
158
159 if connect || weight_threshold != W::zero() || state_threshold.is_some() {
160 for s in 0..(fst.num_states() as StateId) {
161 if !noneps_in[s as usize] {
162 fst.delete_trs(s)?;
163 }
164 }
165 }
166
167 fst.set_properties(rmepsilon_properties(fst.properties(), false));
168
169 if weight_threshold != W::zero() || state_threshold.is_some() {
170 todo!("Implement Prune!")
171 }
172
173 if connect && weight_threshold == W::zero() && state_threshold.is_none() {
174 crate::algorithms::connect(fst)?;
175 }
176 Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::fst_traits::Fst;
183 use crate::prelude::{TropicalWeight, VectorFst};
184 use crate::SymbolTable;
185 use proptest::prelude::any;
186 use proptest::proptest;
187 use std::sync::Arc;
188
189 proptest! {
190 #[test]
191 fn test_proptest_rmepsilon_keeps_symts(mut fst in any::<VectorFst::<TropicalWeight>>()) {
192 let symt = Arc::new(SymbolTable::new());
193 fst.set_input_symbols(Arc::clone(&symt));
194 fst.set_output_symbols(Arc::clone(&symt));
195
196 rm_epsilon(&mut fst).unwrap();
197
198 assert!(fst.input_symbols().is_some());
199 assert!(fst.output_symbols().is_some());
200 }
201 }
202}