rustfst/algorithms/compose/matchers/
multi_eps_matcher.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::iter::Peekable;
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7use anyhow::Result;
8use itertools::Itertools;
9use nom::lib::std::collections::BTreeSet;
10
11use bitflags::bitflags;
12
13use crate::algorithms::compose::matchers::{IterItemMatcher, MatchType, Matcher, MatcherFlags};
14use crate::fst_traits::Fst;
15use crate::semirings::Semiring;
16use crate::{Label, StateId, EPS_LABEL, NO_LABEL};
17
18bitflags! {
19    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20    pub struct MultiEpsMatcherFlags: u32 {
21        const MULTI_EPS_LOOP =  1u32;
22        const MULTI_EPS_LIST =  2u32;
23    }
24}
25
26#[derive(Clone, Debug)]
27pub struct MultiEpsMatcher<W, F, B, M>
28where
29    W: Semiring,
30    F: Fst<W>,
31    B: Borrow<F> + Debug,
32    M: Matcher<W, F, B>,
33{
34    matcher: Arc<M>,
35    flags: MultiEpsMatcherFlags,
36    multi_eps_labels: CompactSet<Label>,
37    ghost: PhantomData<(W, F, B)>,
38}
39
40pub struct IteratorMultiEpsMatcher<W, F, B, M>
41where
42    W: Semiring,
43    F: Fst<W>,
44    B: Borrow<F> + Debug,
45    M: Matcher<W, F, B>,
46{
47    iter_matcher: Option<Peekable<M::Iter>>,
48    iter_labels: Option<(Vec<Label>, usize)>,
49    matcher: Arc<M>,
50    matcher_state: StateId,
51    done: bool,
52    ghost: PhantomData<W>,
53}
54
55impl<W, F, B, M> Clone for IteratorMultiEpsMatcher<W, F, B, M>
56where
57    W: Semiring,
58    F: Fst<W>,
59    B: Borrow<F> + Debug,
60    M: Matcher<W, F, B>,
61{
62    fn clone(&self) -> Self {
63        unimplemented!()
64        // Self {
65        //     iter_matcher: self.iter_matcher.clone(),
66        //     iter_labels: self.iter_labels.clone(),
67        //     matcher: Arc::clone(&self.matcher),
68        //     ghost: PhantomData,
69        //     done: self.done,
70        //     matcher_state: self.matcher_state,
71        // }
72    }
73}
74
75impl<W, F, B, M> Iterator for IteratorMultiEpsMatcher<W, F, B, M>
76where
77    W: Semiring,
78    F: Fst<W>,
79    B: Borrow<F> + Debug,
80    M: Matcher<W, F, B>,
81{
82    type Item = IterItemMatcher<W>;
83
84    fn next(&mut self) -> Option<Self::Item> {
85        if let Some(ref mut matcher_iter) = &mut self.iter_matcher {
86            let res = matcher_iter.next();
87            let done = res.is_none();
88            if done {
89                if let Some((multi_eps_labels, pos_labels)) = &mut self.iter_labels {
90                    if *pos_labels >= multi_eps_labels.len() {
91                        return None;
92                    }
93                    *pos_labels += 1;
94                    while *pos_labels < multi_eps_labels.len() {
95                        let mut it = self
96                            .matcher
97                            .iter(self.matcher_state, multi_eps_labels[*pos_labels] as Label)
98                            .unwrap()
99                            .peekable();
100                        if it.peek().is_some() {
101                            *matcher_iter = it;
102                            break;
103                        }
104                        *pos_labels += 1;
105                    }
106                    if *pos_labels < multi_eps_labels.len() {
107                        self.done = false;
108                        res
109                    } else {
110                        *matcher_iter = self
111                            .matcher
112                            .iter(self.matcher_state, NO_LABEL)
113                            .unwrap()
114                            .peekable();
115                        matcher_iter.next()
116                    }
117                } else {
118                    res
119                }
120            } else {
121                res
122            }
123        } else if self.done {
124            None
125        } else {
126            self.done = true;
127            Some(IterItemMatcher::EpsLoop)
128        }
129    }
130}
131
132impl<W, F, B, M> MultiEpsMatcher<W, F, B, M>
133where
134    W: Semiring,
135    F: Fst<W>,
136    B: Borrow<F> + Debug,
137    M: Matcher<W, F, B>,
138{
139    pub fn new_with_opts<IM: Into<Option<Arc<M>>>>(
140        fst: B,
141        match_type: MatchType,
142        flags: MultiEpsMatcherFlags,
143        matcher: IM,
144    ) -> Result<Self> {
145        let matcher = matcher
146            .into()
147            .unwrap_or_else(|| Arc::new(M::new(fst, match_type).unwrap()));
148        Ok(Self {
149            matcher,
150            flags,
151            multi_eps_labels: CompactSet::new(NO_LABEL),
152            ghost: PhantomData,
153        })
154    }
155
156    pub fn matcher(&self) -> &Arc<M> {
157        &self.matcher
158    }
159
160    pub fn clear_multi_eps_labels(&mut self) {
161        self.multi_eps_labels.clear()
162    }
163
164    pub fn add_multi_eps_label(&mut self, label: Label) -> Result<()> {
165        if label == EPS_LABEL {
166            bail!("MultiEpsMatcher: Bad multi-eps label: 0")
167        }
168        self.multi_eps_labels.insert(label);
169        Ok(())
170    }
171
172    pub fn remove_multi_eps_label(&mut self, label: Label) -> Result<()> {
173        if label == EPS_LABEL {
174            bail!("MultiEpsMatcher: Bad multi-eps label: 0")
175        }
176        self.multi_eps_labels.erase(label);
177        Ok(())
178    }
179}
180
181impl<W, F, B, M> Matcher<W, F, B> for MultiEpsMatcher<W, F, B, M>
182where
183    W: Semiring,
184    F: Fst<W>,
185    B: Borrow<F> + Debug,
186    M: Matcher<W, F, B>,
187{
188    type Iter = IteratorMultiEpsMatcher<W, F, B, M>;
189
190    fn new(fst: B, match_type: MatchType) -> Result<Self> {
191        Self::new_with_opts(
192            fst,
193            match_type,
194            MultiEpsMatcherFlags::MULTI_EPS_LOOP | MultiEpsMatcherFlags::MULTI_EPS_LIST,
195            None,
196        )
197    }
198
199    fn iter(&self, state: StateId, label: Label) -> Result<Self::Iter> {
200        let (iter_matcher, iter_labels) = if label == EPS_LABEL {
201            (Some(self.matcher.iter(state, EPS_LABEL)?.peekable()), None)
202        } else if label == NO_LABEL {
203            if self.flags.contains(MultiEpsMatcherFlags::MULTI_EPS_LIST) {
204                // TODO: Didn't find a way to store the iterator in IteratorMultiEpsMatcher.
205                let multi_eps_labels = self.multi_eps_labels.iter().cloned().collect_vec();
206
207                let mut iter_matcher = None;
208                let mut pos_labels = 0;
209                while pos_labels < multi_eps_labels.len() {
210                    let mut it = self
211                        .matcher
212                        .iter(state, multi_eps_labels[pos_labels])?
213                        .peekable();
214                    if it.peek().is_some() {
215                        iter_matcher = Some(it);
216                        break;
217                    }
218                    pos_labels += 1;
219                }
220
221                if pos_labels < multi_eps_labels.len() {
222                    (iter_matcher, Some((multi_eps_labels, pos_labels)))
223                } else {
224                    (Some(self.matcher.iter(state, NO_LABEL)?.peekable()), None)
225                }
226            } else {
227                (Some(self.matcher.iter(state, NO_LABEL)?.peekable()), None)
228            }
229        } else if self.flags.contains(MultiEpsMatcherFlags::MULTI_EPS_LOOP)
230            && self.multi_eps_labels.contains(&label)
231        {
232            // Empty iter
233            (None, None)
234        } else {
235            (Some(self.matcher.iter(state, label)?.peekable()), None)
236        };
237        Ok(IteratorMultiEpsMatcher {
238            iter_matcher,
239            iter_labels,
240            matcher: Arc::clone(&self.matcher),
241            ghost: PhantomData,
242            done: false,
243            matcher_state: state,
244        })
245    }
246
247    fn final_weight(&self, state: StateId) -> Result<Option<W>> {
248        self.matcher.final_weight(state)
249    }
250
251    fn match_type(&self, test: bool) -> Result<MatchType> {
252        self.matcher.match_type(test)
253    }
254
255    fn flags(&self) -> MatcherFlags {
256        self.matcher.flags()
257    }
258
259    fn priority(&self, state: StateId) -> Result<usize> {
260        self.matcher.priority(state)
261    }
262
263    fn fst(&self) -> &B {
264        self.matcher.fst()
265    }
266}
267
268trait CompactSetKey: Copy + Ord {
269    fn add(self, v: usize) -> Self;
270    fn sub(self, v: usize) -> Self;
271}
272
273impl CompactSetKey for usize {
274    fn add(self, v: usize) -> Self {
275        self + v
276    }
277    fn sub(self, v: usize) -> Self {
278        self - v
279    }
280}
281
282impl CompactSetKey for u32 {
283    fn add(self, v: usize) -> Self {
284        self + v as u32
285    }
286    fn sub(self, v: usize) -> Self {
287        self - v as u32
288    }
289}
290
291#[derive(Clone, Debug)]
292struct CompactSet<K> {
293    set: BTreeSet<K>,
294    min_key: K,
295    max_key: K,
296    no_key: K,
297}
298
299impl<K: Copy + Ord> CompactSet<K> {
300    pub fn new(no_key: K) -> Self {
301        Self {
302            set: BTreeSet::new(),
303            min_key: no_key,
304            max_key: no_key,
305            no_key,
306        }
307    }
308
309    pub fn insert(&mut self, key: K) {
310        self.set.insert(key);
311        if self.min_key == self.no_key || key < self.min_key {
312            self.min_key = key;
313        }
314        if self.max_key == self.no_key || self.max_key > key {
315            self.max_key = key;
316        }
317    }
318
319    pub fn clear(&mut self) {
320        self.set.clear();
321        self.min_key = self.no_key;
322        self.max_key = self.no_key;
323    }
324
325    pub fn iter(&self) -> std::collections::btree_set::Iter<K> {
326        self.set.iter()
327    }
328
329    #[allow(unused)]
330    pub fn lower_bound(&self) -> K {
331        self.min_key
332    }
333
334    #[allow(unused)]
335    pub fn upper_bound(&self) -> K {
336        self.max_key
337    }
338}
339
340impl<K: CompactSetKey> CompactSet<K> {
341    pub fn erase(&mut self, key: K) {
342        self.set.remove(&key);
343        if self.set.is_empty() {
344            self.min_key = self.no_key;
345            self.max_key = self.no_key;
346        } else if key == self.min_key {
347            self.min_key = self.min_key.add(1);
348        } else if key == self.max_key {
349            self.max_key = self.max_key.sub(1);
350        }
351    }
352
353    pub fn contains(&self, key: &K) -> bool {
354        if self.min_key == self.no_key || *key < self.min_key || *key > self.max_key {
355            // out of range
356            false
357        } else if self.min_key != self.no_key
358            && self.max_key.add(1) as K == self.min_key.add(self.set.len())
359        {
360            // dense range
361            true
362        } else {
363            self.set.contains(key)
364        }
365    }
366}