use super::hmm::{HmmDiscrete, HmmGaussian, log_safe};
use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone)]
pub struct ForwardBackward {
pub log_alpha: Vec<f64>,
pub log_beta: Vec<f64>,
pub gamma: Vec<f64>,
pub xi: Vec<f64>,
pub log_likelihood: f64,
}
#[inline]
pub fn logsumexp(xs: &[f64]) -> f64 {
let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if m == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
m + s.ln()
}
pub fn forward_backward(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<ForwardBackward> {
if obs.is_empty() {
return Err(SeqError::EmptyInput);
}
let t_max = obs.len();
let n = hmm.n_states;
let mut log_em = vec![f64::NEG_INFINITY; t_max * n];
for t in 0..t_max {
for j in 0..n {
log_em[t * n + j] = hmm.log_emission(j, obs[t])?;
}
}
let mut log_a = vec![f64::NEG_INFINITY; n * n];
for i in 0..n {
for j in 0..n {
log_a[i * n + j] = log_safe(hmm.a[i * n + j]);
}
}
let log_pi: Vec<f64> = hmm.pi.iter().map(|&p| log_safe(p)).collect();
forward_backward_log(&log_pi, &log_a, &log_em, n, t_max)
}
pub fn forward_backward_gaussian(hmm: &HmmGaussian, x: &[f64]) -> SeqResult<ForwardBackward> {
if x.is_empty() {
return Err(SeqError::EmptyInput);
}
if x.len() % hmm.dim != 0 {
return Err(SeqError::DimensionMismatch {
a: x.len(),
b: hmm.dim,
});
}
let t_max = x.len() / hmm.dim;
let n = hmm.n_states;
let mut log_em = vec![f64::NEG_INFINITY; t_max * n];
for t in 0..t_max {
let row = &x[t * hmm.dim..(t + 1) * hmm.dim];
for j in 0..n {
log_em[t * n + j] = hmm.log_emission(j, row)?;
}
}
let mut log_a = vec![f64::NEG_INFINITY; n * n];
for i in 0..n {
for j in 0..n {
log_a[i * n + j] = log_safe(hmm.a[i * n + j]);
}
}
let log_pi: Vec<f64> = hmm.pi.iter().map(|&p| log_safe(p)).collect();
forward_backward_log(&log_pi, &log_a, &log_em, n, t_max)
}
fn forward_backward_log(
log_pi: &[f64],
log_a: &[f64],
log_em: &[f64],
n: usize,
t_max: usize,
) -> SeqResult<ForwardBackward> {
let mut log_alpha = vec![f64::NEG_INFINITY; t_max * n];
let mut log_beta = vec![f64::NEG_INFINITY; t_max * n];
for j in 0..n {
log_alpha[j] = log_pi[j] + log_em[j];
}
let mut tmp = vec![0.0; n];
for t in 1..t_max {
for j in 0..n {
for i in 0..n {
tmp[i] = log_alpha[(t - 1) * n + i] + log_a[i * n + j];
}
log_alpha[t * n + j] = logsumexp(&tmp) + log_em[t * n + j];
}
}
for i in 0..n {
log_beta[(t_max - 1) * n + i] = 0.0;
}
for t in (0..t_max - 1).rev() {
for i in 0..n {
for j in 0..n {
tmp[j] = log_a[i * n + j] + log_em[(t + 1) * n + j] + log_beta[(t + 1) * n + j];
}
log_beta[t * n + i] = logsumexp(&tmp);
}
}
let last_alpha = &log_alpha[(t_max - 1) * n..t_max * n];
let ll = logsumexp(last_alpha);
let mut gamma = vec![0.0; t_max * n];
for t in 0..t_max {
for i in 0..n {
gamma[t * n + i] = (log_alpha[t * n + i] + log_beta[t * n + i] - ll).exp();
}
let s: f64 = gamma[t * n..t * n + n].iter().sum();
if s > 0.0 {
for i in 0..n {
gamma[t * n + i] /= s;
}
}
}
let mut xi = vec![0.0; (t_max.saturating_sub(1)) * n * n];
for t in 0..t_max.saturating_sub(1) {
let mut s = 0.0;
for i in 0..n {
for j in 0..n {
let v = (log_alpha[t * n + i]
+ log_a[i * n + j]
+ log_em[(t + 1) * n + j]
+ log_beta[(t + 1) * n + j]
- ll)
.exp();
xi[t * n * n + i * n + j] = v;
s += v;
}
}
if s > 0.0 {
for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
*v /= s;
}
}
}
Ok(ForwardBackward {
log_alpha,
log_beta,
gamma,
xi,
log_likelihood: ll,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn small_hmm() -> HmmDiscrete {
HmmDiscrete::new(
2,
2,
vec![0.6, 0.4],
vec![0.7, 0.3, 0.4, 0.6],
vec![0.1, 0.9, 0.8, 0.2],
)
.expect("ok")
}
#[test]
fn forward_alpha_dimensions() {
let h = small_hmm();
let fb = forward_backward(&h, &[0, 1, 0]).expect("ok");
assert_eq!(fb.log_alpha.len(), 6);
assert_eq!(fb.gamma.len(), 6);
assert_eq!(fb.xi.len(), 8);
}
#[test]
fn gamma_rows_sum_to_one() {
let h = small_hmm();
let fb = forward_backward(&h, &[0, 1, 0, 1]).expect("ok");
for t in 0..4 {
let s: f64 = fb.gamma[t * 2..(t + 1) * 2].iter().sum();
assert!((s - 1.0).abs() < 1e-9, "γ_{t} sums to {s}");
}
}
#[test]
fn logsumexp_neg_inf() {
let xs = vec![f64::NEG_INFINITY, f64::NEG_INFINITY];
assert!(logsumexp(&xs).is_infinite());
}
#[test]
fn logsumexp_single() {
let xs = vec![5.0];
assert!((logsumexp(&xs) - 5.0).abs() < 1e-12);
}
}