rustfst/algorithms/compose/
compose_fst_op.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::fs::{read, File};
4use std::hash::Hash;
5use std::io::BufWriter;
6use std::path::Path;
7use std::sync::Arc;
8
9use anyhow::{Context, Result};
10
11use crate::algorithms::compose::compose_filters::{ComposeFilter, ComposeFilterBuilder};
12use crate::algorithms::compose::filter_states::FilterState;
13use crate::algorithms::compose::lookahead_filters::lookahead_selector::Selector;
14use crate::algorithms::compose::matchers::{IterItemMatcher, MatcherFlags};
15use crate::algorithms::compose::matchers::{MatchType, Matcher, REQUIRE_PRIORITY};
16use crate::algorithms::compose::{ComposeFstOpOptions, ComposeStateTuple};
17use crate::algorithms::lazy::{AccessibleOpState, FstOp, SerializableOpState, StateTable};
18use crate::fst_properties::mutable_properties::compose_properties;
19use crate::fst_properties::FstProperties;
20use crate::fst_traits::Fst;
21use crate::parsers::SerializeBinary;
22use crate::semirings::Semiring;
23use crate::{StateId, Tr, Trs, TrsVec, EPS_LABEL, NO_LABEL};
24
25#[derive(Debug, Clone)]
26pub struct ComposeFstOpState<T: Hash + Eq + Clone> {
27    state_table: StateTable<T>,
28}
29
30impl<T: Hash + Eq + Clone> Default for ComposeFstOpState<T> {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<T: Hash + Eq + Clone> ComposeFstOpState<T> {
37    pub fn new() -> Self {
38        ComposeFstOpState {
39            state_table: StateTable::<T>::new(),
40        }
41    }
42}
43
44impl<T: Hash + Eq + Clone + SerializeBinary> SerializableOpState for ComposeFstOpState<T> {
45    /// Loads a ComposeFstOpState from a file in binary format.
46    fn read<P: AsRef<Path>>(path: P) -> Result<Self> {
47        let data = read(path.as_ref())
48            .with_context(|| format!("Can't open file : {:?}", path.as_ref()))?;
49
50        // Parse StateTable
51        let (_, state_table) = StateTable::<T>::parse_binary(&data)
52            .map_err(|e| format_err!("Error while parsing binary StateTable : {:?}", e))?;
53
54        Ok(Self { state_table })
55    }
56
57    /// Writes a ComposeFstOpState to a file in binary format.
58    fn write<P: AsRef<Path>>(&self, path: P) -> Result<()> {
59        let mut file = BufWriter::new(File::create(path)?);
60
61        // Write StateTable
62        self.state_table.write_binary(&mut file)?;
63        Ok(())
64    }
65}
66
67#[derive(Debug)]
68pub struct ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>
69where
70    W: Semiring,
71    F1: Fst<W>,
72    F2: Fst<W>,
73    B1: Borrow<F1> + Debug + Clone,
74    B2: Borrow<F2> + Debug + Clone,
75    M1: Matcher<W, F1, B1>,
76    M2: Matcher<W, F2, B2>,
77    CFB: ComposeFilterBuilder<W, F1, F2, B1, B2, M1, M2>,
78{
79    compose_filter_builder: CFB,
80    compose_state: ComposeFstOpState<
81        ComposeStateTuple<<CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS>,
82    >,
83    match_type: MatchType,
84    properties: FstProperties,
85    fst1: B1,
86    fst2: B2,
87}
88
89impl<W, F1, F2, B1, B2, M1, M2, CFB> Clone for ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>
90where
91    W: Semiring,
92    F1: Fst<W>,
93    F2: Fst<W>,
94    B1: Borrow<F1> + Debug + Clone,
95    B2: Borrow<F2> + Debug + Clone,
96    M1: Matcher<W, F1, B1>,
97    M2: Matcher<W, F2, B2>,
98    CFB: ComposeFilterBuilder<W, F1, F2, B1, B2, M1, M2>,
99{
100    fn clone(&self) -> Self {
101        Self {
102            compose_filter_builder: self.compose_filter_builder.clone(),
103            compose_state: self.compose_state.clone(),
104            match_type: self.match_type,
105            properties: self.properties,
106            fst1: self.fst1.clone(),
107            fst2: self.fst2.clone(),
108        }
109    }
110}
111
112impl<W, F1, F2, B1, B2, M1, M2, CFB> ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>
113where
114    W: Semiring,
115    F1: Fst<W>,
116    F2: Fst<W>,
117    B1: Borrow<F1> + Debug + Clone,
118    B2: Borrow<F2> + Debug + Clone,
119    M1: Matcher<W, F1, B1>,
120    M2: Matcher<W, F2, B2>,
121    CFB: ComposeFilterBuilder<W, F1, F2, B1, B2, M1, M2>,
122{
123    // Compose specifying two matcher types Matcher1 and Matcher2. Requires input
124    // FST (of the same Tr type, but o.w. arbitrary) match the corresponding
125    // matcher FST types). Recommended only for advanced use in demanding or
126    // specialized applications due to potential code bloat and matcher
127    // incompatibilities.
128    // fn new2(fst1: &'fst F1, fst2: &'fst F2) -> Result<Self> {
129    //     unimplemented!()
130    // }
131
132    pub fn new(
133        fst1: B1,
134        fst2: B2,
135        opts: ComposeFstOpOptions<
136            M1,
137            M2,
138            CFB,
139            ComposeFstOpState<
140                ComposeStateTuple<
141                    <CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS,
142                >,
143            >,
144        >,
145    ) -> Result<Self> {
146        let matcher1 = opts.matcher1;
147        let matcher2 = opts.matcher2;
148        let compose_filter_builder = opts.filter_builder.unwrap_or_else(|| {
149            ComposeFilterBuilder::new(fst1.clone(), fst2.clone(), matcher1, matcher2).unwrap()
150        });
151        let compose_filter = compose_filter_builder.build()?;
152        let match_type = Self::match_type(compose_filter.matcher1(), compose_filter.matcher2())?;
153
154        let fprops1 = fst1.borrow().properties();
155        let fprops2 = fst2.borrow().properties();
156        let cprops = compose_properties(fprops1, fprops2);
157        let properties = compose_filter.properties(cprops);
158
159        Ok(Self {
160            compose_filter_builder,
161            compose_state: opts.op_state.unwrap_or_default(),
162            match_type,
163            properties,
164            fst1,
165            fst2,
166        })
167    }
168
169    fn match_type(matcher1: &CFB::IM1, matcher2: &CFB::IM2) -> Result<MatchType> {
170        if matcher1.flags().contains(MatcherFlags::REQUIRE_MATCH)
171            && matcher1.match_type(true)? != MatchType::MatchOutput
172        {
173            bail!("ComposeFst: 1st argument cannot perform required matching (sort?)")
174        }
175        if matcher2.flags().contains(MatcherFlags::REQUIRE_MATCH)
176            && matcher2.match_type(true)? != MatchType::MatchInput
177        {
178            bail!("ComposeFst: 2nd argument cannot perform required matching (sort?)")
179        }
180
181        let type1 = matcher1.match_type(false)?;
182        let type2 = matcher2.match_type(false)?;
183        let mt = if type1 == MatchType::MatchOutput && type2 == MatchType::MatchInput {
184            MatchType::MatchBoth
185        } else if type1 == MatchType::MatchOutput {
186            MatchType::MatchOutput
187        } else if type2 == MatchType::MatchInput {
188            MatchType::MatchInput
189        } else if matcher1.match_type(true)? == MatchType::MatchOutput {
190            MatchType::MatchOutput
191        } else if matcher2.match_type(true)? == MatchType::MatchInput {
192            MatchType::MatchInput
193        } else {
194            bail!("ComposeFst: 1st argument cannot match on output labels and 2nd argument cannot match on input labels (sort?).")
195        };
196        Ok(mt)
197    }
198
199    fn match_input(&self, s1: StateId, s2: StateId, compose_filter: &CFB::CF) -> Result<bool> {
200        match self.match_type {
201            MatchType::MatchInput => Ok(true),
202            MatchType::MatchOutput => Ok(false),
203            _ => {
204                // Match both
205                let priority1 = compose_filter.matcher1().priority(s1)?;
206                let priority2 = compose_filter.matcher2().priority(s2)?;
207                if priority1 == REQUIRE_PRIORITY && priority2 == REQUIRE_PRIORITY {
208                    bail!("Both sides can't require match")
209                }
210                if priority1 == REQUIRE_PRIORITY {
211                    return Ok(false);
212                }
213                if priority2 == REQUIRE_PRIORITY {
214                    return Ok(true);
215                }
216                Ok(priority1 <= priority2)
217            }
218        }
219    }
220
221    fn ordered_expand(
222        &self,
223        sa: StateId,
224        sb: StateId,
225        match_input: bool,
226        mut compose_filter: CFB::CF,
227        selector: Selector,
228    ) -> Result<TrsVec<W>> {
229        let tr_loop = if match_input {
230            Tr::new(EPS_LABEL, NO_LABEL, W::one(), sb)
231        } else {
232            Tr::new(NO_LABEL, EPS_LABEL, W::one(), sb)
233        };
234        let mut trs = vec![];
235
236        match selector {
237            Selector::Fst1Matcher2 => {
238                self.match_tr(
239                    sa,
240                    &tr_loop,
241                    match_input,
242                    &mut compose_filter,
243                    selector,
244                    &mut trs,
245                )?;
246                for tr in self.fst1.borrow().get_trs(sb)?.trs() {
247                    self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?;
248                }
249            }
250            Selector::Fst2Matcher1 => {
251                self.match_tr(
252                    sa,
253                    &tr_loop,
254                    match_input,
255                    &mut compose_filter,
256                    selector,
257                    &mut trs,
258                )?;
259                for tr in self.fst2.borrow().get_trs(sb)?.trs() {
260                    self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?;
261                }
262            }
263        }
264        Ok(TrsVec(Arc::new(trs)))
265    }
266
267    fn add_tr(
268        &self,
269        mut arc1: Tr<W>,
270        arc2: Tr<W>,
271        fs: <CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS,
272    ) -> Result<Tr<W>> {
273        let tuple = ComposeStateTuple {
274            fs,
275            s1: arc1.nextstate,
276            s2: arc2.nextstate,
277        };
278        arc1.weight.times_assign(arc2.weight)?;
279        Ok(Tr::new(
280            arc1.ilabel,
281            arc2.olabel,
282            arc1.weight,
283            self.compose_state.state_table.find_id(tuple),
284        ))
285    }
286
287    fn match_tr_selected(
288        &self,
289        sa: StateId,
290        tr: &Tr<W>,
291        match_input: bool,
292        compose_filter: &mut CFB::CF,
293        it: impl Iterator<Item = IterItemMatcher<W>>,
294        trs: &mut Vec<Tr<W>>,
295    ) -> Result<()> {
296        let match_type = if match_input {
297            MatchType::MatchInput
298        } else {
299            MatchType::MatchOutput
300        };
301        for arca in it {
302            let mut arca = arca.into_tr(sa, match_type)?;
303            let mut arcb = tr.clone();
304            if match_input {
305                let fs = compose_filter.filter_tr(&mut arcb, &mut arca)?;
306                if fs
307                    != <CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS::new_no_state()
308                {
309                    trs.push(self.add_tr(arcb, arca, fs)?);
310                }
311            } else {
312                let fs = compose_filter.filter_tr(&mut arca, &mut arcb)?;
313
314                if fs
315                    != <CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS::new_no_state()
316                {
317                    trs.push(self.add_tr(arca, arcb, fs)?);
318                }
319            }
320        }
321        Ok(())
322    }
323
324    fn match_tr(
325        &self,
326        sa: StateId,
327        tr: &Tr<W>,
328        match_input: bool,
329        compose_filter: &mut CFB::CF,
330        selector: Selector,
331        trs: &mut Vec<Tr<W>>,
332    ) -> Result<()> {
333        let label = if match_input { tr.olabel } else { tr.ilabel };
334
335        match selector {
336            Selector::Fst2Matcher1 => self.match_tr_selected(
337                sa,
338                tr,
339                match_input,
340                compose_filter,
341                compose_filter.matcher1().iter(sa, label)?,
342                trs,
343            ),
344            Selector::Fst1Matcher2 => self.match_tr_selected(
345                sa,
346                tr,
347                match_input,
348                compose_filter,
349                compose_filter.matcher2().iter(sa, label)?,
350                trs,
351            ),
352        }
353    }
354}
355
356impl<W, F1, F2, B1, B2, M1, M2, CFB> AccessibleOpState
357    for ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>
358where
359    W: Semiring,
360    F1: Fst<W>,
361    F2: Fst<W>,
362    B1: Borrow<F1> + Debug + Clone,
363    B2: Borrow<F2> + Debug + Clone,
364    M1: Matcher<W, F1, B1>,
365    M2: Matcher<W, F2, B2>,
366    CFB: ComposeFilterBuilder<W, F1, F2, B1, B2, M1, M2>,
367    <CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS: SerializeBinary,
368{
369    type FstOpState = ComposeFstOpState<
370        ComposeStateTuple<<CFB::CF as ComposeFilter<W, F1, F2, B1, B2, CFB::IM1, CFB::IM2>>::FS>,
371    >;
372
373    fn get_op_state(&self) -> &Self::FstOpState {
374        &self.compose_state
375    }
376}
377
378impl<W, F1, F2, B1, B2, M1, M2, CFB> FstOp<W> for ComposeFstOp<W, F1, F2, B1, B2, M1, M2, CFB>
379where
380    W: Semiring,
381    F1: Fst<W>,
382    F2: Fst<W>,
383    B1: Borrow<F1> + Debug + Clone,
384    B2: Borrow<F2> + Debug + Clone,
385    M1: Matcher<W, F1, B1>,
386    M2: Matcher<W, F2, B2>,
387    CFB: ComposeFilterBuilder<W, F1, F2, B1, B2, M1, M2>,
388{
389    fn compute_start(&self) -> Result<Option<StateId>> {
390        let compose_filter = self.compose_filter_builder.build()?;
391        let s1 = self.fst1.borrow().start();
392        if s1.is_none() {
393            return Ok(None);
394        }
395        let s1 = s1.unwrap();
396        let s2 = self.fst2.borrow().start();
397        if s2.is_none() {
398            return Ok(None);
399        }
400        let s2 = s2.unwrap();
401        let fs = compose_filter.start();
402        let tuple = ComposeStateTuple { fs, s1, s2 };
403        Ok(Some(self.compose_state.state_table.find_id(tuple)))
404    }
405
406    fn compute_trs(&self, state: StateId) -> Result<TrsVec<W>> {
407        let tuple = self.compose_state.state_table.find_tuple(state);
408        let s1 = tuple.s1;
409        let s2 = tuple.s2;
410
411        let mut compose_filter = self.compose_filter_builder.build()?;
412        compose_filter.set_state(s1, s2, &tuple.fs)?;
413        if self.match_input(s1, s2, &compose_filter)? {
414            self.ordered_expand(s2, s1, true, compose_filter, Selector::Fst1Matcher2)
415        } else {
416            self.ordered_expand(s1, s2, false, compose_filter, Selector::Fst2Matcher1)
417        }
418    }
419
420    fn compute_final_weight(&self, state: StateId) -> Result<Option<W>> {
421        let tuple = self.compose_state.state_table.find_tuple(state);
422
423        // Construct a new ComposeFilter each time to avoid mutating the internal state.
424        let mut compose_filter = self.compose_filter_builder.build()?;
425
426        let s1 = tuple.s1;
427        let final1 = compose_filter.matcher1().final_weight(s1)?;
428        if final1.is_none() {
429            return Ok(None);
430        }
431        let mut final1 = final1.unwrap();
432
433        let s2 = tuple.s2;
434        let final2 = compose_filter.matcher2().final_weight(s2)?;
435        if final2.is_none() {
436            return Ok(None);
437        }
438        let mut final2 = final2.unwrap();
439
440        compose_filter.set_state(s1, s2, &tuple.fs)?;
441        compose_filter.filter_final(&mut final1, &mut final2)?;
442
443        final1.times_assign(&final2)?;
444        if final1.is_zero() {
445            Ok(None)
446        } else {
447            Ok(Some(final1))
448        }
449    }
450
451    fn properties(&self) -> FstProperties {
452        self.properties
453    }
454}