use crate::backend::LatticeBackend;
use crate::lattice::{EdgeId, Lattice, LatticePath, NodeId};
use crate::semiring::Semiring;
#[derive(Clone, Debug)]
pub struct ViterbiResult<W: Semiring> {
pub path: LatticePath<W>,
pub success: bool,
}
impl<W: Semiring> ViterbiResult<W> {
fn success(path: LatticePath<W>) -> Self {
Self {
path,
success: true,
}
}
fn failure() -> Self {
Self {
path: LatticePath::new(),
success: false,
}
}
}
pub fn viterbi<W: Semiring, B: LatticeBackend>(lattice: &mut Lattice<W, B>) -> ViterbiResult<W> {
if lattice.is_empty() {
if lattice.start() == lattice.end() {
let mut path = LatticePath::new();
path.mark_complete();
return ViterbiResult::success(path);
}
return ViterbiResult::failure();
}
let topo_order = match lattice.topological_order() {
Some(order) => order.to_vec(),
None => return ViterbiResult::failure(), };
let n = lattice.num_nodes();
let start = lattice.start();
let end = lattice.end();
let mut best: Vec<Option<(W, EdgeId, NodeId)>> = vec![None; n];
let _start_idx = start.0 as usize;
for &node_id in &topo_order {
let node_idx = node_id.0 as usize;
if node_id != start && best[node_idx].is_none() {
}
let current_score = if node_id == start {
W::one()
} else {
match &best[node_idx] {
Some((score, _, _)) => *score,
None => continue, }
};
for edge in lattice.outgoing_edges(node_id) {
let target_idx = edge.target.0 as usize;
let new_score = current_score.times(&edge.weight);
let update = match &best[target_idx] {
None => true,
Some((existing_score, _, _)) => {
match new_score.natural_less(existing_score) {
Some(true) => true,
Some(false) => false,
None => {
new_score.is_zero() || existing_score.is_zero()
}
}
}
};
if update {
best[target_idx] = Some((new_score, edge.id, node_id));
}
}
}
let end_idx = end.0 as usize;
if end_idx >= n || (start != end && best[end_idx].is_none()) {
return ViterbiResult::failure();
}
let mut edges = Vec::new();
let mut current = end;
while current != start {
let current_idx = current.0 as usize;
match &best[current_idx] {
Some((_, edge_id, prev_node)) => {
edges.push(*edge_id);
current = *prev_node;
}
None => return ViterbiResult::failure(), }
}
edges.reverse();
let final_weight = if edges.is_empty() {
W::one()
} else {
best[end_idx]
.as_ref()
.map(|(w, _, _)| *w)
.unwrap_or_else(W::one)
};
let mut path = LatticePath::with_weight(final_weight);
for edge_id in edges {
path.edges.push(edge_id);
}
path.weight = final_weight;
path.mark_complete();
ViterbiResult::success(path)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::lattice::{EdgeMetadata, LatticeBuilder};
use crate::semiring::TropicalWeight;
#[test]
fn test_viterbi_simple() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(
0,
1,
"the",
TropicalWeight::new(0.5),
EdgeMetadata::default(),
);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
let mut lattice = builder.build(1);
let result = viterbi(&mut lattice);
assert!(result.success);
assert_eq!(result.path.len(), 1);
assert_eq!(result.path.weight.value(), 0.5);
let words = result.path.to_words(&lattice);
assert_eq!(words, vec!["the"]);
}
#[test]
fn test_viterbi_multi_position() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(
0,
1,
"the",
TropicalWeight::new(0.5),
EdgeMetadata::default(),
);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(
1,
2,
"quick",
TropicalWeight::new(0.3),
EdgeMetadata::default(),
);
builder.add_correction(
1,
2,
"slow",
TropicalWeight::new(0.7),
EdgeMetadata::default(),
);
let mut lattice = builder.build(2);
let result = viterbi(&mut lattice);
assert!(result.success);
assert_eq!(result.path.len(), 2);
assert_eq!(result.path.weight.value(), 0.8);
let words = result.path.to_words(&lattice);
assert_eq!(words, vec!["the", "quick"]);
}
#[test]
fn test_viterbi_empty_lattice() {
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let mut lattice = builder.build(0);
let result = viterbi(&mut lattice);
assert!(result.success);
assert!(result.path.is_empty());
}
#[test]
fn test_viterbi_single_path() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(
0,
1,
"hello",
TropicalWeight::new(1.0),
EdgeMetadata::default(),
);
builder.add_correction(
1,
2,
"world",
TropicalWeight::new(2.0),
EdgeMetadata::default(),
);
let mut lattice = builder.build(2);
let result = viterbi(&mut lattice);
assert!(result.success);
assert_eq!(result.path.len(), 2);
assert_eq!(result.path.weight.value(), 3.0);
}
#[test]
fn test_viterbi_diamond() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 2, "b", TropicalWeight::new(2.0), EdgeMetadata::default());
builder.add_correction(1, 3, "c", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(2, 3, "d", TropicalWeight::new(0.5), EdgeMetadata::default());
let mut lattice = builder.build(3);
let result = viterbi(&mut lattice);
assert!(result.success);
assert_eq!(result.path.len(), 2);
assert_eq!(result.path.weight.value(), 2.0);
let words = result.path.to_words(&lattice);
assert_eq!(words, vec!["a", "c"]);
}
#[test]
fn test_viterbi_equal_weights() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 1, "b", TropicalWeight::new(1.0), EdgeMetadata::default());
let mut lattice = builder.build(1);
let result = viterbi(&mut lattice);
assert!(result.success);
assert_eq!(result.path.len(), 1);
assert_eq!(result.path.weight.value(), 1.0);
}
#[test]
fn test_viterbi_zero_weight() {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(
0,
1,
"zero",
TropicalWeight::new(0.0),
EdgeMetadata::default(),
);
builder.add_correction(
0,
1,
"one",
TropicalWeight::new(1.0),
EdgeMetadata::default(),
);
let mut lattice = builder.build(1);
let result = viterbi(&mut lattice);
assert!(result.success);
assert_eq!(result.path.weight.value(), 0.0);
let words = result.path.to_words(&lattice);
assert_eq!(words, vec!["zero"]);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::path::nbest;
use crate::test_utils::{arb_diamond_lattice, arb_linear_lattice, arb_tropical_lattice};
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn viterbi_linear_finds_only_path(
mut lattice in arb_linear_lattice(4)
) {
let result = viterbi(&mut lattice);
prop_assert!(result.success);
prop_assert_eq!(result.path.len(), 4);
}
#[test]
fn viterbi_always_succeeds_on_connected(
mut lattice in arb_tropical_lattice(3, 2)
) {
let result = viterbi(&mut lattice);
prop_assert!(result.success);
}
#[test]
fn viterbi_path_length_matches_positions(
mut lattice in arb_tropical_lattice(4, 3)
) {
let result = viterbi(&mut lattice);
prop_assert!(result.success);
prop_assert_eq!(result.path.len(), 4);
}
#[test]
fn viterbi_finds_optimal(
mut lattice in arb_diamond_lattice(3)
) {
let viterbi_result = viterbi(&mut lattice);
prop_assert!(viterbi_result.success);
let viterbi_weight = viterbi_result.path.weight.value();
let all_paths = nbest(&mut lattice, 100);
for path in &all_paths {
prop_assert!(
viterbi_weight <= path.weight.value() + 1e-9,
"Viterbi weight {} > path weight {}",
viterbi_weight,
path.weight.value()
);
}
}
#[test]
fn viterbi_weight_non_negative(
mut lattice in arb_tropical_lattice(3, 2)
) {
let result = viterbi(&mut lattice);
prop_assert!(result.success);
prop_assert!(result.path.weight.value() >= 0.0);
}
#[test]
fn viterbi_path_is_complete(
mut lattice in arb_tropical_lattice(2, 2)
) {
let result = viterbi(&mut lattice);
prop_assert!(result.success);
prop_assert!(result.path.is_complete);
}
}
}