rustfst/algorithms/
tr_map.rs

1use anyhow::Result;
2use std::ops::Deref;
3
4use crate::fst_properties::FstProperties;
5use crate::fst_traits::MutableFst;
6use crate::semirings::Semiring;
7use crate::Tr;
8use crate::{Label, StateId, EPS_LABEL};
9
10/// Struct used to map final weights when performing a transition mapping.
11/// It will always be of the form `(EPS_LABEL, EPS_LABEL, final_weight)`
12/// where `final_weight` is the `final_weight` of the current state.
13///
14/// If the mapper modifies the input label or output one,
15/// a super final state will need to be created.
16#[derive(Clone, Debug)]
17pub struct FinalTr<W: Semiring> {
18    /// Input label. Default to `EPS_LABEL`.
19    pub ilabel: Label,
20    /// Output label. Default to `EPS_LABEL`.
21    pub olabel: Label,
22    /// Weight. Default to the final weight of the current state.
23    pub weight: W,
24}
25
26/// Determines how final weights are mapped.
27#[derive(PartialEq)]
28pub enum MapFinalAction {
29    /// A final weight is mapped into a final weight. An error is raised if this
30    /// is not possible.
31    MapNoSuperfinal,
32    /// A final weight is mapped to a transition to the superfinal state when the result
33    /// cannot be represented as a final weight. The superfinal state will be
34    /// added only if it is needed.
35    MapAllowSuperfinal,
36    /// A final weight is mapped to a transition to the superfinal state unless the
37    /// result can be represented as a final weight of weight Zero(). The
38    /// superfinal state is always added (if the input is not the empty FST).
39    MapRequireSuperfinal,
40}
41
42/// The TrMapper interfaces defines how trs and final weights are mapped.
43/// This is useful for implementing operations that do not change the number of
44/// trs.
45pub trait TrMapper<S: Semiring> {
46    /// How to modify the trs.
47    fn tr_map(&self, tr: &mut Tr<S>) -> Result<()>;
48
49    /// The mapper will be passed final weights as trs of the form
50    /// `FinalTr(EPS_LABEL, EPS_LABEL, weight)`.
51    fn final_tr_map(&self, final_tr: &mut FinalTr<S>) -> Result<()>;
52
53    /// Specifies final action the mapper requires (see above).
54    fn final_action(&self) -> MapFinalAction;
55
56    fn properties(&self, inprops: FstProperties) -> FstProperties;
57}
58
59impl<S: Semiring, T: TrMapper<S>, TP: Deref<Target = T>> TrMapper<S> for TP {
60    fn tr_map(&self, tr: &mut Tr<S>) -> Result<()> {
61        self.deref().tr_map(tr)
62    }
63
64    fn final_tr_map(&self, final_tr: &mut FinalTr<S>) -> Result<()> {
65        self.deref().final_tr_map(final_tr)
66    }
67
68    fn final_action(&self) -> MapFinalAction {
69        self.deref().final_action()
70    }
71
72    fn properties(&self, inprops: FstProperties) -> FstProperties {
73        self.deref().properties(inprops)
74    }
75}
76
77/// Maps every transition in the FST using an `TrMapper` object.
78pub fn tr_map<W, F, M>(ifst: &mut F, mapper: &M) -> Result<()>
79where
80    W: Semiring,
81    F: MutableFst<W>,
82    M: TrMapper<W>,
83{
84    if ifst.start().is_none() {
85        return Ok(());
86    }
87
88    let inprops = ifst.properties();
89
90    let final_action = mapper.final_action();
91    let mut superfinal: Option<StateId> = None;
92
93    if final_action == MapFinalAction::MapRequireSuperfinal {
94        let superfinal_id = ifst.add_state();
95        superfinal = Some(superfinal_id);
96        ifst.set_final(superfinal_id, W::one()).unwrap();
97    }
98
99    for state in 0..(ifst.num_states() as StateId) {
100        unsafe {
101            let mut it_tr = ifst.tr_iter_unchecked_mut(state);
102            for idx_tr in 0..it_tr.len() {
103                let mut tr = it_tr.get_unchecked(idx_tr).clone();
104                mapper.tr_map(&mut tr)?;
105                it_tr.set_tr_unchecked(idx_tr, tr);
106            }
107        }
108
109        if let Some(w) = unsafe { ifst.final_weight_unchecked(state) } {
110            let mut final_tr = FinalTr {
111                ilabel: EPS_LABEL,
112                olabel: EPS_LABEL,
113                weight: w,
114            };
115            mapper.final_tr_map(&mut final_tr)?;
116            match final_action {
117                MapFinalAction::MapNoSuperfinal => {
118                    if final_tr.ilabel != EPS_LABEL || final_tr.olabel != EPS_LABEL {
119                        bail!("TrMap: Non-zero tr labels for superfinal tr")
120                    }
121                    unsafe {
122                        ifst.set_final_unchecked(state, final_tr.weight);
123                    }
124                }
125                MapFinalAction::MapAllowSuperfinal => {
126                    if Some(state) != superfinal {
127                        if final_tr.ilabel != EPS_LABEL || final_tr.olabel != EPS_LABEL {
128                            if superfinal.is_none() {
129                                let superfinal_id = ifst.add_state();
130                                superfinal = Some(superfinal_id);
131                                unsafe {
132                                    // Checked because the state is created just above
133                                    ifst.set_final_unchecked(superfinal_id, W::one());
134                                }
135                            }
136                            unsafe {
137                                // Checked
138                                ifst.add_tr_unchecked(
139                                    state,
140                                    Tr::new(
141                                        final_tr.ilabel,
142                                        final_tr.olabel,
143                                        final_tr.weight,
144                                        superfinal.unwrap(), // Checked
145                                    ),
146                                );
147                                ifst.delete_final_weight_unchecked(state);
148                            }
149                        } else {
150                            unsafe {
151                                // Checked
152                                ifst.set_final_unchecked(state, final_tr.weight);
153                            }
154                        }
155                    }
156                }
157                MapFinalAction::MapRequireSuperfinal => {
158                    if Some(state) != superfinal
159                        && (final_tr.ilabel != EPS_LABEL
160                            || final_tr.olabel != EPS_LABEL
161                            || !final_tr.weight.is_zero())
162                    {
163                        unsafe {
164                            // checked
165                            ifst.add_tr_unchecked(
166                                state,
167                                Tr::new(
168                                    final_tr.ilabel,
169                                    final_tr.olabel,
170                                    final_tr.weight,
171                                    superfinal.unwrap(),
172                                ),
173                            );
174                            ifst.delete_final_weight_unchecked(state);
175                        }
176                    }
177                }
178            };
179        }
180    }
181
182    ifst.set_properties_with_mask(mapper.properties(inprops), FstProperties::all_properties());
183
184    Ok(())
185}