use anyhow::Result;
use crate::fst_properties::mutable_properties::reweight_properties;
use crate::fst_properties::FstProperties;
use crate::fst_traits::MutableFst;
use crate::semirings::{DivideType, WeaklyDivisibleSemiring};
#[derive(PartialOrd, PartialEq, Copy, Clone)]
pub enum ReweightType {
ReweightToInitial,
ReweightToFinal,
}
pub fn reweight<W, F>(fst: &mut F, potentials: &[W], reweight_type: ReweightType) -> Result<()>
where
F: MutableFst<W>,
W: WeaklyDivisibleSemiring,
{
let zero = W::zero();
let num_states = fst.num_states();
if num_states == 0 {
return Ok(());
}
for state in 0..num_states {
if state >= potentials.len() {
match reweight_type {
ReweightType::ReweightToInitial => {}
ReweightType::ReweightToFinal => {
if let Some(final_weight) = fst.final_weight(state)? {
let new_weight = W::zero().times(final_weight)?;
fst.set_final(state, new_weight)?;
}
}
};
continue;
}
let d_s = potentials.get(state).unwrap_or(&zero);
if d_s.is_zero() {
continue;
}
unsafe {
let mut it_tr = fst.tr_iter_unchecked_mut(state);
for idx_tr in 0..it_tr.len() {
let tr = it_tr.get_unchecked(idx_tr);
let d_ns = potentials.get(tr.nextstate).unwrap_or(&zero);
if d_ns.is_zero() {
continue;
}
let weight = match reweight_type {
ReweightType::ReweightToInitial => {
(&tr.weight.times(d_ns)?).divide(d_s, DivideType::DivideLeft)?
}
ReweightType::ReweightToFinal => {
(d_s.times(&tr.weight)?).divide(&d_ns, DivideType::DivideRight)?
}
};
it_tr.set_weight_unchecked(idx_tr, weight);
}
}
}
for state_id in 0..fst.num_states() {
if let Some(mut final_weight) = unsafe { fst.final_weight_unchecked(state_id) } {
let d_s = potentials.get(state_id).unwrap_or(&zero);
match reweight_type {
ReweightType::ReweightToFinal => {
final_weight.times_assign(d_s)?;
}
ReweightType::ReweightToInitial => {
if d_s.is_zero() {
continue;
}
final_weight.divide_assign(d_s, DivideType::DivideLeft)?;
}
};
unsafe { fst.set_final_unchecked(state_id, final_weight) };
}
}
if let Some(start_state) = fst.start() {
let d_s = potentials.get(start_state).unwrap_or(&zero);
if !d_s.is_one() && !d_s.is_zero() {
unsafe {
let mut it_tr = fst.tr_iter_unchecked_mut(start_state);
for idx_tr in 0..it_tr.len() {
let tr = it_tr.get_unchecked(idx_tr);
let weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.times(&tr.weight)?,
ReweightType::ReweightToFinal => {
(W::one().divide(&d_s, DivideType::DivideRight)?).times(&tr.weight)?
}
};
it_tr.set_weight_unchecked(idx_tr, weight);
}
}
if let Some(final_weight) = fst.final_weight(start_state)? {
let new_weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.times(final_weight)?,
ReweightType::ReweightToFinal => {
(W::one().divide(&d_s, DivideType::DivideRight)?).times(final_weight)?
}
};
fst.set_final(start_state, new_weight)?;
}
}
}
fst.set_properties_with_mask(
reweight_properties(fst.properties()),
FstProperties::all_properties(),
);
Ok(())
}