use std::cmp::Ord;
use std::fmt::Debug;
use std::ops::Add;
use approx::{abs_diff_eq, AbsDiffEq};
use ndarray::prelude::*;
use ndarray_stats::{MaybeNan, QuantileExt};
use realfft::num_traits::Float;
pub(crate) fn viterbi<A>(
prob: ArrayView2<A>,
transition: ArrayView2<A>,
p_init: Option<CowArray<A, Ix1>>,
) -> (Array1<usize>, A)
where
A: Float + AbsDiffEq<Epsilon = A> + Add + MaybeNan + Debug,
<A as MaybeNan>::NotNan: Ord,
{
assert_eq!(prob.raw_dim()[0], transition.raw_dim()[0]);
assert_eq!(transition.raw_dim()[0], transition.raw_dim()[1]);
assert!(
transition.iter().all(|&x| x >= A::zero())
&& transition.sum_axis(Axis(1)).abs_diff_eq(
&Array::from_elem(transition.shape()[0], A::one()),
A::from(1e-6).unwrap(),
),
"Invalid transition matrix: must be non-negative and sum to 1 on each row.",
);
assert!(
prob.iter().all(|&x| x >= A::zero() && x <= A::one()),
"Invalid probability values: must be between 0 and 1."
);
let n_states = prob.shape()[0];
let n_steps = prob.shape()[1];
let mut states = Array1::<usize>::zeros(n_steps);
let mut values = Array2::<A>::zeros((n_steps, n_states));
let mut ptr = Array2::<usize>::zeros((n_steps, n_states));
let epsilon = A::min_positive_value();
let log_trans = transition.mapv(|x| (x + epsilon).ln());
let log_prob = prob.t().mapv(|x| (x + epsilon).ln());
let p_init = match p_init {
Some(p_init) => {
assert!(
p_init.raw_dim() == Dim(n_states)
&& p_init.iter().all(|&x| x >= A::zero())
&& abs_diff_eq!(p_init.sum(), A::one(), epsilon = A::from(1e-6).unwrap()),
"Invalid initial state distribution: p_init={:?}",
p_init
);
p_init
}
None => Array1::from_elem(n_states, A::one() / A::from(n_states).unwrap()).into(),
};
let log_p_init = p_init.mapv(|x| (x + epsilon).ln());
values
.slice_mut(s![0, ..])
.assign(&(log_p_init + log_prob.slice(s![0, ..])));
for t in 1..n_steps {
let trans_out = &values.slice(s![t - 1..t, ..]) + &log_trans.t();
values
.slice_mut(s![t, ..])
.iter_mut()
.enumerate()
.for_each(|(j, x)| {
ptr[[t, j]] = trans_out.slice(s![j, ..]).argmax_skipnan().unwrap();
*x = log_prob[[t, j]] + trans_out[[j, ptr[[t, j]]]];
})
}
*states.last_mut().unwrap() = values.slice(s![-1, ..]).argmax_skipnan().unwrap();
for t in (0..n_steps - 1).rev() {
states[t] = ptr[[t + 1, states[t + 1]]];
}
let (i, j) = (values.shape()[0] - 1, states[states.shape()[0] - 1]);
(states, values[[i, j]])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_viterbi() {
let p_init = Array1::from_vec(vec![0.6f64, 0.4]);
let p_emit = Array2::from_shape_vec((2, 3), vec![0.5f64, 0.4, 0.1, 0.1, 0.3, 0.6]).unwrap();
let p_trans = Array2::from_shape_vec((2, 2), vec![0.7f64, 0.3, 0.4, 0.6]).unwrap();
let (path, logp) = viterbi(p_emit.view(), p_trans.view(), Some(p_init.into()));
assert!(path
.into_iter()
.zip([0, 0, 1].into_iter())
.all(|(x, y)| x == y));
assert!((logp - (-4.19173690823075)).abs() < 1e-14);
}
}