rustfst/algorithms/factor_weight/
config.rs

1use bitflags::bitflags;
2
3use crate::{Label, KDELTA};
4
5bitflags! {
6    /// What kind of weight should be factored ? Tr weight ? Final weights ?
7    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
8    pub struct FactorWeightType: u32 {
9        /// Factor weights located on the Trs.
10        const FACTOR_FINAL_WEIGHTS = 0b01;
11        /// Factor weights located in the final states.
12        const FACTOR_ARC_WEIGHTS = 0b10;
13    }
14}
15
16#[cfg(test)]
17impl FactorWeightType {
18    pub fn from_bools(factor_final_weights: bool, factor_tr_weights: bool) -> FactorWeightType {
19        match (factor_final_weights, factor_tr_weights) {
20            (true, true) => {
21                FactorWeightType::FACTOR_FINAL_WEIGHTS | FactorWeightType::FACTOR_ARC_WEIGHTS
22            }
23            (true, false) => FactorWeightType::FACTOR_FINAL_WEIGHTS,
24            (false, true) => FactorWeightType::FACTOR_ARC_WEIGHTS,
25            (false, false) => Self::empty(),
26        }
27    }
28}
29
30/// Configuration to control the behaviour of the `factor_weight` algorithm.
31#[derive(Clone, Debug, PartialEq)]
32pub struct FactorWeightOptions {
33    /// Quantization delta
34    pub delta: f32,
35    /// Factor transition weights and/or final weights
36    pub mode: FactorWeightType,
37    /// Input label of transition when factoring final weights.
38    pub final_ilabel: Label,
39    /// Output label of transition when factoring final weights.
40    pub final_olabel: Label,
41    /// When factoring final w' results in > 1 trs at state, increments ilabels to make distinct ?
42    pub increment_final_ilabel: bool,
43    /// When factoring final w' results in > 1 trs at state, increments olabels to make distinct ?
44    pub increment_final_olabel: bool,
45}
46
47impl FactorWeightOptions {
48    #[allow(unused)]
49    pub fn new(mode: FactorWeightType) -> FactorWeightOptions {
50        FactorWeightOptions {
51            delta: KDELTA,
52            mode,
53            final_ilabel: 0,
54            final_olabel: 0,
55            increment_final_ilabel: false,
56            increment_final_olabel: false,
57        }
58    }
59}