rustfst/algorithms/compose/matchers/
multi_eps_matcher.rs1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::iter::Peekable;
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7use anyhow::Result;
8use itertools::Itertools;
9use nom::lib::std::collections::BTreeSet;
10
11use bitflags::bitflags;
12
13use crate::algorithms::compose::matchers::{IterItemMatcher, MatchType, Matcher, MatcherFlags};
14use crate::fst_traits::Fst;
15use crate::semirings::Semiring;
16use crate::{Label, StateId, EPS_LABEL, NO_LABEL};
17
18bitflags! {
19 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20 pub struct MultiEpsMatcherFlags: u32 {
21 const MULTI_EPS_LOOP = 1u32;
22 const MULTI_EPS_LIST = 2u32;
23 }
24}
25
26#[derive(Clone, Debug)]
27pub struct MultiEpsMatcher<W, F, B, M>
28where
29 W: Semiring,
30 F: Fst<W>,
31 B: Borrow<F> + Debug,
32 M: Matcher<W, F, B>,
33{
34 matcher: Arc<M>,
35 flags: MultiEpsMatcherFlags,
36 multi_eps_labels: CompactSet<Label>,
37 ghost: PhantomData<(W, F, B)>,
38}
39
40pub struct IteratorMultiEpsMatcher<W, F, B, M>
41where
42 W: Semiring,
43 F: Fst<W>,
44 B: Borrow<F> + Debug,
45 M: Matcher<W, F, B>,
46{
47 iter_matcher: Option<Peekable<M::Iter>>,
48 iter_labels: Option<(Vec<Label>, usize)>,
49 matcher: Arc<M>,
50 matcher_state: StateId,
51 done: bool,
52 ghost: PhantomData<W>,
53}
54
55impl<W, F, B, M> Clone for IteratorMultiEpsMatcher<W, F, B, M>
56where
57 W: Semiring,
58 F: Fst<W>,
59 B: Borrow<F> + Debug,
60 M: Matcher<W, F, B>,
61{
62 fn clone(&self) -> Self {
63 unimplemented!()
64 }
73}
74
75impl<W, F, B, M> Iterator for IteratorMultiEpsMatcher<W, F, B, M>
76where
77 W: Semiring,
78 F: Fst<W>,
79 B: Borrow<F> + Debug,
80 M: Matcher<W, F, B>,
81{
82 type Item = IterItemMatcher<W>;
83
84 fn next(&mut self) -> Option<Self::Item> {
85 if let Some(ref mut matcher_iter) = &mut self.iter_matcher {
86 let res = matcher_iter.next();
87 let done = res.is_none();
88 if done {
89 if let Some((multi_eps_labels, pos_labels)) = &mut self.iter_labels {
90 if *pos_labels >= multi_eps_labels.len() {
91 return None;
92 }
93 *pos_labels += 1;
94 while *pos_labels < multi_eps_labels.len() {
95 let mut it = self
96 .matcher
97 .iter(self.matcher_state, multi_eps_labels[*pos_labels] as Label)
98 .unwrap()
99 .peekable();
100 if it.peek().is_some() {
101 *matcher_iter = it;
102 break;
103 }
104 *pos_labels += 1;
105 }
106 if *pos_labels < multi_eps_labels.len() {
107 self.done = false;
108 res
109 } else {
110 *matcher_iter = self
111 .matcher
112 .iter(self.matcher_state, NO_LABEL)
113 .unwrap()
114 .peekable();
115 matcher_iter.next()
116 }
117 } else {
118 res
119 }
120 } else {
121 res
122 }
123 } else if self.done {
124 None
125 } else {
126 self.done = true;
127 Some(IterItemMatcher::EpsLoop)
128 }
129 }
130}
131
132impl<W, F, B, M> MultiEpsMatcher<W, F, B, M>
133where
134 W: Semiring,
135 F: Fst<W>,
136 B: Borrow<F> + Debug,
137 M: Matcher<W, F, B>,
138{
139 pub fn new_with_opts<IM: Into<Option<Arc<M>>>>(
140 fst: B,
141 match_type: MatchType,
142 flags: MultiEpsMatcherFlags,
143 matcher: IM,
144 ) -> Result<Self> {
145 let matcher = matcher
146 .into()
147 .unwrap_or_else(|| Arc::new(M::new(fst, match_type).unwrap()));
148 Ok(Self {
149 matcher,
150 flags,
151 multi_eps_labels: CompactSet::new(NO_LABEL),
152 ghost: PhantomData,
153 })
154 }
155
156 pub fn matcher(&self) -> &Arc<M> {
157 &self.matcher
158 }
159
160 pub fn clear_multi_eps_labels(&mut self) {
161 self.multi_eps_labels.clear()
162 }
163
164 pub fn add_multi_eps_label(&mut self, label: Label) -> Result<()> {
165 if label == EPS_LABEL {
166 bail!("MultiEpsMatcher: Bad multi-eps label: 0")
167 }
168 self.multi_eps_labels.insert(label);
169 Ok(())
170 }
171
172 pub fn remove_multi_eps_label(&mut self, label: Label) -> Result<()> {
173 if label == EPS_LABEL {
174 bail!("MultiEpsMatcher: Bad multi-eps label: 0")
175 }
176 self.multi_eps_labels.erase(label);
177 Ok(())
178 }
179}
180
181impl<W, F, B, M> Matcher<W, F, B> for MultiEpsMatcher<W, F, B, M>
182where
183 W: Semiring,
184 F: Fst<W>,
185 B: Borrow<F> + Debug,
186 M: Matcher<W, F, B>,
187{
188 type Iter = IteratorMultiEpsMatcher<W, F, B, M>;
189
190 fn new(fst: B, match_type: MatchType) -> Result<Self> {
191 Self::new_with_opts(
192 fst,
193 match_type,
194 MultiEpsMatcherFlags::MULTI_EPS_LOOP | MultiEpsMatcherFlags::MULTI_EPS_LIST,
195 None,
196 )
197 }
198
199 fn iter(&self, state: StateId, label: Label) -> Result<Self::Iter> {
200 let (iter_matcher, iter_labels) = if label == EPS_LABEL {
201 (Some(self.matcher.iter(state, EPS_LABEL)?.peekable()), None)
202 } else if label == NO_LABEL {
203 if self.flags.contains(MultiEpsMatcherFlags::MULTI_EPS_LIST) {
204 let multi_eps_labels = self.multi_eps_labels.iter().cloned().collect_vec();
206
207 let mut iter_matcher = None;
208 let mut pos_labels = 0;
209 while pos_labels < multi_eps_labels.len() {
210 let mut it = self
211 .matcher
212 .iter(state, multi_eps_labels[pos_labels])?
213 .peekable();
214 if it.peek().is_some() {
215 iter_matcher = Some(it);
216 break;
217 }
218 pos_labels += 1;
219 }
220
221 if pos_labels < multi_eps_labels.len() {
222 (iter_matcher, Some((multi_eps_labels, pos_labels)))
223 } else {
224 (Some(self.matcher.iter(state, NO_LABEL)?.peekable()), None)
225 }
226 } else {
227 (Some(self.matcher.iter(state, NO_LABEL)?.peekable()), None)
228 }
229 } else if self.flags.contains(MultiEpsMatcherFlags::MULTI_EPS_LOOP)
230 && self.multi_eps_labels.contains(&label)
231 {
232 (None, None)
234 } else {
235 (Some(self.matcher.iter(state, label)?.peekable()), None)
236 };
237 Ok(IteratorMultiEpsMatcher {
238 iter_matcher,
239 iter_labels,
240 matcher: Arc::clone(&self.matcher),
241 ghost: PhantomData,
242 done: false,
243 matcher_state: state,
244 })
245 }
246
247 fn final_weight(&self, state: StateId) -> Result<Option<W>> {
248 self.matcher.final_weight(state)
249 }
250
251 fn match_type(&self, test: bool) -> Result<MatchType> {
252 self.matcher.match_type(test)
253 }
254
255 fn flags(&self) -> MatcherFlags {
256 self.matcher.flags()
257 }
258
259 fn priority(&self, state: StateId) -> Result<usize> {
260 self.matcher.priority(state)
261 }
262
263 fn fst(&self) -> &B {
264 self.matcher.fst()
265 }
266}
267
268trait CompactSetKey: Copy + Ord {
269 fn add(self, v: usize) -> Self;
270 fn sub(self, v: usize) -> Self;
271}
272
273impl CompactSetKey for usize {
274 fn add(self, v: usize) -> Self {
275 self + v
276 }
277 fn sub(self, v: usize) -> Self {
278 self - v
279 }
280}
281
282impl CompactSetKey for u32 {
283 fn add(self, v: usize) -> Self {
284 self + v as u32
285 }
286 fn sub(self, v: usize) -> Self {
287 self - v as u32
288 }
289}
290
291#[derive(Clone, Debug)]
292struct CompactSet<K> {
293 set: BTreeSet<K>,
294 min_key: K,
295 max_key: K,
296 no_key: K,
297}
298
299impl<K: Copy + Ord> CompactSet<K> {
300 pub fn new(no_key: K) -> Self {
301 Self {
302 set: BTreeSet::new(),
303 min_key: no_key,
304 max_key: no_key,
305 no_key,
306 }
307 }
308
309 pub fn insert(&mut self, key: K) {
310 self.set.insert(key);
311 if self.min_key == self.no_key || key < self.min_key {
312 self.min_key = key;
313 }
314 if self.max_key == self.no_key || self.max_key > key {
315 self.max_key = key;
316 }
317 }
318
319 pub fn clear(&mut self) {
320 self.set.clear();
321 self.min_key = self.no_key;
322 self.max_key = self.no_key;
323 }
324
325 pub fn iter(&self) -> std::collections::btree_set::Iter<K> {
326 self.set.iter()
327 }
328
329 #[allow(unused)]
330 pub fn lower_bound(&self) -> K {
331 self.min_key
332 }
333
334 #[allow(unused)]
335 pub fn upper_bound(&self) -> K {
336 self.max_key
337 }
338}
339
340impl<K: CompactSetKey> CompactSet<K> {
341 pub fn erase(&mut self, key: K) {
342 self.set.remove(&key);
343 if self.set.is_empty() {
344 self.min_key = self.no_key;
345 self.max_key = self.no_key;
346 } else if key == self.min_key {
347 self.min_key = self.min_key.add(1);
348 } else if key == self.max_key {
349 self.max_key = self.max_key.sub(1);
350 }
351 }
352
353 pub fn contains(&self, key: &K) -> bool {
354 if self.min_key == self.no_key || *key < self.min_key || *key > self.max_key {
355 false
357 } else if self.min_key != self.no_key
358 && self.max_key.add(1) as K == self.min_key.add(self.set.len())
359 {
360 true
362 } else {
363 self.set.contains(key)
364 }
365 }
366}