1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
use unsafe_unwrap::UnsafeUnwrap;

use crate::fst_traits::MutableFst;
use crate::semirings::Semiring;
use crate::{StateId, Tr, EPS_LABEL};

/// Add, if needed, a super final state to the given FST. The super final state
/// is returned if it is possible.
///
/// # Definition
/// A super final state is a state that is the only final state in the FST with
/// a weight of `W::One()`.
///
/// # Behaviour
/// If the input FST has no final states, this algorithm will add super final state
/// that is connected to no other state.
///
/// If the input FST has only one final state with a weight of `W::One()`, this
/// algorithm will have no effect and this final state will be returned as the super
/// final state.
///
/// Otherwise, a final super state will be added to the input FST. Any final state will
/// point to this final super state where the transition weight will be their final weight.
///
pub fn add_super_final_state<W: Semiring, F: MutableFst<W>>(ifst: &mut F) -> StateId {
    let final_states = ifst.final_states_iter().collect::<Vec<_>>();
    if final_states.len() == 1
        && unsafe { ifst.final_weight_unchecked(final_states[0]) } == Some(W::one())
    {
        return final_states[0];
    }

    let super_final_state = ifst.add_state();
    unsafe {
        ifst.set_final_unchecked(super_final_state, W::one());
    }

    for final_state in final_states {
        let weight = unsafe {
            ifst.take_final_weight_unchecked(final_state)
                .unsafe_unwrap()
        };
        unsafe {
            ifst.add_tr_unchecked(
                final_state,
                Tr {
                    ilabel: EPS_LABEL,
                    olabel: EPS_LABEL,
                    weight,
                    nextstate: super_final_state,
                },
            )
        }
    }

    super_final_state
}

#[cfg(test)]
mod tests {
    use anyhow::Result;

    use crate::fst_impls::VectorFst;
    use crate::fst_traits::{CoreFst, ExpandedFst};
    use crate::semirings::TropicalWeight;

    use super::*;

    #[test]
    fn test_add_super_final_states() -> Result<()> {
        let mut fst = VectorFst::<TropicalWeight>::new();
        let s0 = fst.add_state();
        let s1 = fst.add_state();
        let s2 = fst.add_state();
        let s3 = fst.add_state();

        fst.set_start(s0)?;
        fst.emplace_tr(s0, 1, 0, 1.0, s1)?;
        fst.emplace_tr(s1, 1, 0, 1.0, s2)?;
        fst.emplace_tr(s1, 1, 0, 1.0, s3)?;

        fst.set_final(s2, 1.0)?;
        fst.set_final(s3, 1.0)?;

        let num_states = fst.num_states();

        let super_final_state = add_super_final_state(&mut fst);
        assert_eq!(num_states, super_final_state as usize);
        assert!(!fst.is_final(s2)?);
        assert_eq!(1, fst.num_trs(s2)?);
        assert!(!fst.is_final(s3)?);
        assert_eq!(1, fst.num_trs(s3)?);
        assert_eq!(
            Some(TropicalWeight::one()),
            fst.final_weight(super_final_state)?
        );
        Ok(())
    }

    #[test]
    fn test_add_super_final_states_1() -> Result<()> {
        let mut fst = VectorFst::<TropicalWeight>::new();
        let s0 = fst.add_state();
        let s1 = fst.add_state();
        let s2 = fst.add_state();
        let s3 = fst.add_state();

        fst.set_start(s0)?;
        fst.emplace_tr(s0, 1, 0, 1.0, s1)?;
        fst.emplace_tr(s1, 1, 0, 1.0, s2)?;
        fst.emplace_tr(s2, 1, 0, 1.0, s3)?;

        fst.set_final(s3, TropicalWeight::one())?;

        let super_final_state = add_super_final_state(&mut fst);
        assert_eq!(s3, super_final_state);
        assert_eq!(
            Some(TropicalWeight::one()),
            fst.final_weight(super_final_state)?
        );
        Ok(())
    }

    #[test]
    fn test_add_super_final_states_2() -> Result<()> {
        let mut fst = VectorFst::<TropicalWeight>::new();
        let s0 = fst.add_state();
        let s1 = fst.add_state();
        let s2 = fst.add_state();
        let s3 = fst.add_state();

        fst.set_start(s0)?;
        fst.emplace_tr(s0, 1, 0, 1.0, s1)?;
        fst.emplace_tr(s1, 1, 0, 1.0, s2)?;
        fst.emplace_tr(s2, 1, 0, 1.0, s3)?;

        fst.set_final(s3, 2.0)?;

        let num_states = fst.num_states();

        let super_final_state = add_super_final_state(&mut fst);
        assert_eq!(num_states, super_final_state as usize);
        assert!(!fst.is_final(s3)?);
        assert_eq!(1, fst.num_trs(s3)?);
        assert_eq!(
            Some(TropicalWeight::one()),
            fst.final_weight(super_final_state)?
        );
        Ok(())
    }

    #[test]
    fn test_add_super_final_states_3() -> Result<()> {
        let mut fst = VectorFst::<TropicalWeight>::new();
        let s0 = fst.add_state();
        let s1 = fst.add_state();
        let s2 = fst.add_state();
        let s3 = fst.add_state();

        fst.set_start(s0)?;
        fst.emplace_tr(s0, 1, 0, 1.0, s1)?;
        fst.emplace_tr(s1, 1, 0, 1.0, s2)?;
        fst.emplace_tr(s2, 1, 0, 1.0, s3)?;

        let num_states = fst.num_states();

        let super_final_state = add_super_final_state(&mut fst);
        assert_eq!(num_states, super_final_state as usize);
        assert_eq!(
            Some(TropicalWeight::one()),
            fst.final_weight(super_final_state)?
        );
        Ok(())
    }
}