rustfst/algorithms/compose/
matcher_fst.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use anyhow::Result;
7
8use crate::algorithms::compose::lookahead_matchers::{LabelLookAheadRelabeler, LookaheadMatcher};
9use crate::algorithms::compose::matchers::MatchType;
10use crate::algorithms::compose::FstAddOn;
11use crate::algorithms::compose::LabelReachableData;
12use crate::fst_properties::FstProperties;
13use crate::fst_traits::{
14    CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, MutableFst, StateIterator,
15};
16use crate::semirings::Semiring;
17use crate::{StateId, SymbolTable};
18
19type InnerFstAddOn<F, T> = FstAddOn<F, (Option<Arc<T>>, Option<Arc<T>>)>;
20
21#[derive(Clone, PartialEq, Debug)]
22pub struct MatcherFst<W, F, B, M, T> {
23    fst_add_on: InnerFstAddOn<F, T>,
24    matcher: PhantomData<M>,
25    w: PhantomData<(W, B)>,
26}
27
28impl<W, F, B, M, T> MatcherFst<W, F, B, M, T> {
29    pub fn fst(&self) -> &F {
30        self.fst_add_on.fst()
31    }
32
33    pub fn addon(&self) -> &(Option<Arc<T>>, Option<Arc<T>>) {
34        self.fst_add_on.add_on()
35    }
36
37    pub fn data(&self, match_type: MatchType) -> Option<&Arc<T>> {
38        let data = self.fst_add_on.add_on();
39        if match_type == MatchType::MatchInput {
40            data.0.as_ref()
41        } else {
42            data.1.as_ref()
43        }
44    }
45}
46
47// TODO: To be generalized
48impl<W, F, B, M> MatcherFst<W, F, B, M, M::MatcherData>
49where
50    W: Semiring,
51    F: MutableFst<W>,
52    B: Borrow<F>,
53    M: LookaheadMatcher<W, F, B, MatcherData = LabelReachableData>,
54{
55    pub fn new(mut fst: F) -> Result<Self> {
56        let imatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchInput)?;
57        let omatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchOutput)?;
58
59        let mut add_on = (imatcher_data, omatcher_data);
60        LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?;
61
62        let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new));
63
64        let fst_add_on = FstAddOn::new(fst, add_on);
65        Ok(Self {
66            fst_add_on,
67            matcher: PhantomData,
68            w: PhantomData,
69        })
70    }
71
72    // Construct a new Matcher Fst intended for LookAhead composition and relabel fst2 wrt to the first fst.
73    pub fn new_with_relabeling<F2: MutableFst<W>>(
74        mut fst: F,
75        fst2: &mut F2,
76        relabel_input: bool,
77    ) -> Result<Self> {
78        let imatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchInput)?;
79        let omatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchOutput)?;
80
81        let mut add_on = (imatcher_data, omatcher_data);
82        LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?;
83        LabelLookAheadRelabeler::relabel(fst2, &mut add_on, relabel_input)?;
84
85        let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new));
86
87        let fst_add_on = FstAddOn::new(fst, add_on);
88        Ok(Self {
89            fst_add_on,
90            matcher: PhantomData,
91            w: PhantomData,
92        })
93    }
94}
95
96impl<W: Semiring, F: CoreFst<W>, B: Borrow<F>, M, T> CoreFst<W> for MatcherFst<W, F, B, M, T> {
97    type TRS = <FstAddOn<F, T> as CoreFst<W>>::TRS;
98
99    fn start(&self) -> Option<StateId> {
100        self.fst_add_on.start()
101    }
102
103    fn final_weight(&self, state_id: StateId) -> Result<Option<W>> {
104        self.fst_add_on.final_weight(state_id)
105    }
106
107    unsafe fn final_weight_unchecked(&self, state_id: StateId) -> Option<W> {
108        self.fst_add_on.final_weight_unchecked(state_id)
109    }
110
111    fn num_trs(&self, s: StateId) -> Result<usize> {
112        self.fst_add_on.num_trs(s)
113    }
114
115    unsafe fn num_trs_unchecked(&self, s: StateId) -> usize {
116        self.fst_add_on.num_trs_unchecked(s)
117    }
118
119    fn get_trs(&self, state_id: StateId) -> Result<Self::TRS> {
120        self.fst_add_on.get_trs(state_id)
121    }
122
123    unsafe fn get_trs_unchecked(&self, state_id: StateId) -> Self::TRS {
124        self.fst_add_on.get_trs_unchecked(state_id)
125    }
126
127    fn properties(&self) -> FstProperties {
128        self.fst_add_on.properties()
129    }
130
131    fn num_input_epsilons(&self, state: StateId) -> Result<usize> {
132        self.fst_add_on.num_input_epsilons(state)
133    }
134
135    fn num_output_epsilons(&self, state: StateId) -> Result<usize> {
136        self.fst_add_on.num_output_epsilons(state)
137    }
138}
139
140impl<'a, W, F: StateIterator<'a>, B: Borrow<F>, M, T> StateIterator<'a>
141    for MatcherFst<W, F, B, M, T>
142{
143    type Iter = <F as StateIterator<'a>>::Iter;
144
145    fn states_iter(&'a self) -> Self::Iter {
146        self.fst_add_on.states_iter()
147    }
148}
149
150impl<'a, W, F, B, M, T> FstIterator<'a, W> for MatcherFst<W, F, B, M, T>
151where
152    W: Semiring,
153    F: FstIterator<'a, W>,
154    B: Borrow<F>,
155{
156    type FstIter = F::FstIter;
157
158    fn fst_iter(&'a self) -> Self::FstIter {
159        self.fst_add_on.fst_iter()
160    }
161}
162
163impl<W, F, B, M, T> Fst<W> for MatcherFst<W, F, B, M, T>
164where
165    W: Semiring,
166    F: Fst<W>,
167    B: Borrow<F> + Debug,
168    M: Debug,
169    T: Debug,
170{
171    fn input_symbols(&self) -> Option<&Arc<SymbolTable>> {
172        self.fst_add_on.input_symbols()
173    }
174
175    fn output_symbols(&self) -> Option<&Arc<SymbolTable>> {
176        self.fst_add_on.output_symbols()
177    }
178
179    fn set_input_symbols(&mut self, symt: Arc<SymbolTable>) {
180        self.fst_add_on.set_input_symbols(symt)
181    }
182
183    fn set_output_symbols(&mut self, symt: Arc<SymbolTable>) {
184        self.fst_add_on.set_output_symbols(symt)
185    }
186
187    fn take_input_symbols(&mut self) -> Option<Arc<SymbolTable>> {
188        self.fst_add_on.take_input_symbols()
189    }
190
191    fn take_output_symbols(&mut self) -> Option<Arc<SymbolTable>> {
192        self.fst_add_on.take_output_symbols()
193    }
194}
195
196impl<W, F, M, B, T> ExpandedFst<W> for MatcherFst<W, F, B, M, T>
197where
198    W: Semiring,
199    F: ExpandedFst<W>,
200    B: Borrow<F> + Debug + PartialEq + Clone,
201    M: Debug + Clone + PartialEq,
202    T: Debug + Clone + PartialEq,
203{
204    fn num_states(&self) -> usize {
205        self.fst_add_on.num_states()
206    }
207}
208
209impl<W, F, B, M, T> FstIntoIterator<W> for MatcherFst<W, F, B, M, T>
210where
211    W: Semiring,
212    F: FstIntoIterator<W>,
213    B: Borrow<F> + Debug + PartialEq + Clone,
214    T: Debug,
215{
216    type TrsIter = F::TrsIter;
217    type FstIter = F::FstIter;
218
219    fn fst_into_iter(self) -> Self::FstIter {
220        self.fst_add_on.fst_into_iter()
221    }
222}