1use std::cmp::max;
2use std::cmp::Ordering;
3use std::collections::hash_map::Entry;
4use std::collections::HashMap;
5use std::collections::HashSet;
6use std::marker::PhantomData;
7
8use anyhow::Result;
9use binary_heap_plus::BinaryHeap;
10use stable_bst::TreeMap;
11
12use crate::algorithms::encode::EncodeType;
13use crate::algorithms::factor_weight::factor_iterators::GallicFactorLeft;
14use crate::algorithms::factor_weight::{factor_weight, FactorWeightOptions, FactorWeightType};
15use crate::algorithms::partition::Partition;
16use crate::algorithms::queues::LifoQueue;
17use crate::algorithms::tr_compares::ILabelCompare;
18use crate::algorithms::tr_mappers::QuantizeMapper;
19use crate::algorithms::tr_unique;
20use crate::algorithms::weight_converters::{FromGallicConverter, ToGallicConverter};
21use crate::algorithms::Queue;
22use crate::algorithms::{
23 connect,
24 encode::{decode, encode},
25 tr_map, tr_sort, weight_convert, ReweightType,
26};
27use crate::algorithms::{push_weights_with_config, reverse, PushWeightsConfig};
28use crate::fst_impls::VectorFst;
29use crate::fst_properties::FstProperties;
30use crate::fst_traits::{AllocableFst, CoreFst, ExpandedFst, Fst, MutableFst};
31use crate::semirings::{
32 GallicWeightLeft, Semiring, SemiringProperties, WeaklyDivisibleSemiring, WeightQuantize,
33};
34use crate::EPS_LABEL;
35use crate::KDELTA;
36use crate::{Label, StateId, Trs};
37use crate::{Tr, KSHORTESTDELTA};
38use itertools::Itertools;
39use std::cell::RefCell;
40use std::rc::Rc;
41
42#[derive(Clone, Copy, PartialOrd, PartialEq)]
44pub struct MinimizeConfig {
45 pub delta: f32,
46 pub allow_nondet: bool,
47}
48
49impl MinimizeConfig {
50 pub fn new(delta: f32, allow_nondet: bool) -> Self {
51 Self {
52 delta,
53 allow_nondet,
54 }
55 }
56
57 pub fn with_delta(self, delta: f32) -> Self {
58 Self { delta, ..self }
59 }
60
61 pub fn with_allow_nondet(self, allow_nondet: bool) -> Self {
62 Self {
63 allow_nondet,
64 ..self
65 }
66 }
67}
68
69impl Default for MinimizeConfig {
70 fn default() -> Self {
71 Self {
72 delta: KSHORTESTDELTA,
73 allow_nondet: false,
74 }
75 }
76}
77
78pub fn minimize<W, F>(ifst: &mut F) -> Result<()>
82where
83 F: MutableFst<W> + ExpandedFst<W> + AllocableFst<W>,
84 W: WeaklyDivisibleSemiring + WeightQuantize,
85 W::ReverseWeight: WeightQuantize,
86{
87 minimize_with_config(ifst, MinimizeConfig::default())
88}
89
90pub fn minimize_with_config<W, F>(ifst: &mut F, config: MinimizeConfig) -> Result<()>
94where
95 F: MutableFst<W> + ExpandedFst<W> + AllocableFst<W>,
96 W: WeaklyDivisibleSemiring + WeightQuantize,
97 W::ReverseWeight: WeightQuantize,
98{
99 let delta = config.delta;
100 let allow_nondet = config.allow_nondet;
101
102 let props = ifst.compute_and_update_properties(
103 FstProperties::ACCEPTOR
104 | FstProperties::I_DETERMINISTIC
105 | FstProperties::WEIGHTED
106 | FstProperties::UNWEIGHTED,
107 )?;
108
109 let allow_acyclic_minimization = if props.contains(FstProperties::I_DETERMINISTIC) {
110 true
111 } else {
112 if !W::properties().contains(SemiringProperties::IDEMPOTENT) {
113 bail!("Cannot minimize a non-deterministic FST over a non-idempotent semiring")
114 } else if !allow_nondet {
115 bail!("Refusing to minimize a non-deterministic FST with allow_nondet = false")
116 }
117
118 false
119 };
120
121 if !props.contains(FstProperties::ACCEPTOR) {
122 let mut to_gallic = ToGallicConverter {};
124 let mut gfst: VectorFst<GallicWeightLeft<W>> = weight_convert(ifst, &mut to_gallic)?;
125 let push_weights_config = PushWeightsConfig::default().with_delta(delta);
126 push_weights_with_config(
127 &mut gfst,
128 ReweightType::ReweightToInitial,
129 push_weights_config,
130 )?;
131
132 let quantize_mapper = QuantizeMapper::new(delta);
133 tr_map(&mut gfst, &quantize_mapper)?;
134
135 let encode_table = encode(&mut gfst, EncodeType::EncodeWeightsAndLabels)?;
136
137 acceptor_minimize(&mut gfst, allow_acyclic_minimization)?;
138
139 decode(&mut gfst, encode_table)?;
140
141 let factor_opts: FactorWeightOptions = FactorWeightOptions {
142 delta: KDELTA,
143 mode: FactorWeightType::FACTOR_FINAL_WEIGHTS | FactorWeightType::FACTOR_ARC_WEIGHTS,
144 final_ilabel: 0,
145 final_olabel: 0,
146 increment_final_ilabel: false,
147 increment_final_olabel: false,
148 };
149
150 let fwfst: VectorFst<_> =
151 factor_weight::<_, VectorFst<GallicWeightLeft<W>>, _, _, GallicFactorLeft<W>>(
152 &gfst,
153 factor_opts,
154 )?;
155
156 let mut from_gallic = FromGallicConverter {
157 superfinal_label: EPS_LABEL,
158 };
159 *ifst = weight_convert(&fwfst, &mut from_gallic)?;
160
161 Ok(())
162 } else if props.contains(FstProperties::WEIGHTED) {
163 let push_weights_config = PushWeightsConfig::default().with_delta(delta);
165 push_weights_with_config(ifst, ReweightType::ReweightToInitial, push_weights_config)?;
166 let quantize_mapper = QuantizeMapper::new(delta);
167 tr_map(ifst, &quantize_mapper)?;
168 let encode_table = encode(ifst, EncodeType::EncodeWeightsAndLabels)?;
169 acceptor_minimize(ifst, allow_acyclic_minimization)?;
170 decode(ifst, encode_table)
171 } else {
172 acceptor_minimize(ifst, allow_acyclic_minimization)
174 }
175}
176
177pub fn acceptor_minimize<W: Semiring, F: MutableFst<W> + ExpandedFst<W>>(
183 ifst: &mut F,
184 allow_acyclic_minimization: bool,
185) -> Result<()> {
186 let props = ifst.compute_and_update_properties(
187 FstProperties::ACCEPTOR | FstProperties::UNWEIGHTED | FstProperties::ACYCLIC,
188 )?;
189 if !props.contains(FstProperties::ACCEPTOR | FstProperties::UNWEIGHTED) {
190 bail!("FST is not an unweighted acceptor");
191 }
192
193 connect(ifst)?;
194
195 if ifst.num_states() == 0 {
196 return Ok(());
197 }
198
199 if allow_acyclic_minimization && props.contains(FstProperties::ACYCLIC) {
200 tr_sort(ifst, ILabelCompare {});
202 let minimizer = AcyclicMinimizer::new(ifst)?;
203 merge_states(minimizer.get_partition(), ifst)?;
204 } else {
205 let p = cyclic_minimize(ifst)?;
206 merge_states(p, ifst)?;
207 }
208
209 tr_unique(ifst);
210
211 Ok(())
212}
213
214fn merge_states<W: Semiring, F: MutableFst<W>>(
215 partition: Rc<RefCell<Partition>>,
216 fst: &mut F,
217) -> Result<()> {
218 let mut state_map = vec![None; partition.borrow().num_classes()];
219
220 for (i, state_map_i) in state_map
221 .iter_mut()
222 .enumerate()
223 .take(partition.borrow().num_classes())
224 {
225 *state_map_i = partition.borrow().iter(i).next();
226 }
227
228 for c in 0..partition.borrow().num_classes() {
229 for s in partition.borrow().iter(c) {
230 if s == state_map[c].unwrap() {
231 let mut it_tr = fst.tr_iter_mut(s as StateId)?;
232 for idx_tr in 0..it_tr.len() {
233 let tr = unsafe { it_tr.get_unchecked(idx_tr) };
234 let nextstate =
235 state_map[partition.borrow().get_class_id(tr.nextstate as usize)].unwrap();
236 unsafe { it_tr.set_nextstate_unchecked(idx_tr, nextstate as StateId) };
237 }
238 } else {
239 for tr in fst
240 .get_trs(s as StateId)?
241 .trs()
242 .iter()
243 .cloned()
244 .map(|mut tr| {
245 tr.nextstate = state_map
246 [partition.borrow().get_class_id(tr.nextstate as usize)]
247 .unwrap() as StateId;
248 tr
249 })
250 {
251 fst.add_tr(state_map[c].unwrap() as StateId, tr)?;
252 }
253 }
254 }
255 }
256
257 fst.set_start(
258 state_map[partition
259 .borrow()
260 .get_class_id(fst.start().unwrap() as usize)]
261 .unwrap() as StateId,
262 )?;
263
264 connect(fst)?;
265
266 Ok(())
267}
268
269pub fn fst_depth<W: Semiring, F: Fst<W>>(
271 fst: &F,
272 state_id_cour: StateId,
273 accessible_states: &mut HashSet<StateId>,
274 fully_examined_states: &mut HashSet<StateId>,
275 heights: &mut Vec<i32>,
276) -> Result<()> {
277 accessible_states.insert(state_id_cour);
278
279 for _ in heights.len()..=(state_id_cour as usize) {
280 heights.push(-1);
281 }
282
283 let mut height_cur_state = 0;
284 for tr in fst.get_trs(state_id_cour)?.trs() {
285 let nextstate = tr.nextstate;
286
287 if !accessible_states.contains(&nextstate) {
288 fst_depth(
289 fst,
290 nextstate,
291 accessible_states,
292 fully_examined_states,
293 heights,
294 )?;
295 }
296
297 height_cur_state = max(height_cur_state, 1 + heights[nextstate as usize]);
298 }
299 fully_examined_states.insert(state_id_cour);
300
301 heights[state_id_cour as usize] = height_cur_state;
302
303 Ok(())
304}
305
306struct AcyclicMinimizer {
307 partition: Rc<RefCell<Partition>>,
308}
309
310impl AcyclicMinimizer {
311 pub fn new<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<Self> {
312 let mut c = Self {
313 partition: Rc::new(RefCell::new(Partition::empty_new())),
314 };
315 c.initialize(fst)?;
316 c.refine(fst);
317 Ok(c)
318 }
319
320 fn initialize<W: Semiring, F: MutableFst<W>>(&mut self, fst: &mut F) -> Result<()> {
321 let mut accessible_state = HashSet::new();
322 let mut fully_examined_states = HashSet::new();
323 let mut heights = Vec::new();
324 fst_depth(
325 fst,
326 fst.start().unwrap(),
327 &mut accessible_state,
328 &mut fully_examined_states,
329 &mut heights,
330 )?;
331 self.partition.borrow_mut().initialize(heights.len());
332 self.partition
333 .borrow_mut()
334 .allocate_classes((heights.iter().max().unwrap() + 1) as usize);
335 for (s, h) in heights.iter().enumerate() {
336 self.partition.borrow_mut().add(s, *h as usize);
337 }
338 Ok(())
339 }
340
341 fn refine<W: Semiring, F: MutableFst<W>>(&mut self, fst: &mut F) {
342 let state_cmp = StateComparator {
343 fst,
344 partition: Rc::clone(&self.partition),
345 w: PhantomData,
346 };
347
348 let height = self.partition.borrow().num_classes();
349 for h in 0..height {
350 let mut equiv_classes =
354 TreeMap::<StateId, StateId, _>::with_comparator(|a: &StateId, b: &StateId| {
355 state_cmp.compare(*a, *b).unwrap()
356 });
357
358 let it_partition: Vec<_> = self.partition.borrow().iter(h).collect();
359 equiv_classes.insert(it_partition[0] as StateId, h as StateId);
360
361 for e in it_partition.iter().skip(1) {
362 equiv_classes.get_or_insert(*e as StateId, || {
363 self.partition.borrow_mut().add_class() as StateId
364 });
365 }
366
367 for s in it_partition {
368 let old_class = self.partition.borrow().get_class_id(s);
369 let new_class = *equiv_classes.get(&(s as StateId)).unwrap();
370
371 if old_class != (new_class as usize) {
372 self.partition
373 .borrow_mut()
374 .move_element(s, new_class as usize);
375 }
376 }
377 }
378 }
379
380 pub fn get_partition(self) -> Rc<RefCell<Partition>> {
381 self.partition
382 }
383}
384
385struct StateComparator<'a, W: Semiring, F: MutableFst<W>> {
386 fst: &'a F,
387 partition: Rc<RefCell<Partition>>,
388 w: PhantomData<W>,
389}
390
391impl<'a, W: Semiring, F: MutableFst<W>> StateComparator<'a, W, F> {
392 fn do_compare(&self, x: StateId, y: StateId) -> Result<bool> {
393 let xfinal = self.fst.final_weight(x)?.unwrap_or_else(W::zero);
394 let yfinal = self.fst.final_weight(y)?.unwrap_or_else(W::zero);
395
396 if xfinal < yfinal {
397 return Ok(true);
398 } else if xfinal > yfinal {
399 return Ok(false);
400 }
401
402 if self.fst.num_trs(x)? < self.fst.num_trs(y)? {
403 return Ok(true);
404 }
405 if self.fst.num_trs(x)? > self.fst.num_trs(y)? {
406 return Ok(false);
407 }
408
409 let it_x_owner = self.fst.get_trs(x)?;
410 let it_x = it_x_owner.trs().iter();
411 let it_y_owner = self.fst.get_trs(y)?;
412 let it_y = it_y_owner.trs().iter();
413
414 for (arc1, arc2) in it_x.zip(it_y) {
415 if arc1.ilabel < arc2.ilabel {
416 return Ok(true);
417 }
418 if arc1.ilabel > arc2.ilabel {
419 return Ok(false);
420 }
421 let id_1 = self
422 .partition
423 .borrow()
424 .get_class_id(arc1.nextstate as usize);
425 let id_2 = self
426 .partition
427 .borrow()
428 .get_class_id(arc2.nextstate as usize);
429 if id_1 < id_2 {
430 return Ok(true);
431 }
432 if id_1 > id_2 {
433 return Ok(false);
434 }
435 }
436 Ok(false)
437 }
438
439 pub fn compare(&self, x: StateId, y: StateId) -> Result<Ordering> {
440 if x == y {
441 return Ok(Ordering::Equal);
442 }
443
444 let x_y = self.do_compare(x, y).unwrap();
445 let y_x = self.do_compare(y, x).unwrap();
446
447 if !(x_y) && !(y_x) {
448 return Ok(Ordering::Equal);
449 }
450
451 if x_y {
452 Ok(Ordering::Less)
453 } else {
454 Ok(Ordering::Greater)
455 }
456 }
457}
458
459fn pre_partition<W: Semiring, F: MutableFst<W>>(
460 fst: &F,
461 partition: &Rc<RefCell<Partition>>,
462 queue: &mut LifoQueue,
463) {
464 let mut next_class: StateId = 0;
465 let num_states = fst.num_states();
466
467 let mut state_to_initial_class: Vec<StateId> = vec![0; num_states];
468 {
469 let mut hash_to_class_nonfinal = HashMap::<Vec<Label>, StateId>::new();
470 let mut hash_to_class_final = HashMap::<Vec<Label>, StateId>::new();
471
472 for (s, state_to_initial_class_s) in state_to_initial_class
473 .iter_mut()
474 .enumerate()
475 .take(num_states)
476 {
477 let this_map = if unsafe { fst.is_final_unchecked(s as StateId) } {
478 &mut hash_to_class_final
479 } else {
480 &mut hash_to_class_nonfinal
481 };
482
483 let ilabels = fst
484 .get_trs(s as StateId)
485 .unwrap()
486 .trs()
487 .iter()
488 .map(|e| e.ilabel)
489 .dedup()
490 .collect_vec();
491
492 match this_map.entry(ilabels) {
493 Entry::Occupied(e) => {
494 *state_to_initial_class_s = *e.get();
495 }
496 Entry::Vacant(e) => {
497 e.insert(next_class);
498 *state_to_initial_class_s = next_class;
499 next_class += 1;
500 }
501 };
502 }
503 }
504
505 partition.borrow_mut().allocate_classes(next_class as usize);
506 for (s, c) in state_to_initial_class.iter().enumerate().take(num_states) {
507 partition.borrow_mut().add(s, *c as usize);
508 }
509
510 for c in 0..next_class {
511 queue.enqueue(c);
512 }
513}
514
515fn cyclic_minimize<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<Rc<RefCell<Partition>>> {
516 let mut tr: VectorFst<W::ReverseWeight> = reverse(fst)?;
518 tr_sort(&mut tr, ILabelCompare {});
519
520 let partition = Rc::new(RefCell::new(Partition::new(tr.num_states() - 1)));
521 let mut queue = LifoQueue::default();
522 pre_partition(fst, &partition, &mut queue);
523
524 let comp = TrIterCompare {};
525
526 let mut aiter_queue = BinaryHeap::new_by(|v1, v2| {
527 if comp.compare(v1, v2) {
528 Ordering::Less
529 } else {
530 Ordering::Greater
531 }
532 });
533
534 while let Some(c) = queue.dequeue() {
536 for s in partition.borrow().iter(c as usize) {
538 if tr.num_trs(s as StateId + 1)? > 0 {
539 aiter_queue.push(TrsIterCollected {
540 idx: 0,
541 trs: tr.get_trs(s as StateId + 1)?,
542 w: PhantomData,
543 });
544 }
545 }
546
547 let mut prev_label = -1;
548 while !aiter_queue.is_empty() {
549 let mut aiter = aiter_queue.pop().unwrap();
550 if aiter.done() {
551 continue;
552 }
553 let tr = aiter.value().unwrap();
554 let from_state = tr.nextstate - 1;
555 let from_label = tr.ilabel;
556 if prev_label != from_label as i32 {
557 partition.borrow_mut().finalize_split(&mut Some(&mut queue));
558 }
559 let from_class = partition.borrow().get_class_id(from_state as usize);
560 if partition.borrow().get_class_size(from_class) > 1 {
561 partition.borrow_mut().split_on(from_state as usize);
562 }
563 prev_label = from_label as i32;
564 aiter.next();
565 if !aiter.done() {
566 aiter_queue.push(aiter);
567 }
568 }
569
570 partition.borrow_mut().finalize_split(&mut Some(&mut queue));
571 }
572
573 Ok(partition)
575}
576
577struct TrsIterCollected<W: Semiring, T: Trs<W>> {
578 idx: usize,
579 trs: T,
580 w: PhantomData<W>,
581}
582
583impl<W: Semiring, T: Trs<W>> TrsIterCollected<W, T> {
584 fn value(&self) -> Option<&Tr<W>> {
585 self.trs.trs().get(self.idx)
586 }
587
588 fn done(&self) -> bool {
589 self.idx >= self.trs.len()
590 }
591
592 fn next(&mut self) {
593 self.idx += 1;
594 }
595}
596
597#[derive(Debug, Clone)]
598struct TrIterCompare {}
599
600impl TrIterCompare {
601 fn compare<W: Semiring, T: Trs<W>>(
602 &self,
603 x: &TrsIterCollected<W, T>,
604 y: &TrsIterCollected<W, T>,
605 ) -> bool {
606 let xarc = x.value().unwrap();
607 let yarc = y.value().unwrap();
608 xarc.ilabel > yarc.ilabel
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use crate::prelude::*;
615 use ::proptest::prelude::*;
616 use algorithms::determinize::*;
617 use std::sync::Arc;
618
619 #[test]
620 fn test_minimize_issue_158() {
621 let text_fst = r#"0 5 101 101 0
6220 4 100 100 0
6230 3 99 99 0
6240 2 98 98 0
6250 1 97 97 0
6261 10 101 101 0
6271 9 100 100 0
6281 8 99 99 0
6291 7 98 98 0
6301 6 97 97 0
6312 11 101 101 0
6322 10 100 100 0
6332 9 99 99 0
6342 8 98 98 0
6352 7 97 97 0
6363 11 100 100 0
6373 10 99 99 0
6383 9 98 98 0
6393 8 97 97 0
6404 11 99 99 0
6414 10 98 98 0
6424 9 97 97 0
6435 11 98 98 0
6445 10 97 97 0
6456 15 101 101 0
6466 14 100 100 0
6476 13 99 99 0
6486 12 98 98 0
6497 16 101 101 0
6507 15 100 100 0
6517 14 99 99 0
6527 13 98 98 0
6537 12 97 97 0
6548 16 100 100 0
6558 15 99 99 0
6568 14 98 98 0
6578 13 97 97 0
6589 16 99 99 0
6599 15 98 98 0
6609 14 97 97 0
66110 16 98 98 0
66210 15 97 97 0
66311 16 97 97 0
66412 17 101 101 0
66513 17 100 100 0
66614 17 99 99 0
66715 17 98 98 0
66816 17 97 97 0
66917 18 32 32 0
67018 0
671 "#;
672 let path = fst_path![97, 98, 97, 100, 32];
673 let mut fst: VectorFst<TropicalWeight> = VectorFst::from_text_string(text_fst).unwrap();
674 let accept1 = check_path_in_fst(&fst, &path);
675 minimize(&mut fst).unwrap();
676 let accept2 = check_path_in_fst(&fst, &path);
677
678 assert_eq!(accept1, accept2);
679 }
680
681 proptest! {
682 #[test]
683 fn test_proptest_minimize_timeout(mut fst in any::<VectorFst::<TropicalWeight>>()) {
684 let config = MinimizeConfig::default().with_allow_nondet(true);
685 minimize_with_config(&mut fst, config).unwrap();
686 }
687 }
688
689 proptest! {
690 #[test]
691 #[ignore] fn test_minimize_proptest(mut fst in any::<VectorFst::<TropicalWeight>>()) {
693 let det:VectorFst<_> = determinize_with_config(&fst, DeterminizeConfig::default().with_det_type(DeterminizeType::DeterminizeNonFunctional)).unwrap();
694 let min_config = MinimizeConfig::default().with_allow_nondet(true);
695 minimize_with_config(&mut fst, min_config).unwrap();
696 let det_config = DeterminizeConfig::default().with_det_type(DeterminizeType::DeterminizeNonFunctional);
697 let min_det:VectorFst<_> = determinize_with_config(&fst, det_config).unwrap();
698 prop_assert!(isomorphic(&det, &min_det).unwrap())
699 }
700 }
701
702 proptest! {
703 #[test]
704 fn test_proptest_minimize_keeps_symts(mut fst in any::<VectorFst::<TropicalWeight>>()) {
705 let symt = Arc::new(SymbolTable::new());
706 fst.set_input_symbols(Arc::clone(&symt));
707 fst.set_output_symbols(Arc::clone(&symt));
708
709 minimize_with_config(&mut fst, MinimizeConfig::default().with_allow_nondet(true)).unwrap();
710
711 assert!(fst.input_symbols().is_some());
712 assert!(fst.output_symbols().is_some());
713 }
714 }
715}