use anyhow::Result;
use crate::algorithms::fst_convert_from_ref;
use crate::algorithms::tr_mappers::QuantizeMapper;
use crate::fst_traits::{AllocableFst, Fst, FstIntoIterator, MutableFst};
use crate::semirings::{Semiring, WeightQuantize};
use crate::{Trs, KDELTA};
pub trait ExpandedFst<W: Semiring>: Fst<W> + Clone + PartialEq + FstIntoIterator<W> {
fn num_states(&self) -> usize;
fn equal_quantized<F2: ExpandedFst<W>>(&self, fst2: &F2) -> bool
where
W: WeightQuantize,
{
let n = self.num_states();
if fst2.num_states() != n {
println!("Not the same number of states");
return false;
}
if self.start() != fst2.start() {
println!("Not the same start state");
return false;
}
for state in 0..n {
let trs1 = unsafe { self.get_trs_unchecked(state) };
let trs2 = unsafe { fst2.get_trs_unchecked(state) };
if trs1.trs().len() != trs2.trs().len() {
println!("Not the same number of trs for state {:?}", state);
return false;
}
for (tr1, tr2) in trs1.trs().iter().zip(trs2.trs().iter()) {
if tr1.ilabel != tr2.ilabel
|| tr1.olabel != tr2.olabel
|| tr1.nextstate != tr2.nextstate
{
return false;
}
let w1 = tr1.weight.quantize(KDELTA).unwrap();
let w2 = tr2.weight.quantize(KDELTA).unwrap();
if w1 != w2 {
return false;
}
}
let fw1 =
unsafe { self.final_weight_unchecked(state) }.map(|w| w.quantize(KDELTA).unwrap());
let fw2 =
unsafe { fst2.final_weight_unchecked(state) }.map(|w| w.quantize(KDELTA).unwrap());
if fw1 != fw2 {
return false;
}
}
true
}
fn quantize<F2: MutableFst<W> + AllocableFst<W>>(&self) -> Result<F2>
where
W: WeightQuantize,
{
let mut fst_tr_map: F2 = fst_convert_from_ref(self);
let mut mapper = QuantizeMapper {};
fst_tr_map.tr_map(&mut mapper)?;
Ok(fst_tr_map)
}
}