rustfst/algorithms/compose/lookahead_matchers/
tr_lookahead_matcher.rs1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use anyhow::Result;
7use unsafe_unwrap::UnsafeUnwrap;
8
9use crate::algorithms::compose::lookahead_matchers::{
10 LookAheadMatcherData, LookaheadMatcher, MatcherFlagsTrait,
11};
12use crate::algorithms::compose::matchers::{IterItemMatcher, MatchType, Matcher, MatcherFlags};
13use crate::fst_traits::Fst;
14use crate::semirings::Semiring;
15use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL};
16
17#[derive(Debug, Clone)]
18pub struct TrLookAheadMatcher<W, F, B, M, MFT>
19where
20 W: Semiring,
21 F: Fst<W>,
22 B: Borrow<F>,
23 M: Matcher<W, F, B>,
24{
25 fst: B,
27 matcher: M,
28 ghost: PhantomData<(W, F, MFT)>,
29}
30
31impl<W, F, B, M, MFT> Matcher<W, F, B> for TrLookAheadMatcher<W, F, B, M, MFT>
32where
33 W: Semiring,
34 F: Fst<W>,
35 B: Borrow<F> + Debug + Clone,
36 M: Matcher<W, F, B>,
37 MFT: MatcherFlagsTrait,
38{
39 type Iter = M::Iter;
40
41 fn new(fst: B, match_type: MatchType) -> Result<Self> {
42 Ok(Self {
43 matcher: M::new(fst.clone(), match_type)?,
44 fst,
45 ghost: PhantomData,
46 })
47 }
48
49 fn iter(&self, state: StateId, label: Label) -> Result<Self::Iter> {
50 self.matcher.iter(state, label)
51 }
52
53 fn final_weight(&self, state: StateId) -> Result<Option<W>> {
54 self.matcher.final_weight(state)
55 }
56
57 fn match_type(&self, test: bool) -> Result<MatchType> {
58 self.matcher.match_type(test)
59 }
60
61 fn flags(&self) -> MatcherFlags {
62 self.matcher.flags()
63 | MatcherFlags::INPUT_LOOKAHEAD_MATCHER
64 | MatcherFlags::OUTPUT_LOOKAHEAD_MATCHER
65 | MFT::flags()
66 }
67
68 fn priority(&self, state: StateId) -> Result<usize> {
69 self.matcher.priority(state)
70 }
71
72 fn fst(&self) -> &B {
73 &self.fst
74 }
75}
76
77impl<W, F, B, M, MFT> LookaheadMatcher<W, F, B> for TrLookAheadMatcher<W, F, B, M, MFT>
78where
79 W: Semiring,
80 F: Fst<W>,
81 B: Borrow<F> + Debug + Clone,
82 M: Matcher<W, F, B>,
83 MFT: MatcherFlagsTrait,
84{
85 type MatcherData = ();
87
88 fn data(&self) -> Option<&Arc<Self::MatcherData>> {
89 None
90 }
91
92 fn new_with_data(
93 fst: B,
94 match_type: MatchType,
95 _data: Option<Arc<Self::MatcherData>>,
96 ) -> Result<Self> {
97 Self::new(fst, match_type)
98 }
99
100 fn create_data<F2: Fst<W>, BF2: Borrow<F2>>(
101 _fst: BF2,
102 _match_type: MatchType,
103 ) -> Result<Option<Self::MatcherData>> {
104 Ok(None)
105 }
106
107 fn init_lookahead_fst<LF: Fst<W>, BLF: Borrow<LF> + Clone>(
108 &mut self,
109 _lfst: &BLF,
110 ) -> Result<()> {
111 Ok(())
112 }
113
114 fn lookahead_fst<LF: Fst<W>, BLF: Borrow<LF>>(
115 &self,
116 matcher_state: StateId,
117 lfst: &BLF,
118 lfst_state: StateId,
119 ) -> Result<Option<LookAheadMatcherData<W>>> {
120 let mut result = false;
121 let mut nprefix = 0;
122 let mut la_matcher_data = LookAheadMatcherData::default();
123 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
124 la_matcher_data.clear_lookahead_weight();
125 }
126 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_PREFIX) {
127 la_matcher_data.clear_lookahead_prefix();
128 }
129 if self.fst.borrow().is_final(matcher_state)? && lfst.borrow().is_final(lfst_state)? {
130 if !MFT::flags()
131 .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
132 {
133 return Ok(Some(la_matcher_data));
134 }
135 nprefix += 1;
136 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
137 unsafe {
138 let fw_matcher_state = self
139 .fst
140 .borrow()
141 .final_weight_unchecked(matcher_state)
142 .unsafe_unwrap();
143 let fw_lfst_state = lfst
144 .borrow()
145 .final_weight_unchecked(lfst_state)
146 .unsafe_unwrap();
147 la_matcher_data
148 .lookahead_weight
149 .plus_assign(fw_matcher_state.times(fw_lfst_state)?)?;
150 }
151 }
152 result = true;
153 }
154 {
155 let mut iter = self.iter(matcher_state, NO_LABEL)?.peekable();
156 if iter.peek().is_some() {
157 if !MFT::flags()
158 .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
159 {
160 return Ok(Some(la_matcher_data));
161 }
162 nprefix += 1;
163 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
164 for tr in iter {
165 match tr {
166 IterItemMatcher::Tr(a) => {
167 la_matcher_data.lookahead_weight.plus_assign(&a.weight)?
168 }
169 IterItemMatcher::EpsLoop => {
170 la_matcher_data.lookahead_weight.plus_assign(W::one())?
171 }
172 };
173 }
174 }
175 result = true;
176 }
177 }
178
179 let match_type = self.match_type(false)?;
180 for tr in lfst.borrow().get_trs(lfst_state)?.trs() {
181 let label = match match_type {
182 MatchType::MatchInput => tr.olabel,
183 MatchType::MatchOutput => tr.ilabel,
184 _ => bail!("Bad match type"),
185 };
186 if label == EPS_LABEL {
187 if !MFT::flags()
188 .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
189 {
190 return Ok(Some(la_matcher_data));
191 }
192 if !MFT::flags().contains(MatcherFlags::LOOKAHEAD_NON_EPSILON_PREFIX) {
193 nprefix += 1;
194 }
195 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
196 la_matcher_data.lookahead_weight.plus_assign(&tr.weight)?;
197 }
198 result = true;
199 } else {
200 let mut iter = self.iter(matcher_state, label)?.peekable();
201 if iter.peek().is_some() {
202 if !MFT::flags()
203 .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
204 {
205 return Ok(Some(la_matcher_data));
206 }
207 for matcher_value in iter {
208 nprefix += 1;
209 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
210 match matcher_value {
211 IterItemMatcher::Tr(a) => la_matcher_data
212 .lookahead_weight
213 .plus_assign(tr.weight.times(&a.weight)?)?,
214 IterItemMatcher::EpsLoop => la_matcher_data
215 .lookahead_weight
216 .plus_assign(tr.weight.times(W::one())?)?,
217 };
218 }
219 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_PREFIX) && nprefix == 1 {
220 la_matcher_data.set_lookahead_prefix(tr.clone());
221 }
222 }
223 result = true;
224 }
225 }
226 }
227
228 if MFT::flags().contains(MatcherFlags::LOOKAHEAD_PREFIX) {
229 if nprefix == 1 {
230 la_matcher_data.clear_lookahead_weight();
231 } else {
232 la_matcher_data.clear_lookahead_prefix();
233 }
234 }
235
236 if result {
237 Ok(Some(la_matcher_data))
238 } else {
239 Ok(None)
240 }
241 }
242
243 fn lookahead_label(&self, state: StateId, label: Label) -> Result<bool> {
244 let mut it = self.matcher.iter(state, label)?;
245 Ok(it.next().is_some())
246 }
247
248 fn lookahead_prefix(&self, tr: &mut Tr<W>, la_matcher_data: &LookAheadMatcherData<W>) -> bool {
249 la_matcher_data.default_lookahead_prefix(tr)
250 }
251}