use super::linear_chain_crf::LinearChainCrf;
use crate::error::{SeqError, SeqResult};
pub fn viterbi_decode(crf: &LinearChainCrf, x: &[f64]) -> SeqResult<Vec<usize>> {
let n = crf.n_labels;
let k = crf.n_features;
if x.is_empty() {
return Err(SeqError::EmptyInput);
}
if x.len() % k != 0 {
return Err(SeqError::DimensionMismatch { a: x.len(), b: k });
}
let t_max = x.len() / k;
let mut delta = vec![f64::NEG_INFINITY; t_max * n];
let mut psi = vec![0usize; t_max * n];
for j in 0..n {
delta[j] = crf.emit_score(j, &x[..k])?;
}
for t in 1..t_max {
for j in 0..n {
let emit = crf.emit_score(j, &x[t * k..(t + 1) * k])?;
let mut best = f64::NEG_INFINITY;
let mut argmax = 0usize;
for i in 0..n {
let v = delta[(t - 1) * n + i] + crf.transitions[i * n + j];
if v > best {
best = v;
argmax = i;
}
}
delta[t * n + j] = best + emit;
psi[t * n + j] = argmax;
}
}
let mut best = f64::NEG_INFINITY;
let mut last = 0usize;
for j in 0..n {
let v = delta[(t_max - 1) * n + j];
if v > best {
best = v;
last = j;
}
}
let mut path = vec![0usize; t_max];
path[t_max - 1] = last;
for t in (1..t_max).rev() {
path[t - 1] = psi[t * n + path[t]];
}
Ok(path)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_trivial() {
let mut crf = LinearChainCrf::zeros(2, 2).expect("ok");
crf.emissions = vec![1.0, 0.0, 0.0, 1.0];
let x = vec![5.0, 0.0, 0.0, 5.0, 5.0, 0.0];
let y = viterbi_decode(&crf, &x).expect("ok");
assert_eq!(y, vec![0, 1, 0]);
}
}