rustfst/fst_impls/vector_fst/
test.rs

1#[cfg(test)]
2mod tests {
3    use rand::{rngs::StdRng, SeedableRng};
4
5    use anyhow::Result;
6
7    use crate::fst_impls::VectorFst;
8    use crate::fst_traits::{
9        CoreFst, ExpandedFst, Fst, MutableFst, SerializableFst, StateIterator,
10    };
11    use crate::semirings::{ProbabilityWeight, Semiring, TropicalWeight};
12    use crate::tr::Tr;
13    use crate::{SymbolTable, Trs};
14    use rand::seq::SliceRandom;
15    use std::sync::Arc;
16
17    #[test]
18    fn test_small_fst() -> Result<()> {
19        let mut fst = VectorFst::<ProbabilityWeight>::new();
20
21        // States
22        let s1 = fst.add_state();
23        let s2 = fst.add_state();
24
25        fst.set_start(s1)?;
26
27        // Trs
28        let tr_1 = Tr::new(3, 5, 10.0, s2);
29        fst.add_tr(s1, tr_1.clone())?;
30
31        assert_eq!(fst.num_trs(s1).unwrap(), 1);
32
33        let tr_2 = Tr::new(5, 7, 18.0, s2);
34        fst.add_tr(s1, tr_2.clone())?;
35        assert_eq!(fst.num_trs(s1).unwrap(), 2);
36        assert_eq!(fst.get_trs(s1)?.trs().iter().count(), 2);
37
38        // Iterates on trs leaving s1
39        let it_s1 = fst.get_trs(s1)?;
40        assert_eq!(it_s1.len(), 2);
41        assert_eq!(tr_1, it_s1.trs()[0]);
42        assert_eq!(tr_2, it_s1.trs()[1]);
43
44        // Iterates on trs leaving s2
45        let it_s2 = fst.get_trs(s2)?;
46
47        assert_eq!(it_s2.len(), 0);
48        Ok(())
49    }
50
51    #[test]
52    fn test_mutable_iter_trs_small() -> Result<()> {
53        let mut fst = VectorFst::<ProbabilityWeight>::new();
54
55        // States
56        let s1 = fst.add_state();
57        let s2 = fst.add_state();
58
59        fst.set_start(s1)?;
60
61        // Trs
62        let tr_1 = Tr::new(3, 5, 10.0, s2);
63        fst.add_tr(s1, tr_1.clone())?;
64        let tr_2 = Tr::new(5, 7, 18.0, s2);
65        fst.add_tr(s1, tr_2.clone())?;
66
67        let new_tr_1 = Tr::new(15, 29, 33.0, s2 + 55);
68
69        // Modify first transition leaving s1
70        let mut tr_it = fst.tr_iter_mut(s1)?;
71        tr_it.set_tr(0, new_tr_1.clone())?;
72
73        let it_s1 = fst.get_trs(s1)?;
74        assert_eq!(new_tr_1, it_s1[0]);
75        assert_eq!(tr_2, it_s1[1]);
76        assert_eq!(it_s1.len(), 2);
77
78        Ok(())
79    }
80
81    #[test]
82    fn test_start_states() -> Result<()> {
83        let mut fst = VectorFst::<ProbabilityWeight>::new();
84
85        let n_states = 1000;
86
87        // Add N states to the FST
88        let states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
89
90        // Should be no start state
91        assert_eq!(fst.start(), None);
92
93        // New start state
94        fst.set_start(states[18])?;
95        assert_eq!(fst.start(), Some(states[18]));
96
97        // New start state
98        fst.set_start(states[32])?;
99        assert_eq!(fst.start(), Some(states[32]));
100
101        Ok(())
102    }
103
104    #[test]
105    fn test_only_final_states() -> Result<()> {
106        let mut fst = VectorFst::<ProbabilityWeight>::new();
107
108        let n_states = 1000;
109
110        // Add N states to the FST
111        let states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
112
113        // Number of final states should be zero
114        assert_eq!(fst.final_states_iter().count(), 0);
115
116        // Set all states as final
117        states
118            .iter()
119            .for_each(|v| fst.set_final(*v, ProbabilityWeight::one()).unwrap());
120
121        // Number of final states should be n_states
122        assert_eq!(fst.final_states_iter().count(), n_states);
123
124        Ok(())
125    }
126
127    #[test]
128    fn test_final_weight() -> Result<()> {
129        let mut fst = VectorFst::<ProbabilityWeight>::new();
130
131        let n_states = 1000;
132        let n_final_states = 300;
133
134        // Add N states to the FST
135        let mut states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
136
137        // None of the states are final => None final weight
138        assert!(fst
139            .states_iter()
140            .map(|state_id| fst.final_weight(state_id).unwrap())
141            .all(|v| v.is_none()));
142
143        // Select randomly n_final_states
144        let mut rg = StdRng::from_seed([53; 32]);
145        states.shuffle(&mut rg);
146        let final_states: Vec<_> = states.into_iter().take(n_final_states).collect();
147
148        // Set those as final with a weight equals to its position in the vector
149        final_states.iter().enumerate().for_each(|(idx, state_id)| {
150            fst.set_final(*state_id, ProbabilityWeight::new(idx as f32 + 1_f32))
151                .unwrap()
152        });
153
154        // Check they are final with the correct weight
155        assert!(final_states
156            .iter()
157            .all(|state_id| fst.is_final(*state_id).unwrap()));
158        assert!(final_states
159            .iter()
160            .enumerate()
161            .all(|(idx, state_id)| fst.final_weight(*state_id).unwrap()
162                == Some(ProbabilityWeight::new(idx as f32 + 1_f32))));
163        Ok(())
164    }
165
166    #[test]
167    fn test_del_state_trs() -> Result<()> {
168        let mut fst = VectorFst::<ProbabilityWeight>::new();
169
170        let s1 = fst.add_state();
171        let s2 = fst.add_state();
172
173        fst.add_tr(s1, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
174        fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s1))?;
175        fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
176
177        assert_eq!(fst.num_trs(s1)?, 1);
178        assert_eq!(fst.num_trs(s2)?, 2);
179        assert_eq!(fst.get_trs(s1)?.len(), 1);
180        assert_eq!(fst.get_trs(s2)?.len(), 2);
181
182        fst.del_state(s1)?;
183
184        assert_eq!(fst.num_trs(0)?, 1);
185
186        let only_state = fst.states_iter().next().unwrap();
187        assert_eq!(fst.get_trs(only_state)?.len(), 1);
188        Ok(())
189    }
190
191    #[test]
192    fn test_deleting_twice_same_state() -> Result<()> {
193        let mut fst1 = VectorFst::<ProbabilityWeight>::new();
194
195        let s = fst1.add_state();
196
197        //        let mut fst2 = fst1.clone();
198
199        // Perform test with del_state
200        assert!(fst1.del_state(s).is_ok());
201        assert!(fst1.del_state(s).is_err());
202
203        // Perform test with del_states
204        //        let states_to_remove = vec![s, s];
205        //        assert!(fst2.del_states(states_to_remove.into_iter()).is_err());
206        Ok(())
207    }
208
209    #[test]
210    fn test_del_multiple_states() {
211        // Test to check that
212        let mut fst1 = VectorFst::<ProbabilityWeight>::new();
213
214        let s1 = fst1.add_state();
215        let s2 = fst1.add_state();
216
217        let mut fst2 = fst1.clone();
218
219        // Pass because s2 state id is modified by the first call
220        assert!(fst1.del_state(s1).is_ok());
221        assert!(fst1.del_state(s2).is_err());
222
223        // Test that the above issue doesn't arrive when calling del_states
224        let states_to_remove = vec![s1, s2];
225        assert!(fst2.del_states(states_to_remove.into_iter()).is_ok());
226    }
227
228    #[test]
229    fn test_del_states_big() -> Result<()> {
230        let n_states = 1000;
231        let n_states_to_delete = 300;
232
233        let mut fst = VectorFst::<ProbabilityWeight>::new();
234
235        // Add N states to the FST
236        let mut states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
237
238        // Check those N states do exist
239        assert_eq!(fst.num_states(), n_states);
240
241        // Sample n_states_to_delete to remove from the FST
242        let mut rg = StdRng::from_seed([53; 32]);
243        states.shuffle(&mut rg);
244        let states_to_delete: Vec<_> = states.into_iter().take(n_states_to_delete).collect();
245
246        fst.del_states(states_to_delete)?;
247
248        // Check they are correctly removed
249        assert_eq!(fst.num_states(), n_states - n_states_to_delete);
250        Ok(())
251    }
252
253    #[test]
254    fn test_parse_single_final_state() -> Result<()> {
255        let parsed_fst = VectorFst::<TropicalWeight>::from_text_string("0\tInfinity\n")?;
256
257        let mut fst_ref: VectorFst<TropicalWeight> = VectorFst::new();
258
259        fst_ref.add_state();
260        fst_ref.set_start(0)?;
261
262        assert_eq!(fst_ref, parsed_fst);
263
264        Ok(())
265    }
266
267    #[test]
268    fn test_del_all_states() -> Result<()> {
269        let mut fst = VectorFst::<ProbabilityWeight>::new();
270
271        let s1 = fst.add_state();
272        let s2 = fst.add_state();
273
274        fst.add_tr(s1, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
275        fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s1))?;
276        fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
277
278        fst.set_start(s1)?;
279        fst.set_final(s2, ProbabilityWeight::one())?;
280
281        assert_eq!(fst.num_states(), 2);
282        fst.del_all_states();
283        assert_eq!(fst.num_states(), 0);
284
285        Ok(())
286    }
287
288    #[test]
289    fn test_attach_symt() -> Result<()> {
290        let mut fst = VectorFst::<ProbabilityWeight>::new();
291
292        let s1 = fst.add_state();
293        let s2 = fst.add_state();
294
295        fst.add_tr(s1, Tr::new(1, 0, ProbabilityWeight::one(), s2))?;
296        fst.add_tr(s2, Tr::new(2, 0, ProbabilityWeight::one(), s1))?;
297        fst.add_tr(s2, Tr::new(3, 0, ProbabilityWeight::one(), s2))?;
298
299        fst.set_start(s1)?;
300        fst.set_final(s2, ProbabilityWeight::one())?;
301
302        // Test input symbol table
303        {
304            let mut symt = SymbolTable::new();
305            symt.add_symbol("a"); // 1
306            symt.add_symbol("b"); // 2
307            symt.add_symbol("c"); // 3
308
309            fst.set_input_symbols(Arc::new(symt));
310        }
311        {
312            let symt = fst.input_symbols();
313            assert!(symt.is_some());
314            let symt = symt.unwrap();
315            assert_eq!(symt.len(), 4);
316        }
317
318        // Test output symbol table
319        {
320            let symt = SymbolTable::new();
321            fst.set_output_symbols(Arc::new(symt));
322        }
323        {
324            let symt = fst.output_symbols();
325            assert!(symt.is_some());
326            let symt = symt.unwrap();
327            assert_eq!(symt.len(), 1);
328        }
329
330        Ok(())
331    }
332}