rustfst/algorithms/
minimize.rs

1use std::cmp::max;
2use std::cmp::Ordering;
3use std::collections::hash_map::Entry;
4use std::collections::HashMap;
5use std::collections::HashSet;
6use std::marker::PhantomData;
7
8use anyhow::Result;
9use binary_heap_plus::BinaryHeap;
10use stable_bst::TreeMap;
11
12use crate::algorithms::encode::EncodeType;
13use crate::algorithms::factor_weight::factor_iterators::GallicFactorLeft;
14use crate::algorithms::factor_weight::{factor_weight, FactorWeightOptions, FactorWeightType};
15use crate::algorithms::partition::Partition;
16use crate::algorithms::queues::LifoQueue;
17use crate::algorithms::tr_compares::ILabelCompare;
18use crate::algorithms::tr_mappers::QuantizeMapper;
19use crate::algorithms::tr_unique;
20use crate::algorithms::weight_converters::{FromGallicConverter, ToGallicConverter};
21use crate::algorithms::Queue;
22use crate::algorithms::{
23    connect,
24    encode::{decode, encode},
25    tr_map, tr_sort, weight_convert, ReweightType,
26};
27use crate::algorithms::{push_weights_with_config, reverse, PushWeightsConfig};
28use crate::fst_impls::VectorFst;
29use crate::fst_properties::FstProperties;
30use crate::fst_traits::{AllocableFst, CoreFst, ExpandedFst, Fst, MutableFst};
31use crate::semirings::{
32    GallicWeightLeft, Semiring, SemiringProperties, WeaklyDivisibleSemiring, WeightQuantize,
33};
34use crate::EPS_LABEL;
35use crate::KDELTA;
36use crate::{Label, StateId, Trs};
37use crate::{Tr, KSHORTESTDELTA};
38use itertools::Itertools;
39use std::cell::RefCell;
40use std::rc::Rc;
41
42/// Configuration for minimization.
43#[derive(Clone, Copy, PartialOrd, PartialEq)]
44pub struct MinimizeConfig {
45    pub delta: f32,
46    pub allow_nondet: bool,
47}
48
49impl MinimizeConfig {
50    pub fn new(delta: f32, allow_nondet: bool) -> Self {
51        Self {
52            delta,
53            allow_nondet,
54        }
55    }
56
57    pub fn with_delta(self, delta: f32) -> Self {
58        Self { delta, ..self }
59    }
60
61    pub fn with_allow_nondet(self, allow_nondet: bool) -> Self {
62        Self {
63            allow_nondet,
64            ..self
65        }
66    }
67}
68
69impl Default for MinimizeConfig {
70    fn default() -> Self {
71        Self {
72            delta: KSHORTESTDELTA,
73            allow_nondet: false,
74        }
75    }
76}
77
78/// In place minimization of deterministic weighted automata and transducers,
79/// and also non-deterministic ones if they use an idempotent semiring.
80/// For transducers, the algorithm produces a compact factorization of the minimal transducer.
81pub fn minimize<W, F>(ifst: &mut F) -> Result<()>
82where
83    F: MutableFst<W> + ExpandedFst<W> + AllocableFst<W>,
84    W: WeaklyDivisibleSemiring + WeightQuantize,
85    W::ReverseWeight: WeightQuantize,
86{
87    minimize_with_config(ifst, MinimizeConfig::default())
88}
89
90/// In place minimization of deterministic weighted automata and transducers,
91/// and also non-deterministic ones if they use an idempotent semiring.
92/// For transducers, the algorithm produces a compact factorization of the minimal transducer.
93pub fn minimize_with_config<W, F>(ifst: &mut F, config: MinimizeConfig) -> Result<()>
94where
95    F: MutableFst<W> + ExpandedFst<W> + AllocableFst<W>,
96    W: WeaklyDivisibleSemiring + WeightQuantize,
97    W::ReverseWeight: WeightQuantize,
98{
99    let delta = config.delta;
100    let allow_nondet = config.allow_nondet;
101
102    let props = ifst.compute_and_update_properties(
103        FstProperties::ACCEPTOR
104            | FstProperties::I_DETERMINISTIC
105            | FstProperties::WEIGHTED
106            | FstProperties::UNWEIGHTED,
107    )?;
108
109    let allow_acyclic_minimization = if props.contains(FstProperties::I_DETERMINISTIC) {
110        true
111    } else {
112        if !W::properties().contains(SemiringProperties::IDEMPOTENT) {
113            bail!("Cannot minimize a non-deterministic FST over a non-idempotent semiring")
114        } else if !allow_nondet {
115            bail!("Refusing to minimize a non-deterministic FST with allow_nondet = false")
116        }
117
118        false
119    };
120
121    if !props.contains(FstProperties::ACCEPTOR) {
122        // Weighted transducer
123        let mut to_gallic = ToGallicConverter {};
124        let mut gfst: VectorFst<GallicWeightLeft<W>> = weight_convert(ifst, &mut to_gallic)?;
125        let push_weights_config = PushWeightsConfig::default().with_delta(delta);
126        push_weights_with_config(
127            &mut gfst,
128            ReweightType::ReweightToInitial,
129            push_weights_config,
130        )?;
131
132        let quantize_mapper = QuantizeMapper::new(delta);
133        tr_map(&mut gfst, &quantize_mapper)?;
134
135        let encode_table = encode(&mut gfst, EncodeType::EncodeWeightsAndLabels)?;
136
137        acceptor_minimize(&mut gfst, allow_acyclic_minimization)?;
138
139        decode(&mut gfst, encode_table)?;
140
141        let factor_opts: FactorWeightOptions = FactorWeightOptions {
142            delta: KDELTA,
143            mode: FactorWeightType::FACTOR_FINAL_WEIGHTS | FactorWeightType::FACTOR_ARC_WEIGHTS,
144            final_ilabel: 0,
145            final_olabel: 0,
146            increment_final_ilabel: false,
147            increment_final_olabel: false,
148        };
149
150        let fwfst: VectorFst<_> =
151            factor_weight::<_, VectorFst<GallicWeightLeft<W>>, _, _, GallicFactorLeft<W>>(
152                &gfst,
153                factor_opts,
154            )?;
155
156        let mut from_gallic = FromGallicConverter {
157            superfinal_label: EPS_LABEL,
158        };
159        *ifst = weight_convert(&fwfst, &mut from_gallic)?;
160
161        Ok(())
162    } else if props.contains(FstProperties::WEIGHTED) {
163        // Weighted acceptor
164        let push_weights_config = PushWeightsConfig::default().with_delta(delta);
165        push_weights_with_config(ifst, ReweightType::ReweightToInitial, push_weights_config)?;
166        let quantize_mapper = QuantizeMapper::new(delta);
167        tr_map(ifst, &quantize_mapper)?;
168        let encode_table = encode(ifst, EncodeType::EncodeWeightsAndLabels)?;
169        acceptor_minimize(ifst, allow_acyclic_minimization)?;
170        decode(ifst, encode_table)
171    } else {
172        // Unweighted acceptor
173        acceptor_minimize(ifst, allow_acyclic_minimization)
174    }
175}
176
177/// In place minimization for weighted final state acceptor.
178/// If `allow_acyclic_minimization` is true and the input is acyclic, then a specific
179/// minimization is applied.
180///
181/// An error is returned if the input fst is not a weighted acceptor.
182pub fn acceptor_minimize<W: Semiring, F: MutableFst<W> + ExpandedFst<W>>(
183    ifst: &mut F,
184    allow_acyclic_minimization: bool,
185) -> Result<()> {
186    let props = ifst.compute_and_update_properties(
187        FstProperties::ACCEPTOR | FstProperties::UNWEIGHTED | FstProperties::ACYCLIC,
188    )?;
189    if !props.contains(FstProperties::ACCEPTOR | FstProperties::UNWEIGHTED) {
190        bail!("FST is not an unweighted acceptor");
191    }
192
193    connect(ifst)?;
194
195    if ifst.num_states() == 0 {
196        return Ok(());
197    }
198
199    if allow_acyclic_minimization && props.contains(FstProperties::ACYCLIC) {
200        // Acyclic minimization
201        tr_sort(ifst, ILabelCompare {});
202        let minimizer = AcyclicMinimizer::new(ifst)?;
203        merge_states(minimizer.get_partition(), ifst)?;
204    } else {
205        let p = cyclic_minimize(ifst)?;
206        merge_states(p, ifst)?;
207    }
208
209    tr_unique(ifst);
210
211    Ok(())
212}
213
214fn merge_states<W: Semiring, F: MutableFst<W>>(
215    partition: Rc<RefCell<Partition>>,
216    fst: &mut F,
217) -> Result<()> {
218    let mut state_map = vec![None; partition.borrow().num_classes()];
219
220    for (i, state_map_i) in state_map
221        .iter_mut()
222        .enumerate()
223        .take(partition.borrow().num_classes())
224    {
225        *state_map_i = partition.borrow().iter(i).next();
226    }
227
228    for c in 0..partition.borrow().num_classes() {
229        for s in partition.borrow().iter(c) {
230            if s == state_map[c].unwrap() {
231                let mut it_tr = fst.tr_iter_mut(s as StateId)?;
232                for idx_tr in 0..it_tr.len() {
233                    let tr = unsafe { it_tr.get_unchecked(idx_tr) };
234                    let nextstate =
235                        state_map[partition.borrow().get_class_id(tr.nextstate as usize)].unwrap();
236                    unsafe { it_tr.set_nextstate_unchecked(idx_tr, nextstate as StateId) };
237                }
238            } else {
239                for tr in fst
240                    .get_trs(s as StateId)?
241                    .trs()
242                    .iter()
243                    .cloned()
244                    .map(|mut tr| {
245                        tr.nextstate = state_map
246                            [partition.borrow().get_class_id(tr.nextstate as usize)]
247                        .unwrap() as StateId;
248                        tr
249                    })
250                {
251                    fst.add_tr(state_map[c].unwrap() as StateId, tr)?;
252                }
253            }
254        }
255    }
256
257    fst.set_start(
258        state_map[partition
259            .borrow()
260            .get_class_id(fst.start().unwrap() as usize)]
261        .unwrap() as StateId,
262    )?;
263
264    connect(fst)?;
265
266    Ok(())
267}
268
269// Compute the height (distance) to final state
270pub fn fst_depth<W: Semiring, F: Fst<W>>(
271    fst: &F,
272    state_id_cour: StateId,
273    accessible_states: &mut HashSet<StateId>,
274    fully_examined_states: &mut HashSet<StateId>,
275    heights: &mut Vec<i32>,
276) -> Result<()> {
277    accessible_states.insert(state_id_cour);
278
279    for _ in heights.len()..=(state_id_cour as usize) {
280        heights.push(-1);
281    }
282
283    let mut height_cur_state = 0;
284    for tr in fst.get_trs(state_id_cour)?.trs() {
285        let nextstate = tr.nextstate;
286
287        if !accessible_states.contains(&nextstate) {
288            fst_depth(
289                fst,
290                nextstate,
291                accessible_states,
292                fully_examined_states,
293                heights,
294            )?;
295        }
296
297        height_cur_state = max(height_cur_state, 1 + heights[nextstate as usize]);
298    }
299    fully_examined_states.insert(state_id_cour);
300
301    heights[state_id_cour as usize] = height_cur_state;
302
303    Ok(())
304}
305
306struct AcyclicMinimizer {
307    partition: Rc<RefCell<Partition>>,
308}
309
310impl AcyclicMinimizer {
311    pub fn new<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<Self> {
312        let mut c = Self {
313            partition: Rc::new(RefCell::new(Partition::empty_new())),
314        };
315        c.initialize(fst)?;
316        c.refine(fst);
317        Ok(c)
318    }
319
320    fn initialize<W: Semiring, F: MutableFst<W>>(&mut self, fst: &mut F) -> Result<()> {
321        let mut accessible_state = HashSet::new();
322        let mut fully_examined_states = HashSet::new();
323        let mut heights = Vec::new();
324        fst_depth(
325            fst,
326            fst.start().unwrap(),
327            &mut accessible_state,
328            &mut fully_examined_states,
329            &mut heights,
330        )?;
331        self.partition.borrow_mut().initialize(heights.len());
332        self.partition
333            .borrow_mut()
334            .allocate_classes((heights.iter().max().unwrap() + 1) as usize);
335        for (s, h) in heights.iter().enumerate() {
336            self.partition.borrow_mut().add(s, *h as usize);
337        }
338        Ok(())
339    }
340
341    fn refine<W: Semiring, F: MutableFst<W>>(&mut self, fst: &mut F) {
342        let state_cmp = StateComparator {
343            fst,
344            partition: Rc::clone(&self.partition),
345            w: PhantomData,
346        };
347
348        let height = self.partition.borrow().num_classes();
349        for h in 0..height {
350            // We need here a binary search tree in order to order the states id and create a partition.
351            // For now uses the crate `stable_bst` which is quite old but seems to do the job
352            // TODO: Bench the performances of the implementation. Maybe re-write it.
353            let mut equiv_classes =
354                TreeMap::<StateId, StateId, _>::with_comparator(|a: &StateId, b: &StateId| {
355                    state_cmp.compare(*a, *b).unwrap()
356                });
357
358            let it_partition: Vec<_> = self.partition.borrow().iter(h).collect();
359            equiv_classes.insert(it_partition[0] as StateId, h as StateId);
360
361            for e in it_partition.iter().skip(1) {
362                equiv_classes.get_or_insert(*e as StateId, || {
363                    self.partition.borrow_mut().add_class() as StateId
364                });
365            }
366
367            for s in it_partition {
368                let old_class = self.partition.borrow().get_class_id(s);
369                let new_class = *equiv_classes.get(&(s as StateId)).unwrap();
370
371                if old_class != (new_class as usize) {
372                    self.partition
373                        .borrow_mut()
374                        .move_element(s, new_class as usize);
375                }
376            }
377        }
378    }
379
380    pub fn get_partition(self) -> Rc<RefCell<Partition>> {
381        self.partition
382    }
383}
384
385struct StateComparator<'a, W: Semiring, F: MutableFst<W>> {
386    fst: &'a F,
387    partition: Rc<RefCell<Partition>>,
388    w: PhantomData<W>,
389}
390
391impl<'a, W: Semiring, F: MutableFst<W>> StateComparator<'a, W, F> {
392    fn do_compare(&self, x: StateId, y: StateId) -> Result<bool> {
393        let xfinal = self.fst.final_weight(x)?.unwrap_or_else(W::zero);
394        let yfinal = self.fst.final_weight(y)?.unwrap_or_else(W::zero);
395
396        if xfinal < yfinal {
397            return Ok(true);
398        } else if xfinal > yfinal {
399            return Ok(false);
400        }
401
402        if self.fst.num_trs(x)? < self.fst.num_trs(y)? {
403            return Ok(true);
404        }
405        if self.fst.num_trs(x)? > self.fst.num_trs(y)? {
406            return Ok(false);
407        }
408
409        let it_x_owner = self.fst.get_trs(x)?;
410        let it_x = it_x_owner.trs().iter();
411        let it_y_owner = self.fst.get_trs(y)?;
412        let it_y = it_y_owner.trs().iter();
413
414        for (arc1, arc2) in it_x.zip(it_y) {
415            if arc1.ilabel < arc2.ilabel {
416                return Ok(true);
417            }
418            if arc1.ilabel > arc2.ilabel {
419                return Ok(false);
420            }
421            let id_1 = self
422                .partition
423                .borrow()
424                .get_class_id(arc1.nextstate as usize);
425            let id_2 = self
426                .partition
427                .borrow()
428                .get_class_id(arc2.nextstate as usize);
429            if id_1 < id_2 {
430                return Ok(true);
431            }
432            if id_1 > id_2 {
433                return Ok(false);
434            }
435        }
436        Ok(false)
437    }
438
439    pub fn compare(&self, x: StateId, y: StateId) -> Result<Ordering> {
440        if x == y {
441            return Ok(Ordering::Equal);
442        }
443
444        let x_y = self.do_compare(x, y).unwrap();
445        let y_x = self.do_compare(y, x).unwrap();
446
447        if !(x_y) && !(y_x) {
448            return Ok(Ordering::Equal);
449        }
450
451        if x_y {
452            Ok(Ordering::Less)
453        } else {
454            Ok(Ordering::Greater)
455        }
456    }
457}
458
459fn pre_partition<W: Semiring, F: MutableFst<W>>(
460    fst: &F,
461    partition: &Rc<RefCell<Partition>>,
462    queue: &mut LifoQueue,
463) {
464    let mut next_class: StateId = 0;
465    let num_states = fst.num_states();
466
467    let mut state_to_initial_class: Vec<StateId> = vec![0; num_states];
468    {
469        let mut hash_to_class_nonfinal = HashMap::<Vec<Label>, StateId>::new();
470        let mut hash_to_class_final = HashMap::<Vec<Label>, StateId>::new();
471
472        for (s, state_to_initial_class_s) in state_to_initial_class
473            .iter_mut()
474            .enumerate()
475            .take(num_states)
476        {
477            let this_map = if unsafe { fst.is_final_unchecked(s as StateId) } {
478                &mut hash_to_class_final
479            } else {
480                &mut hash_to_class_nonfinal
481            };
482
483            let ilabels = fst
484                .get_trs(s as StateId)
485                .unwrap()
486                .trs()
487                .iter()
488                .map(|e| e.ilabel)
489                .dedup()
490                .collect_vec();
491
492            match this_map.entry(ilabels) {
493                Entry::Occupied(e) => {
494                    *state_to_initial_class_s = *e.get();
495                }
496                Entry::Vacant(e) => {
497                    e.insert(next_class);
498                    *state_to_initial_class_s = next_class;
499                    next_class += 1;
500                }
501            };
502        }
503    }
504
505    partition.borrow_mut().allocate_classes(next_class as usize);
506    for (s, c) in state_to_initial_class.iter().enumerate().take(num_states) {
507        partition.borrow_mut().add(s, *c as usize);
508    }
509
510    for c in 0..next_class {
511        queue.enqueue(c);
512    }
513}
514
515fn cyclic_minimize<W: Semiring, F: MutableFst<W>>(fst: &mut F) -> Result<Rc<RefCell<Partition>>> {
516    // Initialize
517    let mut tr: VectorFst<W::ReverseWeight> = reverse(fst)?;
518    tr_sort(&mut tr, ILabelCompare {});
519
520    let partition = Rc::new(RefCell::new(Partition::new(tr.num_states() - 1)));
521    let mut queue = LifoQueue::default();
522    pre_partition(fst, &partition, &mut queue);
523
524    let comp = TrIterCompare {};
525
526    let mut aiter_queue = BinaryHeap::new_by(|v1, v2| {
527        if comp.compare(v1, v2) {
528            Ordering::Less
529        } else {
530            Ordering::Greater
531        }
532    });
533
534    // Compute
535    while let Some(c) = queue.dequeue() {
536        // Split
537        for s in partition.borrow().iter(c as usize) {
538            if tr.num_trs(s as StateId + 1)? > 0 {
539                aiter_queue.push(TrsIterCollected {
540                    idx: 0,
541                    trs: tr.get_trs(s as StateId + 1)?,
542                    w: PhantomData,
543                });
544            }
545        }
546
547        let mut prev_label = -1;
548        while !aiter_queue.is_empty() {
549            let mut aiter = aiter_queue.pop().unwrap();
550            if aiter.done() {
551                continue;
552            }
553            let tr = aiter.value().unwrap();
554            let from_state = tr.nextstate - 1;
555            let from_label = tr.ilabel;
556            if prev_label != from_label as i32 {
557                partition.borrow_mut().finalize_split(&mut Some(&mut queue));
558            }
559            let from_class = partition.borrow().get_class_id(from_state as usize);
560            if partition.borrow().get_class_size(from_class) > 1 {
561                partition.borrow_mut().split_on(from_state as usize);
562            }
563            prev_label = from_label as i32;
564            aiter.next();
565            if !aiter.done() {
566                aiter_queue.push(aiter);
567            }
568        }
569
570        partition.borrow_mut().finalize_split(&mut Some(&mut queue));
571    }
572
573    // Get Partition
574    Ok(partition)
575}
576
577struct TrsIterCollected<W: Semiring, T: Trs<W>> {
578    idx: usize,
579    trs: T,
580    w: PhantomData<W>,
581}
582
583impl<W: Semiring, T: Trs<W>> TrsIterCollected<W, T> {
584    fn value(&self) -> Option<&Tr<W>> {
585        self.trs.trs().get(self.idx)
586    }
587
588    fn done(&self) -> bool {
589        self.idx >= self.trs.len()
590    }
591
592    fn next(&mut self) {
593        self.idx += 1;
594    }
595}
596
597#[derive(Debug, Clone)]
598struct TrIterCompare {}
599
600impl TrIterCompare {
601    fn compare<W: Semiring, T: Trs<W>>(
602        &self,
603        x: &TrsIterCollected<W, T>,
604        y: &TrsIterCollected<W, T>,
605    ) -> bool {
606        let xarc = x.value().unwrap();
607        let yarc = y.value().unwrap();
608        xarc.ilabel > yarc.ilabel
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use crate::prelude::*;
615    use ::proptest::prelude::*;
616    use algorithms::determinize::*;
617    use std::sync::Arc;
618
619    #[test]
620    fn test_minimize_issue_158() {
621        let text_fst = r#"0	5	101	101	0
6220	4	100	100	0
6230	3	99	99	0
6240	2	98	98	0
6250	1	97	97	0
6261	10	101	101	0
6271	9	100	100	0
6281	8	99	99	0
6291	7	98	98	0
6301	6	97	97	0
6312	11	101	101	0
6322	10	100	100	0
6332	9	99	99	0
6342	8	98	98	0
6352	7	97	97	0
6363	11	100	100	0
6373	10	99	99	0
6383	9	98	98	0
6393	8	97	97	0
6404	11	99	99	0
6414	10	98	98	0
6424	9	97	97	0
6435	11	98	98	0
6445	10	97	97	0
6456	15	101	101	0
6466	14	100	100	0
6476	13	99	99	0
6486	12	98	98	0
6497	16	101	101	0
6507	15	100	100	0
6517	14	99	99	0
6527	13	98	98	0
6537	12	97	97	0
6548	16	100	100	0
6558	15	99	99	0
6568	14	98	98	0
6578	13	97	97	0
6589	16	99	99	0
6599	15	98	98	0
6609	14	97	97	0
66110	16	98	98	0
66210	15	97	97	0
66311	16	97	97	0
66412	17	101	101	0
66513	17	100	100	0
66614	17	99	99	0
66715	17	98	98	0
66816	17	97	97	0
66917	18	32	32	0
67018	0
671        "#;
672        let path = fst_path![97, 98, 97, 100, 32];
673        let mut fst: VectorFst<TropicalWeight> = VectorFst::from_text_string(text_fst).unwrap();
674        let accept1 = check_path_in_fst(&fst, &path);
675        minimize(&mut fst).unwrap();
676        let accept2 = check_path_in_fst(&fst, &path);
677
678        assert_eq!(accept1, accept2);
679    }
680
681    proptest! {
682        #[test]
683        fn test_proptest_minimize_timeout(mut fst in any::<VectorFst::<TropicalWeight>>()) {
684            let config = MinimizeConfig::default().with_allow_nondet(true);
685            minimize_with_config(&mut fst, config).unwrap();
686        }
687    }
688
689    proptest! {
690        #[test]
691        #[ignore] // falls into the same infinite loop as the timeout test
692        fn test_minimize_proptest(mut fst in any::<VectorFst::<TropicalWeight>>()) {
693            let det:VectorFst<_> = determinize_with_config(&fst, DeterminizeConfig::default().with_det_type(DeterminizeType::DeterminizeNonFunctional)).unwrap();
694            let min_config = MinimizeConfig::default().with_allow_nondet(true);
695            minimize_with_config(&mut fst, min_config).unwrap();
696            let det_config = DeterminizeConfig::default().with_det_type(DeterminizeType::DeterminizeNonFunctional);
697            let min_det:VectorFst<_> = determinize_with_config(&fst, det_config).unwrap();
698            prop_assert!(isomorphic(&det, &min_det).unwrap())
699        }
700    }
701
702    proptest! {
703        #[test]
704        fn test_proptest_minimize_keeps_symts(mut fst in any::<VectorFst::<TropicalWeight>>()) {
705            let symt = Arc::new(SymbolTable::new());
706            fst.set_input_symbols(Arc::clone(&symt));
707            fst.set_output_symbols(Arc::clone(&symt));
708
709            minimize_with_config(&mut fst, MinimizeConfig::default().with_allow_nondet(true)).unwrap();
710
711            assert!(fst.input_symbols().is_some());
712            assert!(fst.output_symbols().is_some());
713        }
714    }
715}