Skip to main content

oxicuda_seq/hmm/
forward_backward.rs

1//! Log-space forward-backward for discrete HMMs and Gaussian HMMs.
2
3use super::hmm::{HmmDiscrete, HmmGaussian, log_safe};
4use crate::error::{SeqError, SeqResult};
5
6/// Result of forward-backward inference.
7#[derive(Debug, Clone)]
8pub struct ForwardBackward {
9    /// log α (T × n_states)
10    pub log_alpha: Vec<f64>,
11    /// log β (T × n_states)
12    pub log_beta: Vec<f64>,
13    /// γ (T × n_states): state posteriors
14    pub gamma: Vec<f64>,
15    /// ξ ((T-1) × n_states × n_states): edge posteriors
16    pub xi: Vec<f64>,
17    /// log p(o₁…o_T | model)
18    pub log_likelihood: f64,
19}
20
21/// log-sum-exp on a slice; handles `-∞` cleanly.
22#[inline]
23pub fn logsumexp(xs: &[f64]) -> f64 {
24    let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
25    if m == f64::NEG_INFINITY {
26        return f64::NEG_INFINITY;
27    }
28    let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
29    m + s.ln()
30}
31
32/// Run forward-backward on a discrete HMM.
33pub fn forward_backward(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<ForwardBackward> {
34    if obs.is_empty() {
35        return Err(SeqError::EmptyInput);
36    }
37    let t_max = obs.len();
38    let n = hmm.n_states;
39
40    // Pre-compute log-emissions per (t, j).
41    let mut log_em = vec![f64::NEG_INFINITY; t_max * n];
42    for t in 0..t_max {
43        for j in 0..n {
44            log_em[t * n + j] = hmm.log_emission(j, obs[t])?;
45        }
46    }
47
48    // Pre-compute log-transitions and log-init.
49    let mut log_a = vec![f64::NEG_INFINITY; n * n];
50    for i in 0..n {
51        for j in 0..n {
52            log_a[i * n + j] = log_safe(hmm.a[i * n + j]);
53        }
54    }
55    let log_pi: Vec<f64> = hmm.pi.iter().map(|&p| log_safe(p)).collect();
56
57    forward_backward_log(&log_pi, &log_a, &log_em, n, t_max)
58}
59
60/// Run forward-backward on a Gaussian HMM with observation sequence `x`
61/// (T × dim row-major).
62pub fn forward_backward_gaussian(hmm: &HmmGaussian, x: &[f64]) -> SeqResult<ForwardBackward> {
63    if x.is_empty() {
64        return Err(SeqError::EmptyInput);
65    }
66    if x.len() % hmm.dim != 0 {
67        return Err(SeqError::DimensionMismatch {
68            a: x.len(),
69            b: hmm.dim,
70        });
71    }
72    let t_max = x.len() / hmm.dim;
73    let n = hmm.n_states;
74    let mut log_em = vec![f64::NEG_INFINITY; t_max * n];
75    for t in 0..t_max {
76        let row = &x[t * hmm.dim..(t + 1) * hmm.dim];
77        for j in 0..n {
78            log_em[t * n + j] = hmm.log_emission(j, row)?;
79        }
80    }
81    let mut log_a = vec![f64::NEG_INFINITY; n * n];
82    for i in 0..n {
83        for j in 0..n {
84            log_a[i * n + j] = log_safe(hmm.a[i * n + j]);
85        }
86    }
87    let log_pi: Vec<f64> = hmm.pi.iter().map(|&p| log_safe(p)).collect();
88    forward_backward_log(&log_pi, &log_a, &log_em, n, t_max)
89}
90
91/// Generic log-space forward-backward given pre-computed log arrays.
92fn forward_backward_log(
93    log_pi: &[f64],
94    log_a: &[f64],
95    log_em: &[f64],
96    n: usize,
97    t_max: usize,
98) -> SeqResult<ForwardBackward> {
99    let mut log_alpha = vec![f64::NEG_INFINITY; t_max * n];
100    let mut log_beta = vec![f64::NEG_INFINITY; t_max * n];
101
102    // α₀(j) = log π_j + log B_j(o₀)
103    for j in 0..n {
104        log_alpha[j] = log_pi[j] + log_em[j];
105    }
106
107    // α_t(j) = logsumexp_i(α_{t-1}(i) + log A_{ij}) + log B_j(o_t)
108    let mut tmp = vec![0.0; n];
109    for t in 1..t_max {
110        for j in 0..n {
111            for i in 0..n {
112                tmp[i] = log_alpha[(t - 1) * n + i] + log_a[i * n + j];
113            }
114            log_alpha[t * n + j] = logsumexp(&tmp) + log_em[t * n + j];
115        }
116    }
117
118    // β_{T-1}(i) = 0
119    for i in 0..n {
120        log_beta[(t_max - 1) * n + i] = 0.0;
121    }
122    // β_t(i) = logsumexp_j(log A_{ij} + log B_j(o_{t+1}) + β_{t+1}(j))
123    for t in (0..t_max - 1).rev() {
124        for i in 0..n {
125            for j in 0..n {
126                tmp[j] = log_a[i * n + j] + log_em[(t + 1) * n + j] + log_beta[(t + 1) * n + j];
127            }
128            log_beta[t * n + i] = logsumexp(&tmp);
129        }
130    }
131
132    // log_likelihood = logsumexp_j α_{T-1}(j)
133    let last_alpha = &log_alpha[(t_max - 1) * n..t_max * n];
134    let ll = logsumexp(last_alpha);
135
136    // γ_t(i) = α_t(i) + β_t(i) − ll
137    let mut gamma = vec![0.0; t_max * n];
138    for t in 0..t_max {
139        for i in 0..n {
140            gamma[t * n + i] = (log_alpha[t * n + i] + log_beta[t * n + i] - ll).exp();
141        }
142        // Re-normalise to keep numerical purity
143        let s: f64 = gamma[t * n..t * n + n].iter().sum();
144        if s > 0.0 {
145            for i in 0..n {
146                gamma[t * n + i] /= s;
147            }
148        }
149    }
150
151    // ξ_t(i,j) = α_t(i) + log A_{ij} + log B_j(o_{t+1}) + β_{t+1}(j) − ll
152    let mut xi = vec![0.0; (t_max.saturating_sub(1)) * n * n];
153    for t in 0..t_max.saturating_sub(1) {
154        let mut s = 0.0;
155        for i in 0..n {
156            for j in 0..n {
157                let v = (log_alpha[t * n + i]
158                    + log_a[i * n + j]
159                    + log_em[(t + 1) * n + j]
160                    + log_beta[(t + 1) * n + j]
161                    - ll)
162                    .exp();
163                xi[t * n * n + i * n + j] = v;
164                s += v;
165            }
166        }
167        if s > 0.0 {
168            for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
169                *v /= s;
170            }
171        }
172    }
173
174    Ok(ForwardBackward {
175        log_alpha,
176        log_beta,
177        gamma,
178        xi,
179        log_likelihood: ll,
180    })
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    fn small_hmm() -> HmmDiscrete {
188        HmmDiscrete::new(
189            2,
190            2,
191            vec![0.6, 0.4],
192            vec![0.7, 0.3, 0.4, 0.6],
193            vec![0.1, 0.9, 0.8, 0.2],
194        )
195        .expect("ok")
196    }
197
198    #[test]
199    fn forward_alpha_dimensions() {
200        let h = small_hmm();
201        let fb = forward_backward(&h, &[0, 1, 0]).expect("ok");
202        assert_eq!(fb.log_alpha.len(), 6);
203        assert_eq!(fb.gamma.len(), 6);
204        assert_eq!(fb.xi.len(), 8);
205    }
206
207    #[test]
208    fn gamma_rows_sum_to_one() {
209        let h = small_hmm();
210        let fb = forward_backward(&h, &[0, 1, 0, 1]).expect("ok");
211        for t in 0..4 {
212            let s: f64 = fb.gamma[t * 2..(t + 1) * 2].iter().sum();
213            assert!((s - 1.0).abs() < 1e-9, "γ_{t} sums to {s}");
214        }
215    }
216
217    #[test]
218    fn logsumexp_neg_inf() {
219        let xs = vec![f64::NEG_INFINITY, f64::NEG_INFINITY];
220        assert!(logsumexp(&xs).is_infinite());
221    }
222
223    #[test]
224    fn logsumexp_single() {
225        let xs = vec![5.0];
226        assert!((logsumexp(&xs) - 5.0).abs() < 1e-12);
227    }
228}