rustfst/fst_impls/vector_fst/
test.rs1#[cfg(test)]
2mod tests {
3 use rand::{rngs::StdRng, SeedableRng};
4
5 use anyhow::Result;
6
7 use crate::fst_impls::VectorFst;
8 use crate::fst_traits::{
9 CoreFst, ExpandedFst, Fst, MutableFst, SerializableFst, StateIterator,
10 };
11 use crate::semirings::{ProbabilityWeight, Semiring, TropicalWeight};
12 use crate::tr::Tr;
13 use crate::{SymbolTable, Trs};
14 use rand::seq::SliceRandom;
15 use std::sync::Arc;
16
17 #[test]
18 fn test_small_fst() -> Result<()> {
19 let mut fst = VectorFst::<ProbabilityWeight>::new();
20
21 let s1 = fst.add_state();
23 let s2 = fst.add_state();
24
25 fst.set_start(s1)?;
26
27 let tr_1 = Tr::new(3, 5, 10.0, s2);
29 fst.add_tr(s1, tr_1.clone())?;
30
31 assert_eq!(fst.num_trs(s1).unwrap(), 1);
32
33 let tr_2 = Tr::new(5, 7, 18.0, s2);
34 fst.add_tr(s1, tr_2.clone())?;
35 assert_eq!(fst.num_trs(s1).unwrap(), 2);
36 assert_eq!(fst.get_trs(s1)?.trs().iter().count(), 2);
37
38 let it_s1 = fst.get_trs(s1)?;
40 assert_eq!(it_s1.len(), 2);
41 assert_eq!(tr_1, it_s1.trs()[0]);
42 assert_eq!(tr_2, it_s1.trs()[1]);
43
44 let it_s2 = fst.get_trs(s2)?;
46
47 assert_eq!(it_s2.len(), 0);
48 Ok(())
49 }
50
51 #[test]
52 fn test_mutable_iter_trs_small() -> Result<()> {
53 let mut fst = VectorFst::<ProbabilityWeight>::new();
54
55 let s1 = fst.add_state();
57 let s2 = fst.add_state();
58
59 fst.set_start(s1)?;
60
61 let tr_1 = Tr::new(3, 5, 10.0, s2);
63 fst.add_tr(s1, tr_1.clone())?;
64 let tr_2 = Tr::new(5, 7, 18.0, s2);
65 fst.add_tr(s1, tr_2.clone())?;
66
67 let new_tr_1 = Tr::new(15, 29, 33.0, s2 + 55);
68
69 let mut tr_it = fst.tr_iter_mut(s1)?;
71 tr_it.set_tr(0, new_tr_1.clone())?;
72
73 let it_s1 = fst.get_trs(s1)?;
74 assert_eq!(new_tr_1, it_s1[0]);
75 assert_eq!(tr_2, it_s1[1]);
76 assert_eq!(it_s1.len(), 2);
77
78 Ok(())
79 }
80
81 #[test]
82 fn test_start_states() -> Result<()> {
83 let mut fst = VectorFst::<ProbabilityWeight>::new();
84
85 let n_states = 1000;
86
87 let states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
89
90 assert_eq!(fst.start(), None);
92
93 fst.set_start(states[18])?;
95 assert_eq!(fst.start(), Some(states[18]));
96
97 fst.set_start(states[32])?;
99 assert_eq!(fst.start(), Some(states[32]));
100
101 Ok(())
102 }
103
104 #[test]
105 fn test_only_final_states() -> Result<()> {
106 let mut fst = VectorFst::<ProbabilityWeight>::new();
107
108 let n_states = 1000;
109
110 let states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
112
113 assert_eq!(fst.final_states_iter().count(), 0);
115
116 states
118 .iter()
119 .for_each(|v| fst.set_final(*v, ProbabilityWeight::one()).unwrap());
120
121 assert_eq!(fst.final_states_iter().count(), n_states);
123
124 Ok(())
125 }
126
127 #[test]
128 fn test_final_weight() -> Result<()> {
129 let mut fst = VectorFst::<ProbabilityWeight>::new();
130
131 let n_states = 1000;
132 let n_final_states = 300;
133
134 let mut states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
136
137 assert!(fst
139 .states_iter()
140 .map(|state_id| fst.final_weight(state_id).unwrap())
141 .all(|v| v.is_none()));
142
143 let mut rg = StdRng::from_seed([53; 32]);
145 states.shuffle(&mut rg);
146 let final_states: Vec<_> = states.into_iter().take(n_final_states).collect();
147
148 final_states.iter().enumerate().for_each(|(idx, state_id)| {
150 fst.set_final(*state_id, ProbabilityWeight::new(idx as f32 + 1_f32))
151 .unwrap()
152 });
153
154 assert!(final_states
156 .iter()
157 .all(|state_id| fst.is_final(*state_id).unwrap()));
158 assert!(final_states
159 .iter()
160 .enumerate()
161 .all(|(idx, state_id)| fst.final_weight(*state_id).unwrap()
162 == Some(ProbabilityWeight::new(idx as f32 + 1_f32))));
163 Ok(())
164 }
165
166 #[test]
167 fn test_del_state_trs() -> Result<()> {
168 let mut fst = VectorFst::<ProbabilityWeight>::new();
169
170 let s1 = fst.add_state();
171 let s2 = fst.add_state();
172
173 fst.add_tr(s1, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
174 fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s1))?;
175 fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
176
177 assert_eq!(fst.num_trs(s1)?, 1);
178 assert_eq!(fst.num_trs(s2)?, 2);
179 assert_eq!(fst.get_trs(s1)?.len(), 1);
180 assert_eq!(fst.get_trs(s2)?.len(), 2);
181
182 fst.del_state(s1)?;
183
184 assert_eq!(fst.num_trs(0)?, 1);
185
186 let only_state = fst.states_iter().next().unwrap();
187 assert_eq!(fst.get_trs(only_state)?.len(), 1);
188 Ok(())
189 }
190
191 #[test]
192 fn test_deleting_twice_same_state() -> Result<()> {
193 let mut fst1 = VectorFst::<ProbabilityWeight>::new();
194
195 let s = fst1.add_state();
196
197 assert!(fst1.del_state(s).is_ok());
201 assert!(fst1.del_state(s).is_err());
202
203 Ok(())
207 }
208
209 #[test]
210 fn test_del_multiple_states() {
211 let mut fst1 = VectorFst::<ProbabilityWeight>::new();
213
214 let s1 = fst1.add_state();
215 let s2 = fst1.add_state();
216
217 let mut fst2 = fst1.clone();
218
219 assert!(fst1.del_state(s1).is_ok());
221 assert!(fst1.del_state(s2).is_err());
222
223 let states_to_remove = vec![s1, s2];
225 assert!(fst2.del_states(states_to_remove.into_iter()).is_ok());
226 }
227
228 #[test]
229 fn test_del_states_big() -> Result<()> {
230 let n_states = 1000;
231 let n_states_to_delete = 300;
232
233 let mut fst = VectorFst::<ProbabilityWeight>::new();
234
235 let mut states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
237
238 assert_eq!(fst.num_states(), n_states);
240
241 let mut rg = StdRng::from_seed([53; 32]);
243 states.shuffle(&mut rg);
244 let states_to_delete: Vec<_> = states.into_iter().take(n_states_to_delete).collect();
245
246 fst.del_states(states_to_delete)?;
247
248 assert_eq!(fst.num_states(), n_states - n_states_to_delete);
250 Ok(())
251 }
252
253 #[test]
254 fn test_parse_single_final_state() -> Result<()> {
255 let parsed_fst = VectorFst::<TropicalWeight>::from_text_string("0\tInfinity\n")?;
256
257 let mut fst_ref: VectorFst<TropicalWeight> = VectorFst::new();
258
259 fst_ref.add_state();
260 fst_ref.set_start(0)?;
261
262 assert_eq!(fst_ref, parsed_fst);
263
264 Ok(())
265 }
266
267 #[test]
268 fn test_del_all_states() -> Result<()> {
269 let mut fst = VectorFst::<ProbabilityWeight>::new();
270
271 let s1 = fst.add_state();
272 let s2 = fst.add_state();
273
274 fst.add_tr(s1, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
275 fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s1))?;
276 fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
277
278 fst.set_start(s1)?;
279 fst.set_final(s2, ProbabilityWeight::one())?;
280
281 assert_eq!(fst.num_states(), 2);
282 fst.del_all_states();
283 assert_eq!(fst.num_states(), 0);
284
285 Ok(())
286 }
287
288 #[test]
289 fn test_attach_symt() -> Result<()> {
290 let mut fst = VectorFst::<ProbabilityWeight>::new();
291
292 let s1 = fst.add_state();
293 let s2 = fst.add_state();
294
295 fst.add_tr(s1, Tr::new(1, 0, ProbabilityWeight::one(), s2))?;
296 fst.add_tr(s2, Tr::new(2, 0, ProbabilityWeight::one(), s1))?;
297 fst.add_tr(s2, Tr::new(3, 0, ProbabilityWeight::one(), s2))?;
298
299 fst.set_start(s1)?;
300 fst.set_final(s2, ProbabilityWeight::one())?;
301
302 {
304 let mut symt = SymbolTable::new();
305 symt.add_symbol("a"); symt.add_symbol("b"); symt.add_symbol("c"); fst.set_input_symbols(Arc::new(symt));
310 }
311 {
312 let symt = fst.input_symbols();
313 assert!(symt.is_some());
314 let symt = symt.unwrap();
315 assert_eq!(symt.len(), 4);
316 }
317
318 {
320 let symt = SymbolTable::new();
321 fst.set_output_symbols(Arc::new(symt));
322 }
323 {
324 let symt = fst.output_symbols();
325 assert!(symt.is_some());
326 let symt = symt.unwrap();
327 assert_eq!(symt.len(), 1);
328 }
329
330 Ok(())
331 }
332}