use crate::semiring::{DivisibleSemiring, LogWeight, Semiring};
use crate::wfst::{MutableWfst, StateId, WeightedTransition, Wfst, NO_STATE};
#[derive(Clone, Debug)]
pub struct LogPushConfig {
pub verify_stochastic: bool,
pub stochastic_epsilon: f64,
pub normalize_finals: bool,
}
impl Default for LogPushConfig {
fn default() -> Self {
Self {
verify_stochastic: false,
stochastic_epsilon: 1e-6,
normalize_finals: true,
}
}
}
impl LogPushConfig {
pub fn verified() -> Self {
Self {
verify_stochastic: true,
..Default::default()
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct BeamSearchPrepResult {
pub pushed: bool,
pub total_weight: LogWeight,
pub is_stochastic: Option<bool>,
pub num_states: usize,
pub num_transitions: usize,
}
pub fn prepare_for_beam_search<L, F>(
fst: &mut F,
config: LogPushConfig,
) -> Result<BeamSearchPrepResult, LogPushError>
where
L: Clone,
F: MutableWfst<L, LogWeight> + Wfst<L, LogWeight>,
{
let n = fst.num_states();
if n == 0 {
return Ok(BeamSearchPrepResult {
pushed: false,
total_weight: LogWeight::zero(),
is_stochastic: Some(true),
num_states: 0,
num_transitions: 0,
});
}
if fst.start() == NO_STATE {
return Err(LogPushError::NoStartState);
}
let num_transitions: usize = (0..n).map(|s| fst.transitions(s as StateId).len()).sum();
let potentials = compute_log_potentials(fst)?;
let start = fst.start() as usize;
let total_weight = if start < potentials.len() {
potentials[start].clone()
} else {
LogWeight::zero()
};
apply_log_push(fst, &potentials, config.normalize_finals)?;
let is_stochastic = if config.verify_stochastic {
Some(verify_stochastic(fst, config.stochastic_epsilon))
} else {
None
};
Ok(BeamSearchPrepResult {
pushed: true,
total_weight,
is_stochastic,
num_states: n,
num_transitions,
})
}
pub fn compute_log_potentials<L, F>(fst: &F) -> Result<Vec<LogWeight>, LogPushError>
where
L: Clone,
F: Wfst<L, LogWeight>,
{
let n = fst.num_states();
if n == 0 {
return Ok(Vec::new());
}
if fst.start() == NO_STATE {
return Err(LogPushError::NoStartState);
}
let mut potentials = vec![LogWeight::zero(); n];
for state in 0..n {
let state_id = state as StateId;
if fst.is_final(state_id) {
potentials[state] = fst.final_weight(state_id);
}
}
let topo_order = compute_topological_order(fst);
for &state in topo_order.iter().rev() {
let state_idx = state as usize;
for trans in fst.transitions(state) {
let to_idx = trans.to as usize;
if to_idx >= n {
continue;
}
let contribution = trans.weight.times(&potentials[to_idx]);
potentials[state_idx] = potentials[state_idx].plus(&contribution);
}
}
let start = fst.start() as usize;
if start < n && potentials[start].is_zero() {
return Err(LogPushError::NoPathToFinal);
}
Ok(potentials)
}
pub fn apply_log_push<L, F>(
fst: &mut F,
potentials: &[LogWeight],
normalize_finals: bool,
) -> Result<(), LogPushError>
where
L: Clone,
F: MutableWfst<L, LogWeight> + Wfst<L, LogWeight>,
{
let n = fst.num_states();
if n == 0 {
return Ok(());
}
let mut new_transitions: Vec<Vec<WeightedTransition<L, LogWeight>>> = vec![Vec::new(); n];
for state in 0..n {
let state_id = state as StateId;
let p_from = &potentials[state];
if p_from.is_zero() {
continue;
}
for trans in fst.transitions(state_id).to_vec() {
let to_idx = trans.to as usize;
if to_idx >= potentials.len() {
continue;
}
let p_to = &potentials[to_idx];
if p_to.is_zero() {
continue;
}
let w_times_to = trans.weight.times(p_to);
let new_weight = w_times_to
.divide(p_from)
.unwrap_or_else(|| trans.weight.clone());
new_transitions[state].push(WeightedTransition {
from: trans.from,
to: trans.to,
input: trans.input,
output: trans.output,
weight: new_weight,
});
}
}
for state in 0..n {
let state_id = state as StateId;
fst.clear_transitions(state_id);
for trans in new_transitions[state].drain(..) {
fst.add_transition(trans);
}
}
if normalize_finals {
for state in 0..n {
let state_id = state as StateId;
if fst.is_final(state_id) {
fst.set_final(state_id, LogWeight::one());
}
}
}
Ok(())
}
fn verify_stochastic<L, F>(fst: &F, epsilon: f64) -> bool
where
L: Clone,
F: Wfst<L, LogWeight>,
{
for state in 0..fst.num_states() {
let state_id = state as StateId;
let mut total = fst.final_weight(state_id);
for trans in fst.transitions(state_id) {
total = total.plus(&trans.weight);
}
if !total.approx_eq(&LogWeight::one(), epsilon) {
if !fst.is_final(state_id) && fst.transitions(state_id).is_empty() {
continue;
}
if total.is_zero() && fst.transitions(state_id).is_empty() && !fst.is_final(state_id) {
continue;
}
return false;
}
}
true
}
fn compute_topological_order<L, F>(fst: &F) -> Vec<StateId>
where
L: Clone,
F: Wfst<L, LogWeight>,
{
let n = fst.num_states();
let mut in_degree = vec![0usize; n];
let mut order = Vec::with_capacity(n);
for s in 0..n {
let state_id = s as StateId;
for trans in fst.transitions(state_id) {
let to = trans.to as usize;
if to < n {
in_degree[to] += 1;
}
}
}
let mut queue: Vec<StateId> = (0..n as StateId)
.filter(|&s| in_degree[s as usize] == 0)
.collect();
while let Some(state) = queue.pop() {
order.push(state);
for trans in fst.transitions(state) {
let to = trans.to as usize;
if to < n {
in_degree[to] -= 1;
if in_degree[to] == 0 {
queue.push(trans.to);
}
}
}
}
if order.len() < n {
order = (0..n as StateId).collect();
}
order
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum LogPushError {
NoStartState,
NoPathToFinal,
DivisionByZero,
}
impl std::fmt::Display for LogPushError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoStartState => write!(f, "WFST has no start state"),
Self::NoPathToFinal => write!(f, "No path from start to final states"),
Self::DivisionByZero => write!(f, "Division by zero during reweighting"),
}
}
}
impl std::error::Error for LogPushError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::{MutableWfst as MutableWfstTrait, VectorWfst};
fn build_simple_chain() -> VectorWfst<char, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s1, Some('b'), Some('b'), s2, LogWeight::new(2.0));
fst
}
fn build_parallel_paths() -> VectorWfst<char, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0));
fst
}
fn build_diamond() -> VectorWfst<char, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
let s3 = fst.add_state();
fst.set_start(s0);
fst.set_final(s3, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s0, Some('b'), Some('b'), s2, LogWeight::new(2.0));
fst.add_arc(s1, Some('c'), Some('c'), s3, LogWeight::new(1.0));
fst.add_arc(s2, Some('d'), Some('d'), s3, LogWeight::new(0.5));
fst
}
#[test]
fn test_compute_potentials_chain() {
let fst = build_simple_chain();
let potentials = compute_log_potentials(&fst).expect("Should compute potentials");
assert_eq!(potentials.len(), 3);
assert!(
potentials[2].approx_eq(&LogWeight::one(), 0.001),
"Final state potential should be one, got {:?}",
potentials[2]
);
assert!(
potentials[1].approx_eq(&LogWeight::new(2.0), 0.001),
"State 1 potential should be 2.0, got {:?}",
potentials[1]
);
assert!(
potentials[0].approx_eq(&LogWeight::new(3.0), 0.001),
"State 0 potential should be 3.0, got {:?}",
potentials[0]
);
}
#[test]
fn test_compute_potentials_parallel() {
let fst = build_parallel_paths();
let potentials = compute_log_potentials(&fst).expect("Should compute potentials");
assert!(potentials[1].approx_eq(&LogWeight::one(), 0.001));
let expected = -((-1.0_f64).exp() + (-2.0_f64).exp()).ln();
assert!(
potentials[0].approx_eq(&LogWeight::new(expected), 0.001),
"State 0 potential should be {:?}, got {:?}",
expected,
potentials[0]
);
}
#[test]
fn test_prepare_for_beam_search_chain() {
let mut fst = build_simple_chain();
let result =
prepare_for_beam_search(&mut fst, LogPushConfig::default()).expect("Should prepare");
assert!(result.pushed);
assert_eq!(result.num_states, 3);
assert_eq!(result.num_transitions, 2);
assert!(
result.total_weight.approx_eq(&LogWeight::new(3.0), 0.001),
"Total weight should be 3.0, got {:?}",
result.total_weight
);
assert!(fst.final_weight(2).approx_eq(&LogWeight::one(), 0.001));
let trans_0 = &fst.transitions(0)[0];
let trans_1 = &fst.transitions(1)[0];
let path_weight = trans_0
.weight
.times(&trans_1.weight)
.times(&fst.final_weight(2));
assert!(
path_weight.approx_eq(&LogWeight::one(), 0.001),
"Normalized path weight should be one, got {:?}",
path_weight
);
}
#[test]
fn test_prepare_for_beam_search_parallel() {
let mut fst = build_parallel_paths();
let result =
prepare_for_beam_search(&mut fst, LogPushConfig::verified()).expect("Should prepare");
assert!(result.pushed);
assert_eq!(result.is_stochastic, Some(true));
let mut total = LogWeight::zero();
for trans in fst.transitions(0) {
total = total.plus(&trans.weight);
}
assert!(
total.approx_eq(&LogWeight::one(), 0.01),
"Pushed weights should sum to one, got {:?} (expected ~0.0 in log space)",
total
);
}
#[test]
fn test_prepare_for_beam_search_diamond() {
let mut fst = build_diamond();
let result =
prepare_for_beam_search(&mut fst, LogPushConfig::verified()).expect("Should prepare");
assert!(result.pushed);
assert_eq!(result.num_states, 4);
assert_eq!(result.num_transitions, 4);
assert!(
result.is_stochastic == Some(true),
"Should be stochastic after push"
);
}
#[test]
fn test_prepare_empty_fst() {
let mut fst: VectorWfst<char, LogWeight> = VectorWfst::new();
let result = prepare_for_beam_search(&mut fst, LogPushConfig::default())
.expect("Should handle empty");
assert!(!result.pushed);
assert_eq!(result.num_states, 0);
}
#[test]
fn test_prepare_no_start() {
let mut fst: VectorWfst<char, LogWeight> = VectorWfst::new();
fst.add_state();
let result = prepare_for_beam_search(&mut fst, LogPushConfig::default());
assert_eq!(result, Err(LogPushError::NoStartState));
}
#[test]
fn test_prepare_no_path_to_final() {
let mut fst: VectorWfst<char, LogWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
let result = prepare_for_beam_search(&mut fst, LogPushConfig::default());
assert_eq!(result, Err(LogPushError::NoPathToFinal));
}
#[test]
fn test_error_display() {
assert_eq!(
LogPushError::NoStartState.to_string(),
"WFST has no start state"
);
assert_eq!(
LogPushError::NoPathToFinal.to_string(),
"No path from start to final states"
);
assert_eq!(
LogPushError::DivisionByZero.to_string(),
"Division by zero during reweighting"
);
}
#[test]
fn test_log_push_preserves_total_weight() {
let mut fst = build_simple_chain();
let original_potentials = compute_log_potentials(&fst).expect("potentials");
let original_total = original_potentials[0].clone();
let result = prepare_for_beam_search(&mut fst, LogPushConfig::default()).expect("push");
assert!(
result.total_weight.approx_eq(&original_total, 0.001),
"Total weight should be preserved: expected {:?}, got {:?}",
original_total,
result.total_weight
);
}
}