rustfst/algorithms/queues/
auto_queue.rs

1use anyhow::Result;
2
3use crate::algorithms::dfs_visit::dfs_visit;
4use crate::algorithms::tr_filters::TrFilter;
5use crate::algorithms::visitors::SccVisitor;
6use crate::algorithms::{Queue, QueueType};
7use crate::fst_properties::FstProperties;
8use crate::fst_traits::ExpandedFst;
9use crate::semirings::{Semiring, SemiringProperties};
10
11use super::{
12    natural_less, FifoQueue, LifoQueue, NaturalShortestFirstQueue, SccQueue, StateOrderQueue,
13    TopOrderQueue, TrivialQueue,
14};
15use crate::{StateId, Trs};
16
17#[derive(Debug)]
18pub struct AutoQueue {
19    queue: Box<dyn Queue>,
20}
21
22impl AutoQueue {
23    pub fn new<W: Semiring, F: ExpandedFst<W>, A: TrFilter<W>>(
24        fst: &F,
25        distance: Option<&Vec<W>>,
26        tr_filter: &A,
27    ) -> Result<Self> {
28        let props = fst.properties();
29
30        let queue: Box<dyn Queue>;
31
32        if props.contains(FstProperties::TOP_SORTED) || fst.start().is_none() {
33            queue = Box::<StateOrderQueue>::default();
34        } else if props.contains(FstProperties::ACYCLIC) {
35            queue = Box::new(TopOrderQueue::new(fst, tr_filter));
36        } else if props.contains(FstProperties::UNWEIGHTED)
37            && W::properties().contains(SemiringProperties::IDEMPOTENT)
38        {
39            queue = Box::<LifoQueue>::default();
40        } else {
41            let mut scc_visitor = SccVisitor::new(fst, true, false);
42            dfs_visit(fst, &mut scc_visitor, tr_filter, false);
43            let sccs: Vec<_> = scc_visitor
44                .scc
45                .unwrap()
46                .into_iter()
47                .map(|v| v as StateId)
48                .collect();
49            let n_sccs = scc_visitor.nscc as usize;
50
51            let mut queue_types = vec![QueueType::TrivialQueue; n_sccs];
52            let less = if distance.is_some()
53                && !distance.unwrap().is_empty()
54                && W::properties().contains(SemiringProperties::PATH)
55            {
56                Some(natural_less)
57            } else {
58                None
59            };
60
61            // Finds the queue type to use per SCC.
62            let mut unweighted = false;
63            let mut all_trivial = false;
64            Self::scc_queue_type(
65                fst,
66                &sccs,
67                less,
68                &mut queue_types,
69                &mut all_trivial,
70                &mut unweighted,
71                tr_filter,
72            )?;
73
74            if unweighted {
75                // If unweighted and semiring is idempotent, uses LIFO queue.
76                queue = Box::<LifoQueue>::default();
77            } else if all_trivial {
78                // If all the SCC are trivial, the FST is acyclic and the scc number gives
79                // the topological order.
80                queue = Box::new(TopOrderQueue::from_precomputed_order(sccs));
81            } else {
82                // AutoQueue: using SCC meta-discipline
83                let mut queues: Vec<Box<dyn Queue>> = Vec::with_capacity(n_sccs);
84                for queue_type in queue_types.iter().take(n_sccs) {
85                    match queue_type {
86                        QueueType::TrivialQueue => queues.push(Box::<TrivialQueue>::default()),
87                        QueueType::ShortestFirstQueue => queues.push(Box::new(
88                            NaturalShortestFirstQueue::new(distance.unwrap().clone()),
89                        )),
90                        QueueType::LifoQueue => queues.push(Box::<LifoQueue>::default()),
91                        _ => queues.push(Box::<FifoQueue>::default()),
92                    }
93                }
94                queue = Box::new(SccQueue::new(queues, sccs));
95            }
96        }
97
98        Ok(Self { queue })
99    }
100
101    pub fn scc_queue_type<
102        W: Semiring,
103        F: ExpandedFst<W>,
104        C: Fn(&W, &W) -> Result<bool>,
105        A: TrFilter<W>,
106    >(
107        fst: &F,
108        sccs: &[StateId],
109        compare: Option<C>,
110        queue_types: &mut [QueueType],
111        all_trivial: &mut bool,
112        unweighted: &mut bool,
113        tr_filter: &A,
114    ) -> Result<()> {
115        *all_trivial = true;
116        *unweighted = true;
117
118        queue_types
119            .iter_mut()
120            .for_each(|v| *v = QueueType::TrivialQueue);
121
122        for state in 0..(fst.num_states() as StateId) {
123            for tr in unsafe { fst.get_trs_unchecked(state).trs() } {
124                if !tr_filter.keep(tr) {
125                    continue;
126                }
127                if sccs[state as usize] == sccs[tr.nextstate as usize] {
128                    let queue_type =
129                        unsafe { queue_types.get_unchecked_mut(sccs[state as usize] as usize) };
130                    if compare.is_none() || compare.as_ref().unwrap()(&tr.weight, &W::one())? {
131                        *queue_type = QueueType::FifoQueue;
132                    } else if *queue_type == QueueType::TrivialQueue
133                        || *queue_type == QueueType::LifoQueue
134                    {
135                        if !W::properties().contains(SemiringProperties::IDEMPOTENT)
136                            || (!tr.weight.is_zero() && !tr.weight.is_one())
137                        {
138                            *queue_type = QueueType::ShortestFirstQueue;
139                        } else {
140                            *queue_type = QueueType::LifoQueue;
141                        }
142                    }
143
144                    if *queue_type != QueueType::TrivialQueue {
145                        *all_trivial = false;
146                    }
147                }
148
149                if !W::properties().contains(SemiringProperties::IDEMPOTENT)
150                    || (!tr.weight.is_zero() && !tr.weight.is_one())
151                {
152                    *unweighted = false;
153                }
154            }
155        }
156        Ok(())
157    }
158}
159
160impl Queue for AutoQueue {
161    fn head(&mut self) -> Option<StateId> {
162        self.queue.head()
163    }
164
165    fn enqueue(&mut self, state: StateId) {
166        self.queue.enqueue(state)
167    }
168
169    fn dequeue(&mut self) -> Option<StateId> {
170        self.queue.dequeue()
171    }
172
173    fn update(&mut self, state: StateId) {
174        self.queue.update(state)
175    }
176
177    fn is_empty(&self) -> bool {
178        self.queue.is_empty()
179    }
180
181    fn clear(&mut self) {
182        self.queue.clear()
183    }
184
185    fn queue_type(&self) -> QueueType {
186        QueueType::AutoQueue
187    }
188}