use crate::semiring::{DivisibleSemiring, Semiring};
use crate::wfst::{MutableWfst, StateId, WeightedTransition, Wfst, NO_STATE};
use super::shortest_distance::{
reverse_shortest_distance, single_source_shortest_distance, ShortestDistanceConfig,
};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum PushDirection {
Forward,
#[default]
Backward,
}
#[derive(Clone, Debug)]
pub struct PushConfig {
pub direction: PushDirection,
pub remove_non_coaccessible: bool,
pub distance_config: ShortestDistanceConfig,
}
impl Default for PushConfig {
fn default() -> Self {
Self {
direction: PushDirection::Backward,
remove_non_coaccessible: true,
distance_config: ShortestDistanceConfig::default(),
}
}
}
impl PushConfig {
pub fn forward() -> Self {
Self {
direction: PushDirection::Forward,
..Default::default()
}
}
pub fn backward() -> Self {
Self {
direction: PushDirection::Backward,
..Default::default()
}
}
pub fn log_semiring() -> Self {
Self {
direction: PushDirection::Backward,
remove_non_coaccessible: true,
distance_config: ShortestDistanceConfig::default(),
}
}
}
pub fn push_weights<L, W, F>(fst: &mut F, config: PushConfig) -> Result<(), PushError>
where
L: Clone,
W: DivisibleSemiring,
F: MutableWfst<L, W> + Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return Ok(());
}
if fst.start() == NO_STATE {
return Err(PushError::NoStartState);
}
let potentials = match config.direction {
PushDirection::Forward => {
single_source_shortest_distance(fst, config.distance_config.clone())
.ok_or(PushError::NoPotentials)?
}
PushDirection::Backward => reverse_shortest_distance(fst, config.distance_config.clone())
.ok_or(PushError::NoPotentials)?,
};
if potentials.is_empty() {
return Err(PushError::NoPotentials);
}
match config.direction {
PushDirection::Forward => push_forward_impl(fst, &potentials),
PushDirection::Backward => push_backward_impl(fst, &potentials),
}
Ok(())
}
fn push_forward_impl<L, W, F>(fst: &mut F, potentials: &[W])
where
L: Clone,
W: DivisibleSemiring,
F: MutableWfst<L, W> + Wfst<L, W>,
{
let n = fst.num_states();
let mut new_transitions: Vec<Vec<WeightedTransition<L, W>>> = vec![Vec::new(); n];
for state in 0..n {
let state_id = state as StateId;
let p_from = &potentials[state];
if p_from.is_zero() {
continue;
}
for trans in fst.transitions(state_id).to_vec() {
let to_idx = trans.to as usize;
if to_idx >= potentials.len() {
continue;
}
let p_to = &potentials[to_idx];
let p_from_inv = W::one().divide(p_from).unwrap_or_else(W::one);
let new_weight = p_from_inv.times(&trans.weight).times(p_to);
new_transitions[state].push(WeightedTransition {
from: trans.from,
to: trans.to,
input: trans.input,
output: trans.output,
weight: new_weight,
});
}
}
for state in 0..n {
let state_id = state as StateId;
fst.clear_transitions(state_id);
for trans in new_transitions[state].drain(..) {
fst.add_transition(trans);
}
}
for state in 0..n {
let state_id = state as StateId;
if fst.is_final(state_id) {
let p = &potentials[state];
if !p.is_zero() {
let old_final = fst.final_weight(state_id);
let p_inv = W::one().divide(p).unwrap_or_else(W::one);
let new_final = p_inv.times(&old_final);
fst.set_final(state_id, new_final);
}
}
}
}
fn push_backward_impl<L, W, F>(fst: &mut F, potentials: &[W])
where
L: Clone,
W: DivisibleSemiring,
F: MutableWfst<L, W> + Wfst<L, W>,
{
let n = fst.num_states();
let mut new_transitions: Vec<Vec<WeightedTransition<L, W>>> = vec![Vec::new(); n];
for state in 0..n {
let state_id = state as StateId;
let p_from = &potentials[state];
if p_from.is_zero() {
continue;
}
for trans in fst.transitions(state_id).to_vec() {
let to_idx = trans.to as usize;
if to_idx >= potentials.len() {
continue;
}
let p_to = &potentials[to_idx];
if p_to.is_zero() {
continue;
}
let p_from_inv = W::one().divide(p_from).unwrap_or_else(W::one);
let new_weight = trans.weight.times(p_to).times(&p_from_inv);
new_transitions[state].push(WeightedTransition {
from: trans.from,
to: trans.to,
input: trans.input,
output: trans.output,
weight: new_weight,
});
}
}
for state in 0..n {
let state_id = state as StateId;
fst.clear_transitions(state_id);
for trans in new_transitions[state].drain(..) {
fst.add_transition(trans);
}
}
let start = fst.start();
if start != NO_STATE {
let start_idx = start as usize;
if start_idx < potentials.len() {
let start_potential = &potentials[start_idx];
for state in 0..n {
let state_id = state as StateId;
if fst.is_final(state_id) {
fst.set_final(state_id, W::one());
}
}
let _ = start_potential; }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PushError {
NoStartState,
NoPotentials,
DivisionByZero,
}
impl std::fmt::Display for PushError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoStartState => write!(f, "WFST has no start state"),
Self::NoPotentials => write!(f, "Could not compute potentials"),
Self::DivisionByZero => write!(f, "Division by zero during reweighting"),
}
}
}
impl std::error::Error for PushError {}
pub fn is_stochastic<L, W, F>(fst: &F, epsilon: f64) -> bool
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
for state in 0..fst.num_states() {
let state_id = state as StateId;
let mut total = fst.final_weight(state_id);
for trans in fst.transitions(state_id) {
total = total.plus(&trans.weight);
}
if !total.approx_eq(&W::one(), epsilon) {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{VectorWfst, VectorWfstBuilder};
mod property_tests {
use super::*;
use crate::test_utils::arb_acyclic_wfst_tropical;
use proptest::prelude::*;
proptest! {
#[test]
fn push_preserves_state_count(
fst in arb_acyclic_wfst_tropical(8, 3)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let original_states = fst.num_states();
let mut pushed_fst = fst.clone();
let result = push_weights(&mut pushed_fst, PushConfig::backward());
if result.is_ok() {
prop_assert_eq!(
pushed_fst.num_states(),
original_states,
"Push changed state count from {} to {}",
original_states,
pushed_fst.num_states()
);
}
}
#[test]
fn push_preserves_transitions(
fst in arb_acyclic_wfst_tropical(6, 2)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let original_arc_count: usize = (0..fst.num_states())
.map(|s| fst.transitions(s as StateId).len())
.sum();
let mut pushed_fst = fst.clone();
let config = PushConfig {
direction: PushDirection::Forward,
remove_non_coaccessible: false,
distance_config: ShortestDistanceConfig::default(),
};
let result = push_weights(&mut pushed_fst, config);
if result.is_ok() {
let new_arc_count: usize = (0..pushed_fst.num_states())
.map(|s| pushed_fst.transitions(s as StateId).len())
.sum();
prop_assert!(
new_arc_count <= original_arc_count,
"Push increased arc count from {} to {}",
original_arc_count,
new_arc_count
);
}
}
#[test]
fn push_both_directions_valid(
fst in arb_acyclic_wfst_tropical(6, 2)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let mut forward_fst = fst.clone();
let forward_result = push_weights(&mut forward_fst, PushConfig::forward());
let mut backward_fst = fst.clone();
let backward_result = push_weights(&mut backward_fst, PushConfig::backward());
if forward_result.is_ok() {
prop_assert!(forward_fst.start() != NO_STATE || fst.num_states() == 0);
}
if backward_result.is_ok() {
prop_assert!(backward_fst.start() != NO_STATE || fst.num_states() == 0);
}
}
#[test]
fn push_empty_succeeds(_seed in 0u32..100) {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let result = push_weights(&mut fst, PushConfig::backward());
prop_assert!(result.is_ok());
}
#[test]
fn push_no_start_fails(_seed in 0u32..100) {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
fst.add_state();
let result = push_weights(&mut fst, PushConfig::backward());
prop_assert!(matches!(result, Err(PushError::NoStartState)));
}
}
}
fn build_simple_chain() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(1, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.final_state(2, TropicalWeight::new(0.5))
.build()
}
fn build_diamond() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(4)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(0, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.arc(1, Some('c'), Some('c'), 3, TropicalWeight::new(1.0))
.arc(2, Some('d'), Some('d'), 3, TropicalWeight::new(1.0))
.final_state(3, TropicalWeight::one())
.build()
}
#[test]
fn test_push_empty_fst() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let result = push_weights(&mut fst, PushConfig::backward());
assert!(result.is_ok());
}
#[test]
fn test_push_no_start() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
fst.add_state();
let result = push_weights(&mut fst, PushConfig::backward());
assert_eq!(result, Err(PushError::NoStartState));
}
#[test]
fn test_push_backward_chain() {
let mut fst = build_simple_chain();
let potentials = reverse_shortest_distance(&fst, ShortestDistanceConfig::default())
.expect("Should compute potentials");
let initial_potential = potentials[fst.start() as usize].clone();
let original_total = TropicalWeight::new(3.5);
let result = push_weights(&mut fst, PushConfig::backward());
assert!(result.is_ok());
let start = fst.start();
assert_ne!(start, NO_STATE);
assert_eq!(fst.num_states(), 3);
assert_eq!(fst.transitions(0).len(), 1);
assert_eq!(fst.transitions(1).len(), 1);
assert!(
fst.final_weight(2).approx_eq(&TropicalWeight::one(), 0.001),
"Final weight should be one after backward push, got {:?}",
fst.final_weight(2)
);
let mut normalized_path = TropicalWeight::one();
let mut current = start;
while !fst.transitions(current).is_empty() {
let trans = &fst.transitions(current)[0];
normalized_path = normalized_path.times(&trans.weight);
current = trans.to;
}
normalized_path = normalized_path.times(&fst.final_weight(current));
let reconstructed = initial_potential.times(&normalized_path);
assert!(
reconstructed.approx_eq(&original_total, 0.1),
"V(i) ⊗ normalized_path should equal original: {:?} ⊗ {:?} = {:?}, expected {:?}",
initial_potential,
normalized_path,
reconstructed,
original_total
);
}
#[test]
fn test_push_backward_diamond() {
let mut fst = build_diamond();
let result = push_weights(&mut fst, PushConfig::backward());
assert!(result.is_ok());
assert_eq!(fst.num_states(), 4);
assert_ne!(fst.start(), NO_STATE);
assert!(fst.is_final(3));
}
#[test]
fn test_push_forward_chain() {
let mut fst = build_simple_chain();
let original_total = TropicalWeight::new(3.5);
let result = push_weights(&mut fst, PushConfig::forward());
assert!(result.is_ok());
let start = fst.start();
assert_ne!(start, NO_STATE);
assert_eq!(fst.num_states(), 3);
assert_eq!(fst.transitions(0).len(), 1);
assert_eq!(fst.transitions(1).len(), 1);
let mut total = TropicalWeight::one();
let mut current = start;
while !fst.transitions(current).is_empty() {
let trans = &fst.transitions(current)[0];
total = total.times(&trans.weight);
current = trans.to;
}
total = total.times(&fst.final_weight(current));
assert!(
total.approx_eq(&original_total, 0.1),
"Expected ~{:?}, got {:?}",
original_total,
total
);
}
#[test]
fn test_push_config_defaults() {
let config = PushConfig::default();
assert_eq!(config.direction, PushDirection::Backward);
assert!(config.remove_non_coaccessible);
let forward = PushConfig::forward();
assert_eq!(forward.direction, PushDirection::Forward);
let backward = PushConfig::backward();
assert_eq!(backward.direction, PushDirection::Backward);
}
#[test]
fn test_push_error_display() {
assert_eq!(
PushError::NoStartState.to_string(),
"WFST has no start state"
);
assert_eq!(
PushError::NoPotentials.to_string(),
"Could not compute potentials"
);
assert_eq!(
PushError::DivisionByZero.to_string(),
"Division by zero during reweighting"
);
}
}