use super::gradient::{ArcIndex, GradientAccumulator, GradientWfst};
use crate::semiring::{LogWeight, Semiring};
use crate::wfst::StateId;
#[derive(Clone, Debug)]
pub struct ViterbiGradResult {
pub score: LogWeight,
pub path: Vec<ArcIndex>,
pub gradients: GradientAccumulator,
}
pub fn viterbi_score<L: Clone + Send + Sync>(grad_fst: &GradientWfst<L>) -> LogWeight {
let num_states = grad_fst.num_states();
if num_states == 0 {
return LogWeight::zero();
}
let start = grad_fst.start();
let mut delta = vec![f64::INFINITY; num_states];
delta[start as usize] = 0.0;
let topo_order = compute_topological_order(grad_fst);
for &state in &topo_order {
let delta_state = delta[state as usize];
if delta_state.is_infinite() {
continue;
}
for trans in grad_fst.transitions(state) {
let to_state = trans.to;
let arc_weight = trans.weight.value();
let new_delta = delta_state + arc_weight;
if new_delta < delta[to_state as usize] {
delta[to_state as usize] = new_delta;
}
}
}
let mut best_score = f64::INFINITY;
for s in 0..num_states as StateId {
if grad_fst.is_final(s) {
let final_weight = grad_fst.final_weight(s).value();
let total = delta[s as usize] + final_weight;
if total < best_score {
best_score = total;
}
}
}
if best_score.is_infinite() {
LogWeight::zero()
} else {
LogWeight::new(best_score)
}
}
pub fn viterbi_path_with_grad<L: Clone + Send + Sync>(
grad_fst: &GradientWfst<L>,
) -> ViterbiGradResult {
let num_states = grad_fst.num_states();
if num_states == 0 {
return ViterbiGradResult {
score: LogWeight::zero(),
path: Vec::new(),
gradients: GradientAccumulator::new(),
};
}
let start = grad_fst.start();
let mut delta = vec![f64::INFINITY; num_states];
let mut backpointers: Vec<Option<(StateId, usize)>> = vec![None; num_states];
delta[start as usize] = 0.0;
let topo_order = compute_topological_order(grad_fst);
for &state in &topo_order {
let delta_state = delta[state as usize];
if delta_state.is_infinite() {
continue;
}
for (arc_idx, trans) in grad_fst.transitions(state).iter().enumerate() {
let to_state = trans.to;
let arc_weight = trans.weight.value();
let new_delta = delta_state + arc_weight;
if new_delta < delta[to_state as usize] {
delta[to_state as usize] = new_delta;
backpointers[to_state as usize] = Some((state, arc_idx));
}
}
}
let mut best_final: Option<StateId> = None;
let mut best_score = f64::INFINITY;
for s in 0..num_states as StateId {
if grad_fst.is_final(s) {
let final_weight = grad_fst.final_weight(s).value();
let total = delta[s as usize] + final_weight;
if total < best_score {
best_score = total;
best_final = Some(s);
}
}
}
let mut path = Vec::new();
if let Some(final_state) = best_final {
let mut current = final_state;
while let Some((prev_state, arc_idx)) = backpointers[current as usize] {
path.push(ArcIndex::new(prev_state, arc_idx));
current = prev_state;
}
path.reverse();
}
let mut gradients = GradientAccumulator::new();
for arc in &path {
gradients.add_gradient(*arc, 1.0);
}
ViterbiGradResult {
score: if best_score.is_infinite() {
LogWeight::zero()
} else {
LogWeight::new(best_score)
},
path,
gradients,
}
}
fn compute_topological_order<L: Clone + Send + Sync>(grad_fst: &GradientWfst<L>) -> Vec<StateId> {
let num_states = grad_fst.num_states();
let mut in_degree = vec![0usize; num_states];
let mut order = Vec::with_capacity(num_states);
for s in 0..num_states as StateId {
for trans in grad_fst.transitions(s) {
in_degree[trans.to as usize] += 1;
}
}
let mut queue: Vec<StateId> = (0..num_states as StateId)
.filter(|&s| in_degree[s as usize] == 0)
.collect();
while let Some(state) = queue.pop() {
order.push(state);
for trans in grad_fst.transitions(state) {
let to = trans.to as usize;
in_degree[to] -= 1;
if in_degree[to] == 0 {
queue.push(trans.to);
}
}
}
if order.len() < num_states {
order = (0..num_states as StateId).collect();
}
order
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::wfst::{MutableWfst, VectorWfst, Wfst};
use proptest::prelude::*;
fn arb_parallel_wfst(max_paths: usize) -> impl Strategy<Value = VectorWfst<char, LogWeight>> {
proptest::collection::vec(-5.0f64..5.0, 1..=max_paths).prop_map(|weights| {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
for (i, w) in weights.iter().enumerate() {
let label = (b'a' + (i % 26) as u8) as char;
fst.add_arc(s0, Some(label), Some(label), s1, LogWeight::new(*w));
}
fst
})
}
fn arb_chain_wfst(max_length: usize) -> impl Strategy<Value = VectorWfst<char, LogWeight>> {
(1..=max_length).prop_flat_map(|len| {
proptest::collection::vec(-5.0f64..5.0, len).prop_map(move |weights| {
let mut fst = VectorWfst::new();
for _ in 0..=len {
fst.add_state();
}
fst.set_start(0);
fst.set_final(len as u32, LogWeight::one());
for (i, w) in weights.iter().enumerate() {
let label = (b'a' + (i % 26) as u8) as char;
fst.add_arc(
i as u32,
Some(label),
Some(label),
(i + 1) as u32,
LogWeight::new(*w),
);
}
fst
})
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn viterbi_finds_min_weight(fst in arb_parallel_wfst(5)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
let min_weight = fst.transitions(0).iter()
.map(|arc| arc.weight.value())
.fold(f64::INFINITY, f64::min);
prop_assert!((score.value() - min_weight).abs() < 1e-6,
"Viterbi score {} != min weight {}", score.value(), min_weight);
}
#[test]
fn viterbi_chain_equals_sum(fst in arb_chain_wfst(5)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
let expected: f64 = fst.transitions(0).iter()
.chain((1..fst.num_states() as u32).flat_map(|s| fst.transitions(s).iter()))
.map(|arc| arc.weight.value())
.sum();
prop_assert!((score.value() - expected).abs() < 1e-6,
"Viterbi chain score {} != expected {}", score.value(), expected);
}
#[test]
fn viterbi_deterministic(fst in arb_parallel_wfst(4)) {
let grad_fst1 = GradientWfst::from_wfst(&fst);
let grad_fst2 = GradientWfst::from_wfst(&fst);
let score1 = viterbi_score(&grad_fst1);
let score2 = viterbi_score(&grad_fst2);
prop_assert!((score1.value() - score2.value()).abs() < 1e-9,
"Viterbi scores differ: {} vs {}", score1.value(), score2.value());
}
#[test]
fn viterbi_leq_forward(fst in arb_parallel_wfst(5)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let viterbi = viterbi_score(&grad_fst);
grad_fst.reset();
let forward = super::super::forward_score::forward_score(&grad_fst);
prop_assert!(viterbi.value() >= forward.value() - 1e-6,
"Viterbi {} < forward {} (should be >=)", viterbi.value(), forward.value());
}
#[test]
fn viterbi_path_correct_length(fst in arb_chain_wfst(4)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
let expected_len = fst.num_states() - 1;
prop_assert_eq!(result.path.len(), expected_len,
"Path length {} != expected {}", result.path.len(), expected_len);
}
#[test]
fn viterbi_path_grad_sum(fst in arb_chain_wfst(4)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
let grad_sum: f64 = result.gradients.arc_gradients.iter()
.map(|g| g.gradient)
.sum();
prop_assert!((grad_sum - result.path.len() as f64).abs() < 1e-6,
"Gradient sum {} != path length {}", grad_sum, result.path.len());
}
#[test]
fn viterbi_path_selects_best(fst in arb_parallel_wfst(5)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
prop_assert_eq!(result.path.len(), 1);
let min_idx = fst.transitions(0).iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
a.weight.value().partial_cmp(&b.weight.value()).expect("differentiable/viterbi.rs: required value was None/Err")
})
.map(|(i, _)| i)
.expect("differentiable/viterbi.rs: required value was None/Err");
prop_assert_eq!(result.path[0].arc_idx, min_idx,
"Path arc {} != min arc {}", result.path[0].arc_idx, min_idx);
}
#[test]
fn viterbi_path_score_matches(fst in arb_parallel_wfst(4)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
let grad_fst2 = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst2);
prop_assert!((score.value() - result.score.value()).abs() < 1e-9,
"viterbi_score {} != viterbi_path_with_grad score {}",
score.value(), result.score.value());
}
#[test]
fn viterbi_path_valid_sequence(fst in arb_chain_wfst(4)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
if !result.path.is_empty() {
prop_assert_eq!(result.path[0].from, 0,
"Path should start at state 0, got {}", result.path[0].from);
}
for i in 1..result.path.len() {
let curr_arc = &result.path[i];
let prev_arc = &result.path[i - 1];
prop_assert_eq!(curr_arc.from, prev_arc.from + 1,
"Arc {} should start one state past arc {}", i, i - 1);
prop_assert_eq!(curr_arc.from as usize, i,
"Arc {} should start at state {}", i, i);
}
}
#[test]
fn viterbi_gradient_sparse(fst in arb_parallel_wfst(5)) {
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
let num_arcs = fst.transitions(0).len();
for arc_idx in 0..num_arcs {
let arc = ArcIndex::new(0, arc_idx);
let grad = result.gradients.get_gradient(arc);
let on_path = result.path.iter().any(|p| *p == arc);
if on_path {
prop_assert!((grad - 1.0).abs() < 1e-6,
"Path arc gradient {} should be 1.0", grad);
} else {
prop_assert!((grad - 0.0).abs() < 1e-6,
"Non-path arc gradient {} should be 0.0", grad);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::{MutableWfst, VectorWfst};
#[test]
fn test_viterbi_empty() {
let fst = VectorWfst::<char, LogWeight>::new();
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!(score.is_zero());
}
#[test]
fn test_viterbi_no_path() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!(score.is_zero());
}
#[test]
fn test_viterbi_single_path() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!((score.value() - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_viterbi_two_paths() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(-2.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!((score.value() - (-2.0)).abs() < 1e-6);
}
#[test]
fn test_viterbi_chain() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, LogWeight::new(-0.5));
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
fst.add_arc(s1, Some('b'), Some('b'), s2, LogWeight::new(-2.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!((score.value() - (-3.5)).abs() < 1e-6);
}
#[test]
fn test_viterbi_path_with_grad() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(-2.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
assert!((result.score.value() - (-2.0)).abs() < 1e-6);
assert_eq!(result.path.len(), 1);
assert_eq!(result.path[0].from, 0);
assert_eq!(result.path[0].arc_idx, 1);
assert!((result.gradients.get_gradient(result.path[0]) - 1.0).abs() < 1e-6);
}
#[test]
fn test_viterbi_path_chain() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
fst.add_arc(s1, Some('b'), Some('b'), s2, LogWeight::new(-2.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let result = viterbi_path_with_grad(&grad_fst);
assert_eq!(result.path.len(), 2);
assert_eq!(result.path[0].from, 0);
assert_eq!(result.path[1].from, 1);
for arc in &result.path {
assert!((result.gradients.get_gradient(*arc) - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_viterbi_diamond() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
fst.add_arc(s1, Some('b'), Some('b'), s2, LogWeight::new(-1.0));
fst.add_arc(s0, Some('c'), Some('c'), s2, LogWeight::new(-1.5));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!((score.value() - (-2.0)).abs() < 1e-6);
}
}