use failure::Fallible;
use crate::fst_traits::{ExpandedFst, FinalStatesIterator, Fst, MutableFst};
use crate::semirings::{DivideType, Semiring, WeaklyDivisibleSemiring};
pub enum ReweightType {
ReweightToInitial,
ReweightToFinal,
}
macro_rules! state_to_dist {
($state: expr, $dist: expr) => {
$dist
.get($state)
.ok_or_else(|| format_err!("State {} not in dists array", $state))?;
};
}
pub fn reweight<F>(fst: &mut F, potentials: &[F::W], reweight_type: ReweightType) -> Fallible<()>
where
F: Fst + ExpandedFst + MutableFst,
F::W: WeaklyDivisibleSemiring,
{
let num_states = fst.num_states();
if num_states == 0 {
return Ok(());
}
for state in 0..num_states {
if state >= potentials.len() {
match reweight_type {
ReweightType::ReweightToInitial => {}
ReweightType::ReweightToFinal => {
if let Some(final_weight) = fst.final_weight(state) {
let new_weight = F::W::zero().times(&final_weight)?;
fst.set_final(state, new_weight)?;
}
}
};
continue;
}
let d_s = state_to_dist!(state, potentials);
if d_s.is_zero() {
continue;
}
for arc in fst.arcs_iter_mut(state)? {
let d_ns = state_to_dist!(arc.nextstate, potentials);
if d_ns.is_zero() {
continue;
}
arc.weight = match reweight_type {
ReweightType::ReweightToInitial => {
(&arc.weight.times(d_ns)?).divide(d_s, DivideType::DivideLeft)?
}
ReweightType::ReweightToFinal => {
(d_s.times(&arc.weight)?).divide(&d_ns, DivideType::DivideRight)?
}
};
}
}
let final_states: Vec<_> = fst.final_states_iter().collect();
for final_state in final_states {
let d_s = state_to_dist!(final_state.state_id, potentials);
match reweight_type {
ReweightType::ReweightToFinal => {
let new_weight = d_s.times(&final_state.final_weight)?;
fst.set_final(final_state.state_id, new_weight)?;
}
ReweightType::ReweightToInitial => {
if d_s.is_zero() {
continue;
}
let new_weight =
(&final_state.final_weight).divide(&d_s, DivideType::DivideLeft)?;
fst.set_final(final_state.state_id, new_weight)?;
}
};
}
if let Some(start_state) = fst.start() {
let d_s = state_to_dist!(start_state, potentials);
if !d_s.is_one() && !d_s.is_zero() {
for arc in fst.arcs_iter_mut(start_state)? {
arc.weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.times(&arc.weight)?,
ReweightType::ReweightToFinal => {
(F::W::one().divide(&d_s, DivideType::DivideRight)?).times(&arc.weight)?
}
};
}
if let Some(final_weight) = fst.final_weight(start_state) {
let new_weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.times(&final_weight)?,
ReweightType::ReweightToFinal => {
(F::W::one().divide(&d_s, DivideType::DivideRight)?).times(&final_weight)?
}
};
fst.set_final(start_state, new_weight)?;
}
}
}
Ok(())
}