use anyhow::Result;
use crate::fst_properties::mutable_properties::reweight_properties;
use crate::fst_properties::FstProperties;
use crate::fst_traits::MutableFst;
use crate::semirings::{DivideType, WeaklyDivisibleSemiring};
use crate::{StateId, Tr, EPS_LABEL};
#[derive(PartialOrd, PartialEq, Copy, Clone)]
pub enum ReweightType {
ReweightToInitial,
ReweightToFinal,
}
pub fn reweight<W, F>(fst: &mut F, potentials: &[W], reweight_type: ReweightType) -> Result<()>
where
F: MutableFst<W>,
W: WeaklyDivisibleSemiring,
{
let zero = W::zero();
let num_states = fst.num_states();
if num_states == 0 {
return Ok(());
}
for state in 0..(num_states as StateId) {
if state as usize >= potentials.len() {
match reweight_type {
ReweightType::ReweightToInitial => {}
ReweightType::ReweightToFinal => {
if let Some(final_weight) = fst.final_weight(state)? {
let new_weight = W::zero().times(final_weight)?;
fst.set_final(state, new_weight)?;
}
}
};
continue;
}
let d_s = potentials.get(state as usize).unwrap_or(&zero);
if d_s.is_zero() {
continue;
}
unsafe {
let mut it_tr = fst.tr_iter_unchecked_mut(state);
for idx_tr in 0..it_tr.len() {
let tr = it_tr.get_unchecked(idx_tr);
let d_ns = potentials.get(tr.nextstate as usize).unwrap_or(&zero);
if d_ns.is_zero() {
continue;
}
let weight = match reweight_type {
ReweightType::ReweightToInitial => {
tr.weight.times(d_ns)?.divide(d_s, DivideType::DivideLeft)?
}
ReweightType::ReweightToFinal => {
(d_s.times(&tr.weight)?).divide(d_ns, DivideType::DivideRight)?
}
};
it_tr.set_weight_unchecked(idx_tr, weight);
}
}
}
for state_id in 0..(fst.num_states() as StateId) {
if let Some(mut final_weight) = unsafe { fst.final_weight_unchecked(state_id) } {
let d_s = potentials.get(state_id as usize).unwrap_or(&zero);
match reweight_type {
ReweightType::ReweightToFinal => {
final_weight.times_assign(d_s)?;
}
ReweightType::ReweightToInitial => {
if d_s.is_zero() {
continue;
}
final_weight.divide_assign(d_s, DivideType::DivideLeft)?;
}
};
unsafe { fst.set_final_unchecked(state_id, final_weight) };
}
}
if let Some(start_state) = fst.start() {
let d_s = potentials.get(start_state as usize).unwrap_or(&zero);
if !d_s.is_one() && !d_s.is_zero() {
fst.compute_and_update_properties(FstProperties::INITIAL_ACYCLIC)?;
if fst.properties().contains(FstProperties::INITIAL_ACYCLIC) {
unsafe {
let mut it_tr = fst.tr_iter_unchecked_mut(start_state);
for idx_tr in 0..it_tr.len() {
let tr = it_tr.get_unchecked(idx_tr);
let weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.times(&tr.weight)?,
ReweightType::ReweightToFinal => (W::one()
.divide(d_s, DivideType::DivideRight)?)
.times(&tr.weight)?,
};
it_tr.set_weight_unchecked(idx_tr, weight);
}
}
if let Some(final_weight) = fst.final_weight(start_state)? {
let new_weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.times(final_weight)?,
ReweightType::ReweightToFinal => {
(W::one().divide(d_s, DivideType::DivideRight)?).times(final_weight)?
}
};
fst.set_final(start_state, new_weight)?;
}
} else {
let s = fst.add_state();
let weight = match reweight_type {
ReweightType::ReweightToInitial => d_s.clone(),
ReweightType::ReweightToFinal => {
W::one().divide(d_s, DivideType::DivideRight)?
}
};
fst.add_tr(s, Tr::new(EPS_LABEL, EPS_LABEL, weight, start_state))?;
fst.set_start(s)?;
}
}
}
fst.set_properties_with_mask(
reweight_properties(fst.properties()),
FstProperties::all_properties(),
);
Ok(())
}