1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use anyhow::Result;

use crate::algorithms::{FinalTr, MapFinalAction};
use crate::fst_properties::FstProperties;
use crate::fst_traits::{AllocableFst, ExpandedFst, MutableFst};
use crate::semirings::Semiring;
use crate::{Tr, Trs, EPS_LABEL};
use unsafe_unwrap::UnsafeUnwrap;

/// The WeightConverter interfaces defines how a weight should be turned into another one.
/// Useful for changing the semiring of an FST.
pub trait WeightConverter<SI: Semiring, SO: Semiring> {
    fn tr_map(&mut self, tr: &Tr<SI>) -> Result<Tr<SO>>;
    fn final_tr_map(&mut self, final_tr: &FinalTr<SI>) -> Result<FinalTr<SO>>;
    fn final_action(&self) -> MapFinalAction;
    fn properties(&self, iprops: FstProperties) -> FstProperties;
}

/// Convert an FST in a given Semiring to another Semiring using a WeightConverter
/// to specify how the conversion should be performed.
pub fn weight_convert<W1, W2, F1, F2, M>(fst_in: &F1, mapper: &mut M) -> Result<F2>
where
    W1: Semiring,
    W2: Semiring,
    F1: ExpandedFst<W1>,
    F2: MutableFst<W2> + AllocableFst<W2>,
    M: WeightConverter<W1, W2>,
{
    let iprops = fst_in.properties();
    let mut fst_out = F2::new();
    let final_action = mapper.final_action();

    // Empty FST.
    if fst_in.start().is_none() {
        return Ok(fst_out);
    }

    // Reserve enough space for all the states to avoid re-allocations.
    let mut num_states_needed = fst_in.num_states();
    if !(final_action == MapFinalAction::MapNoSuperfinal) {
        num_states_needed += 1;
    }
    fst_out.reserve_states(num_states_needed);

    // Add all the states from the input fst to the output fst.
    for _ in fst_in.states_iter() {
        fst_out.add_state();
    }

    // Set superfinal states as final.
    let mut superfinal = None;
    if final_action == MapFinalAction::MapRequireSuperfinal {
        superfinal = Some(fst_out.add_state());
        fst_out.set_final(superfinal.unwrap(), W2::one())?;
    }

    if let Some(start_state) = fst_in.start() {
        fst_out.set_start(start_state)?;
    }

    let states: Vec<_> = fst_in.states_iter().collect();
    for state in states {
        fst_out.reserve_trs(state, fst_in.num_trs(state)?)?;
        for tr in fst_in.get_trs(state)?.trs() {
            fst_out.add_tr(state, mapper.tr_map(tr)?)?;
        }
        if let Some(w) = unsafe { fst_in.final_weight_unchecked(state) } {
            let final_tr = FinalTr {
                ilabel: EPS_LABEL,
                olabel: EPS_LABEL,
                weight: w.clone(),
            };
            let mapped_final_tr = mapper.final_tr_map(&final_tr)?;
            match final_action {
                MapFinalAction::MapNoSuperfinal => {
                    if mapped_final_tr.ilabel != EPS_LABEL || mapped_final_tr.olabel != EPS_LABEL {
                        bail!("TrMap: Non-zero tr labels for superfinal tr")
                    }

                    fst_out.set_final(state, mapped_final_tr.weight).unwrap();
                }
                MapFinalAction::MapAllowSuperfinal => {
                    if mapped_final_tr.ilabel != EPS_LABEL || mapped_final_tr.olabel != EPS_LABEL {
                        if superfinal.is_none() {
                            let superfinal_id = fst_out.add_state();
                            superfinal = Some(superfinal_id);
                            fst_out.set_final(superfinal_id, W2::one()).unwrap();
                        }

                        fst_out.add_tr(
                            state,
                            Tr::new(
                                mapped_final_tr.ilabel,
                                mapped_final_tr.olabel,
                                mapped_final_tr.weight,
                                unsafe { superfinal.unsafe_unwrap() },
                            ),
                        )?;

                        fst_out.delete_final_weight(state)?;
                    } else {
                        fst_out.set_final(state, mapped_final_tr.weight)?;
                    }
                }
                MapFinalAction::MapRequireSuperfinal => {
                    if mapped_final_tr.ilabel != EPS_LABEL
                        || mapped_final_tr.olabel != EPS_LABEL
                        || !mapped_final_tr.weight.is_zero()
                    {
                        fst_out
                            .add_tr(
                                state,
                                Tr::new(
                                    mapped_final_tr.ilabel,
                                    mapped_final_tr.olabel,
                                    mapped_final_tr.weight,
                                    superfinal.unwrap(),
                                ),
                            )
                            .unwrap();
                    }
                    fst_out.delete_final_weight(state).unwrap();
                }
            }
        }
    }

    let oprops = fst_out.properties();
    fst_out.set_properties_with_mask(
        mapper.properties(iprops) | oprops,
        FstProperties::all_properties(),
    );

    Ok(fst_out)
}