Skip to main content

oxicuda_seq/hmm/
scaling.rs

1use super::hmm::{HmmDiscrete, log_safe};
2use crate::error::{SeqError, SeqResult};
3
4/// Scaled forward variables and associated per-step scaling coefficients.
5#[derive(Debug, Clone)]
6pub struct ScaledForwardResult {
7    /// T × n_states, scaled forward variables α_t(j).
8    pub alpha: Vec<f64>,
9    /// T scaling coefficients c_t (one per time step).
10    pub scales: Vec<f64>,
11    /// Σ_t log(1 / c_t) = -Σ_t log(c_t).
12    pub log_likelihood: f64,
13}
14
15/// Scaled backward variables.
16#[derive(Debug, Clone)]
17pub struct ScaledBackwardResult {
18    /// T × n_states, scaled backward variables β_t(i).
19    pub beta: Vec<f64>,
20}
21
22/// Combined scaled forward-backward with posterior statistics.
23#[derive(Debug, Clone)]
24pub struct ScaledForwardBackwardResult {
25    /// T × n_states scaled forward variables.
26    pub alpha: Vec<f64>,
27    /// T × n_states scaled backward variables.
28    pub beta: Vec<f64>,
29    /// T scaling coefficients.
30    pub scales: Vec<f64>,
31    /// T × n_states state posteriors γ_t(i) = P(q_t=i | O, λ).
32    pub gamma: Vec<f64>,
33    /// (T-1) × n_states × n_states edge posteriors ξ_t(i,j).
34    pub xi: Vec<f64>,
35    pub log_likelihood: f64,
36}
37
38/// Rabiner (1989) §VI: scaled forward pass for a discrete HMM.
39///
40/// Prevents arithmetic underflow on long sequences by normalising each time
41/// step so that the α row sums to 1.  The per-step scale factors carry the
42/// magnitude information needed to recover the log-likelihood exactly.
43pub fn scaled_forward(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<ScaledForwardResult> {
44    if obs.is_empty() {
45        return Err(SeqError::EmptyInput);
46    }
47    let t_max = obs.len();
48    let n = hmm.n_states;
49
50    let mut alpha = vec![0.0f64; t_max * n];
51    let mut scales = vec![0.0f64; t_max];
52
53    // t=0: α̂_0(j) = π_j · b_j(o_0); then normalise.
54    for j in 0..n {
55        let em = hmm.b[j * hmm.n_obs + obs[0]];
56        alpha[j] = hmm.pi[j] * em;
57    }
58    let c0: f64 = alpha[..n].iter().sum();
59    if c0 < f64::MIN_POSITIVE {
60        return Err(SeqError::NumericalInstability(
61            "all initial emissions are zero for obs[0]".to_string(),
62        ));
63    }
64    let c0 = 1.0 / c0;
65    scales[0] = c0;
66    for j in 0..n {
67        alpha[j] *= c0;
68    }
69
70    // t>0: recursive step.
71    let mut tmp_row = vec![0.0f64; n];
72    for t in 1..t_max {
73        // Copy previous alpha row to avoid simultaneous borrow.
74        tmp_row.copy_from_slice(&alpha[(t - 1) * n..t * n]);
75        for j in 0..n {
76            let em = hmm.b[j * hmm.n_obs + obs[t]];
77            let sum: f64 = (0..n).map(|i| tmp_row[i] * hmm.a[i * n + j]).sum();
78            alpha[t * n + j] = sum * em;
79        }
80        let row_sum: f64 = alpha[t * n..t * n + n].iter().sum();
81        if row_sum < f64::MIN_POSITIVE {
82            return Err(SeqError::NumericalInstability(format!(
83                "all scaled forward values vanished at t={t}"
84            )));
85        }
86        let ct = 1.0 / row_sum;
87        scales[t] = ct;
88        for j in 0..n {
89            alpha[t * n + j] *= ct;
90        }
91    }
92
93    // log p(O|λ) = -Σ_t log(c_t) = Σ_t log(1/c_t).
94    let log_likelihood: f64 = scales.iter().map(|&c| -log_safe(c)).sum();
95
96    Ok(ScaledForwardResult {
97        alpha,
98        scales,
99        log_likelihood,
100    })
101}
102
103/// Scaled backward pass using the scales from a prior scaled forward pass.
104///
105/// Reusing the same c_t as forward ensures the combined α·β products remain
106/// in a numerically safe range without a separate normalisation pass.
107pub fn scaled_backward(
108    hmm: &HmmDiscrete,
109    obs: &[usize],
110    scales: &[f64],
111) -> SeqResult<ScaledBackwardResult> {
112    if obs.is_empty() {
113        return Err(SeqError::EmptyInput);
114    }
115    let t_max = obs.len();
116    if scales.len() != t_max {
117        return Err(SeqError::LengthMismatch {
118            a: scales.len(),
119            b: t_max,
120        });
121    }
122    let n = hmm.n_states;
123    let mut beta = vec![0.0f64; t_max * n];
124
125    // β_{T-1}(i) is set to c_{T-1} (scaled by the last scale factor).
126    let last_c = scales[t_max - 1];
127    for i in 0..n {
128        beta[(t_max - 1) * n + i] = last_c;
129    }
130
131    // β_t(i) = c_t · Σ_j a_{ij} · b_j(o_{t+1}) · β_{t+1}(j)
132    let mut tmp_next = vec![0.0f64; n];
133    for t in (0..t_max - 1).rev() {
134        // Copy next beta row to avoid simultaneous borrow.
135        tmp_next.copy_from_slice(&beta[(t + 1) * n..(t + 2) * n]);
136        let ct = scales[t];
137        for i in 0..n {
138            let mut s = 0.0f64;
139            for j in 0..n {
140                let em = hmm.b[j * hmm.n_obs + obs[t + 1]];
141                s += hmm.a[i * n + j] * em * tmp_next[j];
142            }
143            beta[t * n + i] = ct * s;
144        }
145    }
146
147    Ok(ScaledBackwardResult { beta })
148}
149
150/// Full scaled forward-backward, yielding posteriors γ and ξ.
151pub fn scaled_forward_backward(
152    hmm: &HmmDiscrete,
153    obs: &[usize],
154) -> SeqResult<ScaledForwardBackwardResult> {
155    let sf = scaled_forward(hmm, obs)?;
156    let sb = scaled_backward(hmm, obs, &sf.scales)?;
157
158    let t_max = obs.len();
159    let n = hmm.n_states;
160
161    // γ_t(i) = α_t(i) · β_t(i), normalised per time step.
162    let mut gamma = vec![0.0f64; t_max * n];
163    for t in 0..t_max {
164        let mut row_sum = 0.0f64;
165        for i in 0..n {
166            let v = sf.alpha[t * n + i] * sb.beta[t * n + i];
167            gamma[t * n + i] = v;
168            row_sum += v;
169        }
170        if row_sum > 0.0 {
171            for i in 0..n {
172                gamma[t * n + i] /= row_sum;
173            }
174        }
175    }
176
177    // ξ_t(i,j) = α_t(i) · a_{ij} · b_j(o_{t+1}) · β_{t+1}(j), normalised per t.
178    let xi_len = t_max.saturating_sub(1) * n * n;
179    let mut xi = vec![0.0f64; xi_len];
180    for t in 0..t_max.saturating_sub(1) {
181        let mut total = 0.0f64;
182        for i in 0..n {
183            for j in 0..n {
184                let em = hmm.b[j * hmm.n_obs + obs[t + 1]];
185                let v = sf.alpha[t * n + i] * hmm.a[i * n + j] * em * sb.beta[(t + 1) * n + j];
186                xi[t * n * n + i * n + j] = v;
187                total += v;
188            }
189        }
190        if total > 0.0 {
191            for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
192                *v /= total;
193            }
194        }
195    }
196
197    Ok(ScaledForwardBackwardResult {
198        alpha: sf.alpha,
199        beta: sb.beta,
200        scales: sf.scales,
201        gamma,
202        xi,
203        log_likelihood: sf.log_likelihood,
204    })
205}
206
207/// Compute Baum-Welch parameter updates using scaled forward-backward.
208///
209/// Returns unnormalised `(new_pi, A_numerator, B_numerator)`.
210/// Caller is responsible for row-normalising A and B before creating a new HMM.
211pub fn scaled_baum_welch_step(
212    hmm: &HmmDiscrete,
213    obs: &[usize],
214    sfb: &ScaledForwardBackwardResult,
215) -> SeqResult<(Vec<f64>, Vec<f64>, Vec<f64>)> {
216    let t_max = obs.len();
217    let n = hmm.n_states;
218    let n_obs = hmm.n_obs;
219
220    // new_pi = γ_0
221    let new_pi: Vec<f64> = sfb.gamma[..n].to_vec();
222
223    // A_num[i*n+j] = Σ_{t=0}^{T-2} ξ_t(i,j)
224    let mut a_num = vec![0.0f64; n * n];
225    for t in 0..t_max.saturating_sub(1) {
226        for i in 0..n {
227            for j in 0..n {
228                a_num[i * n + j] += sfb.xi[t * n * n + i * n + j];
229            }
230        }
231    }
232
233    // B_num[j*n_obs+o] = Σ_{t: obs[t]==o} γ_t(j)
234    let mut b_num = vec![0.0f64; n * n_obs];
235    for (t, &o) in obs.iter().enumerate() {
236        if o >= n_obs {
237            return Err(SeqError::IndexOutOfBounds {
238                index: o,
239                len: n_obs,
240            });
241        }
242        for j in 0..n {
243            b_num[j * n_obs + o] += sfb.gamma[t * n + j];
244        }
245    }
246
247    Ok((new_pi, a_num, b_num))
248}
249
250/// Log-space Viterbi decoding, provided as a companion to the scaled algorithms.
251///
252/// The scaling approach only helps the forward-backward pass; Viterbi's
253/// max-product recursion is already numerically safe in log-space.
254pub fn scaled_viterbi(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<Vec<usize>> {
255    if obs.is_empty() {
256        return Err(SeqError::EmptyInput);
257    }
258    let t_max = obs.len();
259    let n = hmm.n_states;
260
261    let mut delta = vec![f64::NEG_INFINITY; t_max * n];
262    let mut psi = vec![0usize; t_max * n];
263
264    // δ_0(j) = log π_j + log b_j(o_0)
265    for j in 0..n {
266        delta[j] = log_safe(hmm.pi[j]) + log_safe(hmm.b[j * hmm.n_obs + obs[0]]);
267    }
268
269    for t in 1..t_max {
270        for j in 0..n {
271            let log_em = log_safe(hmm.b[j * hmm.n_obs + obs[t]]);
272            let mut best = f64::NEG_INFINITY;
273            let mut argmax = 0usize;
274            for i in 0..n {
275                let v = delta[(t - 1) * n + i] + log_safe(hmm.a[i * n + j]);
276                if v > best {
277                    best = v;
278                    argmax = i;
279                }
280            }
281            delta[t * n + j] = best + log_em;
282            psi[t * n + j] = argmax;
283        }
284    }
285
286    // Termination: find the state with the highest score at T-1.
287    let mut best = f64::NEG_INFINITY;
288    let mut last = 0usize;
289    for j in 0..n {
290        let v = delta[(t_max - 1) * n + j];
291        if v > best {
292            best = v;
293            last = j;
294        }
295    }
296
297    // Traceback.
298    let mut path = vec![0usize; t_max];
299    path[t_max - 1] = last;
300    for t in (1..t_max).rev() {
301        path[t - 1] = psi[t * n + path[t]];
302    }
303
304    Ok(path)
305}
306
307// ─── inline tests ────────────────────────────────────────────────────────────
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::hmm::forward_backward::forward_backward;
313    use crate::hmm::viterbi::viterbi;
314
315    fn small_hmm() -> HmmDiscrete {
316        HmmDiscrete::new(
317            2,
318            2,
319            vec![0.6, 0.4],
320            vec![0.7, 0.3, 0.4, 0.6],
321            vec![0.1, 0.9, 0.8, 0.2],
322        )
323        .expect("small_hmm ok")
324    }
325
326    fn hmm_2s_2o() -> HmmDiscrete {
327        HmmDiscrete::new(
328            2,
329            2,
330            vec![0.5, 0.5],
331            vec![0.9, 0.1, 0.1, 0.9],
332            vec![0.9, 0.1, 0.1, 0.9],
333        )
334        .expect("hmm_2s_2o ok")
335    }
336
337    fn single_state_hmm() -> HmmDiscrete {
338        HmmDiscrete::new(1, 2, vec![1.0], vec![1.0], vec![0.5, 0.5]).expect("single ok")
339    }
340
341    #[test]
342    fn scaled_forward_likelihood_matches_log_space() {
343        let h = small_hmm();
344        let obs = vec![0usize, 1, 0, 1, 0];
345        let sf = scaled_forward(&h, &obs).expect("ok");
346        let fb = forward_backward(&h, &obs).expect("ok");
347        assert!(
348            (sf.log_likelihood - fb.log_likelihood).abs() < 1e-6,
349            "scaled ll={} log-space ll={}",
350            sf.log_likelihood,
351            fb.log_likelihood
352        );
353    }
354
355    #[test]
356    fn scaled_forward_scales_all_positive() {
357        let h = small_hmm();
358        let sf = scaled_forward(&h, &[0, 1, 0, 1]).expect("ok");
359        for (t, &c) in sf.scales.iter().enumerate() {
360            assert!(c > 0.0, "c[{t}]={c} not positive");
361        }
362    }
363
364    #[test]
365    fn scaled_forward_alpha_rows_sum_to_one() {
366        let h = small_hmm();
367        let obs = vec![0, 1, 0, 1];
368        let sf = scaled_forward(&h, &obs).expect("ok");
369        let n = h.n_states;
370        for t in 0..obs.len() {
371            let s: f64 = sf.alpha[t * n..(t + 1) * n].iter().sum();
372            assert!((s - 1.0).abs() < 1e-12, "t={t} row sum={s}");
373        }
374    }
375
376    #[test]
377    fn scaled_backward_beta_finite() {
378        let h = small_hmm();
379        let obs = vec![0, 1, 0];
380        let sf = scaled_forward(&h, &obs).expect("ok");
381        let sb = scaled_backward(&h, &obs, &sf.scales).expect("ok");
382        for &v in &sb.beta {
383            assert!(v.is_finite(), "beta value not finite: {v}");
384        }
385    }
386
387    #[test]
388    fn scaled_forward_backward_gamma_sum() {
389        let h = small_hmm();
390        let obs = vec![0, 1, 0, 1];
391        let sfb = scaled_forward_backward(&h, &obs).expect("ok");
392        let n = h.n_states;
393        for t in 0..obs.len() {
394            let s: f64 = sfb.gamma[t * n..(t + 1) * n].iter().sum();
395            assert!((s - 1.0).abs() < 1e-9, "gamma t={t} sum={s}");
396        }
397    }
398
399    #[test]
400    fn scaled_forward_backward_xi_sum() {
401        let h = small_hmm();
402        let obs = vec![0, 1, 0, 1];
403        let sfb = scaled_forward_backward(&h, &obs).expect("ok");
404        let n = h.n_states;
405        for t in 0..obs.len() - 1 {
406            let s: f64 = sfb.xi[t * n * n..(t + 1) * n * n].iter().sum();
407            assert!((s - 1.0).abs() < 1e-9, "xi t={t} sum={s}");
408        }
409    }
410
411    #[test]
412    fn scaled_ll_equals_log_space_ll() {
413        let h = hmm_2s_2o();
414        let obs = vec![0, 0, 1, 1, 0];
415        let sf = scaled_forward(&h, &obs).expect("ok");
416        let fb = forward_backward(&h, &obs).expect("ok");
417        assert!(
418            (sf.log_likelihood - fb.log_likelihood).abs() < 1e-6,
419            "scaled={} log-space={}",
420            sf.log_likelihood,
421            fb.log_likelihood
422        );
423    }
424
425    #[test]
426    fn scaled_forward_empty_obs_err() {
427        let h = small_hmm();
428        let res = scaled_forward(&h, &[]);
429        assert!(matches!(res, Err(SeqError::EmptyInput)));
430    }
431
432    #[test]
433    fn scaled_viterbi_consistent_with_standard_viterbi() {
434        let h = hmm_2s_2o();
435        let obs = vec![0, 0, 1, 1];
436        let sv = scaled_viterbi(&h, &obs).expect("ok");
437        let lv = viterbi(&h, &obs).expect("ok");
438        assert_eq!(
439            sv, lv.path,
440            "scaled_viterbi path diverges from log-space viterbi"
441        );
442    }
443
444    #[test]
445    fn scaled_forward_single_obs() {
446        let h = small_hmm();
447        let sf = scaled_forward(&h, &[0]).expect("ok");
448        assert_eq!(sf.alpha.len(), h.n_states);
449        assert_eq!(sf.scales.len(), 1);
450        let s: f64 = sf.alpha.iter().sum();
451        assert!((s - 1.0).abs() < 1e-12);
452    }
453
454    #[test]
455    fn scaled_forward_long_sequence_no_underflow() {
456        let h = hmm_2s_2o();
457        let obs: Vec<usize> = (0..1000).map(|i| i % 2).collect();
458        let sf = scaled_forward(&h, &obs);
459        assert!(sf.is_ok(), "scaled_forward failed on length-1000 sequence");
460        let sf = sf.expect("ok");
461        assert!(sf.log_likelihood.is_finite());
462        assert!(sf.log_likelihood < 0.0, "log-likelihood must be negative");
463    }
464
465    #[test]
466    fn scaled_backward_wrong_scales_len_err() {
467        let h = small_hmm();
468        let obs = vec![0, 1, 0];
469        let bad_scales = vec![1.0, 1.0];
470        let res = scaled_backward(&h, &obs, &bad_scales);
471        assert!(
472            matches!(res, Err(SeqError::LengthMismatch { .. })),
473            "expected LengthMismatch"
474        );
475    }
476
477    #[test]
478    fn scaled_baum_welch_step_pi_sums_to_1() {
479        let h = small_hmm();
480        let obs = vec![0, 1, 0, 1];
481        let sfb = scaled_forward_backward(&h, &obs).expect("ok");
482        let (new_pi, _, _) = scaled_baum_welch_step(&h, &obs, &sfb).expect("ok");
483        let s: f64 = new_pi.iter().sum();
484        assert!((s - 1.0).abs() < 1e-9, "new_pi sum={s}");
485    }
486
487    #[test]
488    fn scaled_baum_welch_step_shapes_correct() {
489        let h = small_hmm();
490        let obs = vec![0, 1, 0, 1];
491        let sfb = scaled_forward_backward(&h, &obs).expect("ok");
492        let (pi, a_num, b_num) = scaled_baum_welch_step(&h, &obs, &sfb).expect("ok");
493        assert_eq!(pi.len(), h.n_states);
494        assert_eq!(a_num.len(), h.n_states * h.n_states);
495        assert_eq!(b_num.len(), h.n_states * h.n_obs);
496    }
497
498    #[test]
499    fn scaled_forward_backward_2state_2obs() {
500        // Simple known HMM: deterministic transitions, near-deterministic emissions.
501        // State 0 emits obs 0 with probability ~1, state 1 emits obs 1 with probability ~1.
502        let h = HmmDiscrete::new(
503            2,
504            2,
505            vec![1.0, 0.0],
506            vec![0.0, 1.0, 1.0, 0.0],
507            vec![0.99, 0.01, 0.01, 0.99],
508        )
509        .expect("ok");
510        let obs = vec![0, 1, 0, 1];
511        let sfb = scaled_forward_backward(&h, &obs).expect("ok");
512        // At t=0 state 0 is strongly preferred (γ_0(0) ≈ 1).
513        assert!(sfb.gamma[0] > 0.9, "gamma[0][0]={}", sfb.gamma[0]);
514        // At t=1 state 1 is strongly preferred.
515        let n = h.n_states;
516        assert!(sfb.gamma[n + 1] > 0.9, "gamma[1][1]={}", sfb.gamma[n + 1]);
517    }
518
519    #[test]
520    fn scaled_forward_single_state() {
521        let h = single_state_hmm();
522        let obs = vec![0, 1, 0];
523        let sf = scaled_forward(&h, &obs).expect("ok");
524        assert_eq!(sf.scales.len(), 3);
525        assert_eq!(sf.alpha.len(), 3);
526        for &a in &sf.alpha {
527            assert!(
528                (a - 1.0).abs() < 1e-12,
529                "single-state alpha must be 1.0, got {a}"
530            );
531        }
532    }
533
534    #[test]
535    fn scaled_viterbi_single_state() {
536        let h = single_state_hmm();
537        let obs = vec![0, 1, 0, 1];
538        let path = scaled_viterbi(&h, &obs).expect("ok");
539        assert_eq!(
540            path,
541            vec![0, 0, 0, 0],
542            "single-state path must be all zeros"
543        );
544    }
545}