rustfst/algorithms/compose/
matcher_fst.rs1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use anyhow::Result;
7
8use crate::algorithms::compose::lookahead_matchers::{LabelLookAheadRelabeler, LookaheadMatcher};
9use crate::algorithms::compose::matchers::MatchType;
10use crate::algorithms::compose::FstAddOn;
11use crate::algorithms::compose::LabelReachableData;
12use crate::fst_properties::FstProperties;
13use crate::fst_traits::{
14 CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, MutableFst, StateIterator,
15};
16use crate::semirings::Semiring;
17use crate::{StateId, SymbolTable};
18
19type InnerFstAddOn<F, T> = FstAddOn<F, (Option<Arc<T>>, Option<Arc<T>>)>;
20
21#[derive(Clone, PartialEq, Debug)]
22pub struct MatcherFst<W, F, B, M, T> {
23 fst_add_on: InnerFstAddOn<F, T>,
24 matcher: PhantomData<M>,
25 w: PhantomData<(W, B)>,
26}
27
28impl<W, F, B, M, T> MatcherFst<W, F, B, M, T> {
29 pub fn fst(&self) -> &F {
30 self.fst_add_on.fst()
31 }
32
33 pub fn addon(&self) -> &(Option<Arc<T>>, Option<Arc<T>>) {
34 self.fst_add_on.add_on()
35 }
36
37 pub fn data(&self, match_type: MatchType) -> Option<&Arc<T>> {
38 let data = self.fst_add_on.add_on();
39 if match_type == MatchType::MatchInput {
40 data.0.as_ref()
41 } else {
42 data.1.as_ref()
43 }
44 }
45}
46
47impl<W, F, B, M> MatcherFst<W, F, B, M, M::MatcherData>
49where
50 W: Semiring,
51 F: MutableFst<W>,
52 B: Borrow<F>,
53 M: LookaheadMatcher<W, F, B, MatcherData = LabelReachableData>,
54{
55 pub fn new(mut fst: F) -> Result<Self> {
56 let imatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchInput)?;
57 let omatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchOutput)?;
58
59 let mut add_on = (imatcher_data, omatcher_data);
60 LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?;
61
62 let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new));
63
64 let fst_add_on = FstAddOn::new(fst, add_on);
65 Ok(Self {
66 fst_add_on,
67 matcher: PhantomData,
68 w: PhantomData,
69 })
70 }
71
72 pub fn new_with_relabeling<F2: MutableFst<W>>(
74 mut fst: F,
75 fst2: &mut F2,
76 relabel_input: bool,
77 ) -> Result<Self> {
78 let imatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchInput)?;
79 let omatcher_data = M::create_data::<F, _>(&fst, MatchType::MatchOutput)?;
80
81 let mut add_on = (imatcher_data, omatcher_data);
82 LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?;
83 LabelLookAheadRelabeler::relabel(fst2, &mut add_on, relabel_input)?;
84
85 let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new));
86
87 let fst_add_on = FstAddOn::new(fst, add_on);
88 Ok(Self {
89 fst_add_on,
90 matcher: PhantomData,
91 w: PhantomData,
92 })
93 }
94}
95
96impl<W: Semiring, F: CoreFst<W>, B: Borrow<F>, M, T> CoreFst<W> for MatcherFst<W, F, B, M, T> {
97 type TRS = <FstAddOn<F, T> as CoreFst<W>>::TRS;
98
99 fn start(&self) -> Option<StateId> {
100 self.fst_add_on.start()
101 }
102
103 fn final_weight(&self, state_id: StateId) -> Result<Option<W>> {
104 self.fst_add_on.final_weight(state_id)
105 }
106
107 unsafe fn final_weight_unchecked(&self, state_id: StateId) -> Option<W> {
108 self.fst_add_on.final_weight_unchecked(state_id)
109 }
110
111 fn num_trs(&self, s: StateId) -> Result<usize> {
112 self.fst_add_on.num_trs(s)
113 }
114
115 unsafe fn num_trs_unchecked(&self, s: StateId) -> usize {
116 self.fst_add_on.num_trs_unchecked(s)
117 }
118
119 fn get_trs(&self, state_id: StateId) -> Result<Self::TRS> {
120 self.fst_add_on.get_trs(state_id)
121 }
122
123 unsafe fn get_trs_unchecked(&self, state_id: StateId) -> Self::TRS {
124 self.fst_add_on.get_trs_unchecked(state_id)
125 }
126
127 fn properties(&self) -> FstProperties {
128 self.fst_add_on.properties()
129 }
130
131 fn num_input_epsilons(&self, state: StateId) -> Result<usize> {
132 self.fst_add_on.num_input_epsilons(state)
133 }
134
135 fn num_output_epsilons(&self, state: StateId) -> Result<usize> {
136 self.fst_add_on.num_output_epsilons(state)
137 }
138}
139
140impl<'a, W, F: StateIterator<'a>, B: Borrow<F>, M, T> StateIterator<'a>
141 for MatcherFst<W, F, B, M, T>
142{
143 type Iter = <F as StateIterator<'a>>::Iter;
144
145 fn states_iter(&'a self) -> Self::Iter {
146 self.fst_add_on.states_iter()
147 }
148}
149
150impl<'a, W, F, B, M, T> FstIterator<'a, W> for MatcherFst<W, F, B, M, T>
151where
152 W: Semiring,
153 F: FstIterator<'a, W>,
154 B: Borrow<F>,
155{
156 type FstIter = F::FstIter;
157
158 fn fst_iter(&'a self) -> Self::FstIter {
159 self.fst_add_on.fst_iter()
160 }
161}
162
163impl<W, F, B, M, T> Fst<W> for MatcherFst<W, F, B, M, T>
164where
165 W: Semiring,
166 F: Fst<W>,
167 B: Borrow<F> + Debug,
168 M: Debug,
169 T: Debug,
170{
171 fn input_symbols(&self) -> Option<&Arc<SymbolTable>> {
172 self.fst_add_on.input_symbols()
173 }
174
175 fn output_symbols(&self) -> Option<&Arc<SymbolTable>> {
176 self.fst_add_on.output_symbols()
177 }
178
179 fn set_input_symbols(&mut self, symt: Arc<SymbolTable>) {
180 self.fst_add_on.set_input_symbols(symt)
181 }
182
183 fn set_output_symbols(&mut self, symt: Arc<SymbolTable>) {
184 self.fst_add_on.set_output_symbols(symt)
185 }
186
187 fn take_input_symbols(&mut self) -> Option<Arc<SymbolTable>> {
188 self.fst_add_on.take_input_symbols()
189 }
190
191 fn take_output_symbols(&mut self) -> Option<Arc<SymbolTable>> {
192 self.fst_add_on.take_output_symbols()
193 }
194}
195
196impl<W, F, M, B, T> ExpandedFst<W> for MatcherFst<W, F, B, M, T>
197where
198 W: Semiring,
199 F: ExpandedFst<W>,
200 B: Borrow<F> + Debug + PartialEq + Clone,
201 M: Debug + Clone + PartialEq,
202 T: Debug + Clone + PartialEq,
203{
204 fn num_states(&self) -> usize {
205 self.fst_add_on.num_states()
206 }
207}
208
209impl<W, F, B, M, T> FstIntoIterator<W> for MatcherFst<W, F, B, M, T>
210where
211 W: Semiring,
212 F: FstIntoIterator<W>,
213 B: Borrow<F> + Debug + PartialEq + Clone,
214 T: Debug,
215{
216 type TrsIter = F::TrsIter;
217 type FstIter = F::FstIter;
218
219 fn fst_into_iter(self) -> Self::FstIter {
220 self.fst_add_on.fst_into_iter()
221 }
222}