rustfst/algorithms/
push.rs

1use anyhow::Result;
2
3use bitflags::bitflags;
4
5use crate::algorithms::factor_weight::factor_iterators::{GallicFactorLeft, GallicFactorRight};
6use crate::algorithms::factor_weight::{factor_weight, FactorWeightOptions, FactorWeightType};
7use crate::algorithms::fst_convert::fst_convert_from_ref;
8use crate::algorithms::tr_mappers::RmWeightMapper;
9use crate::algorithms::weight_converters::{FromGallicConverter, ToGallicConverter};
10use crate::algorithms::{
11    reweight, shortest_distance_with_config, tr_map, weight_convert, ReweightType,
12    ShortestDistanceConfig,
13};
14use crate::fst_impls::VectorFst;
15use crate::fst_traits::{AllocableFst, ExpandedFst, MutableFst};
16use crate::semirings::{DivideType, Semiring};
17use crate::semirings::{
18    GallicWeightLeft, GallicWeightRight, StringWeightLeft, StringWeightRight,
19    WeaklyDivisibleSemiring, WeightQuantize,
20};
21use crate::{StateId, KDELTA};
22
23bitflags! {
24    /// Configuration to control the behaviour of the pushing algorithm.
25    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
26    pub struct PushType: u32 {
27        const PUSH_WEIGHTS = 0b01;
28        const PUSH_LABELS = 0b10;
29        const REMOVE_TOTAL_WEIGHT = 0b100;
30        const REMOVE_COMMON_AFFIX = 0b1000;
31    }
32}
33
34/// Configuration for [`push_weights_with_config`].
35#[derive(Clone, Debug, Copy, PartialOrd, PartialEq)]
36pub struct PushWeightsConfig {
37    delta: f32,
38    remove_total_weight: bool,
39}
40
41impl Default for PushWeightsConfig {
42    fn default() -> Self {
43        Self {
44            delta: KDELTA,
45            remove_total_weight: false,
46        }
47    }
48}
49
50impl PushWeightsConfig {
51    pub fn new(delta: f32, remove_total_weight: bool) -> Self {
52        Self {
53            delta,
54            remove_total_weight,
55        }
56    }
57
58    pub fn with_delta(self, delta: f32) -> Self {
59        Self { delta, ..self }
60    }
61
62    pub fn with_remove_total_weight(self, remove_total_weight: bool) -> Self {
63        Self {
64            remove_total_weight,
65            ..self
66        }
67    }
68}
69
70/// Push the weights in an FST.
71///
72/// If pushing towards the initial state, the sum of the weight of the
73/// outgoing transitions and final weight at a non-initial state is
74/// equal to One() in the resulting machine. If pushing towards the
75/// final state, the same property holds on the reverse machine.
76pub fn push_weights<W, F>(fst: &mut F, reweight_type: ReweightType) -> Result<()>
77where
78    F: MutableFst<W>,
79    W: WeaklyDivisibleSemiring,
80{
81    push_weights_with_config(fst, reweight_type, PushWeightsConfig::default())
82}
83
84/// Push the weights in an FST, optionally removing the total weight.
85///
86/// If pushing towards the initial state, the sum of the weight of the
87/// outgoing transitions and final weight at a non-initial state is
88/// equal to One() in the resulting machine. If pushing towards the
89/// final state, the same property holds on the reverse machine.
90pub fn push_weights_with_config<W, F>(
91    fst: &mut F,
92    reweight_type: ReweightType,
93    config: PushWeightsConfig,
94) -> Result<()>
95where
96    F: MutableFst<W>,
97    W: WeaklyDivisibleSemiring,
98{
99    let remove_total_weight = config.remove_total_weight;
100    let delta = config.delta;
101    let dist = shortest_distance_with_config(
102        fst,
103        reweight_type == ReweightType::ReweightToInitial,
104        ShortestDistanceConfig::new(delta),
105    )?;
106
107    if remove_total_weight {
108        let total_weight =
109            compute_total_weight(fst, &dist, reweight_type == ReweightType::ReweightToInitial)?;
110        reweight(fst, &dist, reweight_type)?;
111        remove_weight(
112            fst,
113            total_weight,
114            reweight_type == ReweightType::ReweightToFinal,
115        )?;
116    } else {
117        reweight(fst, &dist, reweight_type)?;
118    }
119    Ok(())
120}
121
122fn compute_total_weight<W, F>(fst: &F, dist: &[W], reverse: bool) -> Result<W>
123where
124    W: Semiring,
125    F: ExpandedFst<W>,
126{
127    if reverse {
128        Ok(fst
129            .start()
130            .and_then(|start| dist.get(start as usize))
131            .cloned()
132            .unwrap_or_else(W::zero))
133    } else {
134        let mut sum = W::zero();
135        for (s, dist_s) in dist.iter().enumerate() {
136            sum.plus_assign(dist_s.times(
137                unsafe { fst.final_weight_unchecked(s as StateId) }.unwrap_or_else(W::zero),
138            )?)?;
139        }
140        Ok(sum)
141    }
142}
143
144fn remove_weight<W, F>(fst: &mut F, weight: W, at_final: bool) -> Result<()>
145where
146    F: MutableFst<W>,
147    W: WeaklyDivisibleSemiring,
148{
149    if weight.is_one() || weight.is_zero() {
150        return Ok(());
151    }
152    if at_final {
153        unsafe {
154            for s in fst.states_range() {
155                if let Some(mut final_weight) = fst.final_weight_unchecked(s) {
156                    final_weight.divide_assign(&weight, DivideType::DivideRight)?;
157                    fst.set_final_unchecked(s, final_weight);
158                }
159            }
160        }
161    } else if let Some(start) = fst.start() {
162        unsafe {
163            let mut it_tr = fst.tr_iter_unchecked_mut(start);
164            for idx_tr in 0..it_tr.len() {
165                let tr = it_tr.get_unchecked(idx_tr);
166                let weight = tr.weight.divide(&weight, DivideType::DivideLeft)?;
167                it_tr.set_weight_unchecked(idx_tr, weight);
168            }
169            if let Some(mut final_weight) = fst.final_weight_unchecked(start) {
170                final_weight.divide_assign(&weight, DivideType::DivideLeft)?;
171                fst.set_final_unchecked(start, final_weight);
172            }
173        }
174    }
175    Ok(())
176}
177
178macro_rules! m_labels_pushing {
179    ($ifst: ident, $reweight_type: ident, $push_type: ident, $delta: ident, $gallic_weight: ty, $string_weight: ident, $gallic_factor: ty) => {{
180        // Labels pushing with potentially weights pushing
181        let mut mapper = ToGallicConverter {};
182        let mut gfst: VectorFst<$gallic_weight> = weight_convert($ifst, &mut mapper)?;
183        let gdistance = if $push_type.intersects(PushType::PUSH_WEIGHTS) {
184            shortest_distance_with_config(
185                &gfst,
186                $reweight_type == ReweightType::ReweightToInitial,
187                ShortestDistanceConfig::new($delta),
188            )?
189        } else {
190            let rm_weight_mapper = RmWeightMapper {};
191            let mut uwfst: VectorFst<_> = fst_convert_from_ref($ifst);
192            tr_map(&mut uwfst, &rm_weight_mapper)?;
193            let guwfst: VectorFst<$gallic_weight> = weight_convert(&uwfst, &mut mapper)?;
194            shortest_distance_with_config(
195                &guwfst,
196                $reweight_type == ReweightType::ReweightToInitial,
197                ShortestDistanceConfig::new($delta),
198            )?
199        };
200        if $push_type.intersects(PushType::REMOVE_COMMON_AFFIX | PushType::REMOVE_TOTAL_WEIGHT) {
201            let mut total_weight = compute_total_weight(
202                &gfst,
203                &gdistance,
204                $reweight_type == ReweightType::ReweightToInitial,
205            )?;
206            if !$push_type.intersects(PushType::REMOVE_COMMON_AFFIX) {
207                total_weight.set_value1($string_weight::one());
208            }
209            if !$push_type.intersects(PushType::REMOVE_TOTAL_WEIGHT) {
210                total_weight.set_value2(W::one());
211            }
212            reweight(&mut gfst, gdistance.as_slice(), $reweight_type)?;
213            remove_weight(
214                &mut gfst,
215                total_weight,
216                $reweight_type == ReweightType::ReweightToFinal,
217            )?;
218        } else {
219            reweight(&mut gfst, gdistance.as_slice(), $reweight_type)?;
220        }
221        let fwfst: VectorFst<$gallic_weight> =
222            factor_weight::<_, VectorFst<$gallic_weight>, _, _, $gallic_factor>(
223                &gfst,
224                FactorWeightOptions::new(
225                    FactorWeightType::FACTOR_FINAL_WEIGHTS | FactorWeightType::FACTOR_ARC_WEIGHTS,
226                ),
227            )?;
228        let mut mapper_from_gallic = FromGallicConverter {
229            superfinal_label: 0,
230        };
231        weight_convert(&fwfst, &mut mapper_from_gallic)
232    }};
233}
234
235/// Configuration for [`push_with_config`].
236#[derive(Clone, Copy, Debug, PartialOrd, PartialEq)]
237pub struct PushConfig {
238    delta: f32,
239}
240
241impl Default for PushConfig {
242    fn default() -> Self {
243        Self { delta: KDELTA }
244    }
245}
246
247impl PushConfig {
248    pub fn new(delta: f32) -> Self {
249        Self { delta }
250    }
251
252    pub fn with_delta(self, delta: f32) -> Self {
253        Self { delta }
254    }
255}
256
257/// Push the weights and/or labels of the input FST into the output
258/// mutable FST by pushing weights and/or labels towards the initial state or final states.
259pub fn push<W, F1, F2>(ifst: &F1, reweight_type: ReweightType, push_type: PushType) -> Result<F2>
260where
261    F1: ExpandedFst<W>,
262    F2: ExpandedFst<W> + MutableFst<W> + AllocableFst<W>,
263    W: WeaklyDivisibleSemiring + WeightQuantize,
264    <W as Semiring>::ReverseWeight: 'static,
265{
266    push_with_config(ifst, reweight_type, push_type, PushConfig::default())
267}
268
269/// Push the weights and/or labels of the input FST into the output
270/// mutable FST by pushing weights and/or labels towards the initial state or final states.
271pub fn push_with_config<W, F1, F2>(
272    ifst: &F1,
273    reweight_type: ReweightType,
274    push_type: PushType,
275    config: PushConfig,
276) -> Result<F2>
277where
278    F1: ExpandedFst<W>,
279    F2: ExpandedFst<W> + MutableFst<W> + AllocableFst<W>,
280    W: WeaklyDivisibleSemiring + WeightQuantize,
281    <W as Semiring>::ReverseWeight: 'static,
282{
283    let delta = config.delta;
284    if push_type.intersects(PushType::PUSH_WEIGHTS) && !push_type.intersects(PushType::PUSH_LABELS)
285    {
286        // Only weights pushing
287        let mut ofst = fst_convert_from_ref(ifst);
288        let push_weights_config =
289            PushWeightsConfig::new(delta, push_type.intersects(PushType::REMOVE_TOTAL_WEIGHT));
290        push_weights_with_config(&mut ofst, reweight_type, push_weights_config)?;
291        Ok(ofst)
292    } else if push_type.intersects(PushType::PUSH_LABELS) {
293        match reweight_type {
294            ReweightType::ReweightToInitial => m_labels_pushing!(
295                ifst,
296                reweight_type,
297                push_type,
298                delta,
299                GallicWeightLeft<W>,
300                StringWeightLeft,
301                GallicFactorLeft<W>
302            ),
303            ReweightType::ReweightToFinal => m_labels_pushing!(
304                ifst,
305                reweight_type,
306                push_type,
307                delta,
308                GallicWeightRight<W>,
309                StringWeightRight,
310                GallicFactorRight<W>
311            ),
312        }
313    } else {
314        // NO Labels/Weights pushing
315        Ok(fst_convert_from_ref(ifst))
316    }
317}