rustfst/algorithms/compose/lookahead_matchers/
tr_lookahead_matcher.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use anyhow::Result;
7use unsafe_unwrap::UnsafeUnwrap;
8
9use crate::algorithms::compose::lookahead_matchers::{
10    LookAheadMatcherData, LookaheadMatcher, MatcherFlagsTrait,
11};
12use crate::algorithms::compose::matchers::{IterItemMatcher, MatchType, Matcher, MatcherFlags};
13use crate::fst_traits::Fst;
14use crate::semirings::Semiring;
15use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL};
16
17#[derive(Debug, Clone)]
18pub struct TrLookAheadMatcher<W, F, B, M, MFT>
19where
20    W: Semiring,
21    F: Fst<W>,
22    B: Borrow<F>,
23    M: Matcher<W, F, B>,
24{
25    // matcher fst
26    fst: B,
27    matcher: M,
28    ghost: PhantomData<(W, F, MFT)>,
29}
30
31impl<W, F, B, M, MFT> Matcher<W, F, B> for TrLookAheadMatcher<W, F, B, M, MFT>
32where
33    W: Semiring,
34    F: Fst<W>,
35    B: Borrow<F> + Debug + Clone,
36    M: Matcher<W, F, B>,
37    MFT: MatcherFlagsTrait,
38{
39    type Iter = M::Iter;
40
41    fn new(fst: B, match_type: MatchType) -> Result<Self> {
42        Ok(Self {
43            matcher: M::new(fst.clone(), match_type)?,
44            fst,
45            ghost: PhantomData,
46        })
47    }
48
49    fn iter(&self, state: StateId, label: Label) -> Result<Self::Iter> {
50        self.matcher.iter(state, label)
51    }
52
53    fn final_weight(&self, state: StateId) -> Result<Option<W>> {
54        self.matcher.final_weight(state)
55    }
56
57    fn match_type(&self, test: bool) -> Result<MatchType> {
58        self.matcher.match_type(test)
59    }
60
61    fn flags(&self) -> MatcherFlags {
62        self.matcher.flags()
63            | MatcherFlags::INPUT_LOOKAHEAD_MATCHER
64            | MatcherFlags::OUTPUT_LOOKAHEAD_MATCHER
65            | MFT::flags()
66    }
67
68    fn priority(&self, state: StateId) -> Result<usize> {
69        self.matcher.priority(state)
70    }
71
72    fn fst(&self) -> &B {
73        &self.fst
74    }
75}
76
77impl<W, F, B, M, MFT> LookaheadMatcher<W, F, B> for TrLookAheadMatcher<W, F, B, M, MFT>
78where
79    W: Semiring,
80    F: Fst<W>,
81    B: Borrow<F> + Debug + Clone,
82    M: Matcher<W, F, B>,
83    MFT: MatcherFlagsTrait,
84{
85    // NullAddon
86    type MatcherData = ();
87
88    fn data(&self) -> Option<&Arc<Self::MatcherData>> {
89        None
90    }
91
92    fn new_with_data(
93        fst: B,
94        match_type: MatchType,
95        _data: Option<Arc<Self::MatcherData>>,
96    ) -> Result<Self> {
97        Self::new(fst, match_type)
98    }
99
100    fn create_data<F2: Fst<W>, BF2: Borrow<F2>>(
101        _fst: BF2,
102        _match_type: MatchType,
103    ) -> Result<Option<Self::MatcherData>> {
104        Ok(None)
105    }
106
107    fn init_lookahead_fst<LF: Fst<W>, BLF: Borrow<LF> + Clone>(
108        &mut self,
109        _lfst: &BLF,
110    ) -> Result<()> {
111        Ok(())
112    }
113
114    fn lookahead_fst<LF: Fst<W>, BLF: Borrow<LF>>(
115        &self,
116        matcher_state: StateId,
117        lfst: &BLF,
118        lfst_state: StateId,
119    ) -> Result<Option<LookAheadMatcherData<W>>> {
120        let mut result = false;
121        let mut nprefix = 0;
122        let mut la_matcher_data = LookAheadMatcherData::default();
123        if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
124            la_matcher_data.clear_lookahead_weight();
125        }
126        if MFT::flags().contains(MatcherFlags::LOOKAHEAD_PREFIX) {
127            la_matcher_data.clear_lookahead_prefix();
128        }
129        if self.fst.borrow().is_final(matcher_state)? && lfst.borrow().is_final(lfst_state)? {
130            if !MFT::flags()
131                .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
132            {
133                return Ok(Some(la_matcher_data));
134            }
135            nprefix += 1;
136            if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
137                unsafe {
138                    let fw_matcher_state = self
139                        .fst
140                        .borrow()
141                        .final_weight_unchecked(matcher_state)
142                        .unsafe_unwrap();
143                    let fw_lfst_state = lfst
144                        .borrow()
145                        .final_weight_unchecked(lfst_state)
146                        .unsafe_unwrap();
147                    la_matcher_data
148                        .lookahead_weight
149                        .plus_assign(fw_matcher_state.times(fw_lfst_state)?)?;
150                }
151            }
152            result = true;
153        }
154        {
155            let mut iter = self.iter(matcher_state, NO_LABEL)?.peekable();
156            if iter.peek().is_some() {
157                if !MFT::flags()
158                    .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
159                {
160                    return Ok(Some(la_matcher_data));
161                }
162                nprefix += 1;
163                if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
164                    for tr in iter {
165                        match tr {
166                            IterItemMatcher::Tr(a) => {
167                                la_matcher_data.lookahead_weight.plus_assign(&a.weight)?
168                            }
169                            IterItemMatcher::EpsLoop => {
170                                la_matcher_data.lookahead_weight.plus_assign(W::one())?
171                            }
172                        };
173                    }
174                }
175                result = true;
176            }
177        }
178
179        let match_type = self.match_type(false)?;
180        for tr in lfst.borrow().get_trs(lfst_state)?.trs() {
181            let label = match match_type {
182                MatchType::MatchInput => tr.olabel,
183                MatchType::MatchOutput => tr.ilabel,
184                _ => bail!("Bad match type"),
185            };
186            if label == EPS_LABEL {
187                if !MFT::flags()
188                    .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
189                {
190                    return Ok(Some(la_matcher_data));
191                }
192                if !MFT::flags().contains(MatcherFlags::LOOKAHEAD_NON_EPSILON_PREFIX) {
193                    nprefix += 1;
194                }
195                if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
196                    la_matcher_data.lookahead_weight.plus_assign(&tr.weight)?;
197                }
198                result = true;
199            } else {
200                let mut iter = self.iter(matcher_state, label)?.peekable();
201                if iter.peek().is_some() {
202                    if !MFT::flags()
203                        .contains(MatcherFlags::LOOKAHEAD_WEIGHT | MatcherFlags::LOOKAHEAD_PREFIX)
204                    {
205                        return Ok(Some(la_matcher_data));
206                    }
207                    for matcher_value in iter {
208                        nprefix += 1;
209                        if MFT::flags().contains(MatcherFlags::LOOKAHEAD_WEIGHT) {
210                            match matcher_value {
211                                IterItemMatcher::Tr(a) => la_matcher_data
212                                    .lookahead_weight
213                                    .plus_assign(tr.weight.times(&a.weight)?)?,
214                                IterItemMatcher::EpsLoop => la_matcher_data
215                                    .lookahead_weight
216                                    .plus_assign(tr.weight.times(W::one())?)?,
217                            };
218                        }
219                        if MFT::flags().contains(MatcherFlags::LOOKAHEAD_PREFIX) && nprefix == 1 {
220                            la_matcher_data.set_lookahead_prefix(tr.clone());
221                        }
222                    }
223                    result = true;
224                }
225            }
226        }
227
228        if MFT::flags().contains(MatcherFlags::LOOKAHEAD_PREFIX) {
229            if nprefix == 1 {
230                la_matcher_data.clear_lookahead_weight();
231            } else {
232                la_matcher_data.clear_lookahead_prefix();
233            }
234        }
235
236        if result {
237            Ok(Some(la_matcher_data))
238        } else {
239            Ok(None)
240        }
241    }
242
243    fn lookahead_label(&self, state: StateId, label: Label) -> Result<bool> {
244        let mut it = self.matcher.iter(state, label)?;
245        Ok(it.next().is_some())
246    }
247
248    fn lookahead_prefix(&self, tr: &mut Tr<W>, la_matcher_data: &LookAheadMatcherData<W>) -> bool {
249        la_matcher_data.default_lookahead_prefix(tr)
250    }
251}