1use super::hmm::{HmmDiscrete, log_safe};
2use crate::error::{SeqError, SeqResult};
3
4#[derive(Debug, Clone)]
6pub struct ScaledForwardResult {
7 pub alpha: Vec<f64>,
9 pub scales: Vec<f64>,
11 pub log_likelihood: f64,
13}
14
15#[derive(Debug, Clone)]
17pub struct ScaledBackwardResult {
18 pub beta: Vec<f64>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ScaledForwardBackwardResult {
25 pub alpha: Vec<f64>,
27 pub beta: Vec<f64>,
29 pub scales: Vec<f64>,
31 pub gamma: Vec<f64>,
33 pub xi: Vec<f64>,
35 pub log_likelihood: f64,
36}
37
38pub fn scaled_forward(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<ScaledForwardResult> {
44 if obs.is_empty() {
45 return Err(SeqError::EmptyInput);
46 }
47 let t_max = obs.len();
48 let n = hmm.n_states;
49
50 let mut alpha = vec![0.0f64; t_max * n];
51 let mut scales = vec![0.0f64; t_max];
52
53 for j in 0..n {
55 let em = hmm.b[j * hmm.n_obs + obs[0]];
56 alpha[j] = hmm.pi[j] * em;
57 }
58 let c0: f64 = alpha[..n].iter().sum();
59 if c0 < f64::MIN_POSITIVE {
60 return Err(SeqError::NumericalInstability(
61 "all initial emissions are zero for obs[0]".to_string(),
62 ));
63 }
64 let c0 = 1.0 / c0;
65 scales[0] = c0;
66 for j in 0..n {
67 alpha[j] *= c0;
68 }
69
70 let mut tmp_row = vec![0.0f64; n];
72 for t in 1..t_max {
73 tmp_row.copy_from_slice(&alpha[(t - 1) * n..t * n]);
75 for j in 0..n {
76 let em = hmm.b[j * hmm.n_obs + obs[t]];
77 let sum: f64 = (0..n).map(|i| tmp_row[i] * hmm.a[i * n + j]).sum();
78 alpha[t * n + j] = sum * em;
79 }
80 let row_sum: f64 = alpha[t * n..t * n + n].iter().sum();
81 if row_sum < f64::MIN_POSITIVE {
82 return Err(SeqError::NumericalInstability(format!(
83 "all scaled forward values vanished at t={t}"
84 )));
85 }
86 let ct = 1.0 / row_sum;
87 scales[t] = ct;
88 for j in 0..n {
89 alpha[t * n + j] *= ct;
90 }
91 }
92
93 let log_likelihood: f64 = scales.iter().map(|&c| -log_safe(c)).sum();
95
96 Ok(ScaledForwardResult {
97 alpha,
98 scales,
99 log_likelihood,
100 })
101}
102
103pub fn scaled_backward(
108 hmm: &HmmDiscrete,
109 obs: &[usize],
110 scales: &[f64],
111) -> SeqResult<ScaledBackwardResult> {
112 if obs.is_empty() {
113 return Err(SeqError::EmptyInput);
114 }
115 let t_max = obs.len();
116 if scales.len() != t_max {
117 return Err(SeqError::LengthMismatch {
118 a: scales.len(),
119 b: t_max,
120 });
121 }
122 let n = hmm.n_states;
123 let mut beta = vec![0.0f64; t_max * n];
124
125 let last_c = scales[t_max - 1];
127 for i in 0..n {
128 beta[(t_max - 1) * n + i] = last_c;
129 }
130
131 let mut tmp_next = vec![0.0f64; n];
133 for t in (0..t_max - 1).rev() {
134 tmp_next.copy_from_slice(&beta[(t + 1) * n..(t + 2) * n]);
136 let ct = scales[t];
137 for i in 0..n {
138 let mut s = 0.0f64;
139 for j in 0..n {
140 let em = hmm.b[j * hmm.n_obs + obs[t + 1]];
141 s += hmm.a[i * n + j] * em * tmp_next[j];
142 }
143 beta[t * n + i] = ct * s;
144 }
145 }
146
147 Ok(ScaledBackwardResult { beta })
148}
149
150pub fn scaled_forward_backward(
152 hmm: &HmmDiscrete,
153 obs: &[usize],
154) -> SeqResult<ScaledForwardBackwardResult> {
155 let sf = scaled_forward(hmm, obs)?;
156 let sb = scaled_backward(hmm, obs, &sf.scales)?;
157
158 let t_max = obs.len();
159 let n = hmm.n_states;
160
161 let mut gamma = vec![0.0f64; t_max * n];
163 for t in 0..t_max {
164 let mut row_sum = 0.0f64;
165 for i in 0..n {
166 let v = sf.alpha[t * n + i] * sb.beta[t * n + i];
167 gamma[t * n + i] = v;
168 row_sum += v;
169 }
170 if row_sum > 0.0 {
171 for i in 0..n {
172 gamma[t * n + i] /= row_sum;
173 }
174 }
175 }
176
177 let xi_len = t_max.saturating_sub(1) * n * n;
179 let mut xi = vec![0.0f64; xi_len];
180 for t in 0..t_max.saturating_sub(1) {
181 let mut total = 0.0f64;
182 for i in 0..n {
183 for j in 0..n {
184 let em = hmm.b[j * hmm.n_obs + obs[t + 1]];
185 let v = sf.alpha[t * n + i] * hmm.a[i * n + j] * em * sb.beta[(t + 1) * n + j];
186 xi[t * n * n + i * n + j] = v;
187 total += v;
188 }
189 }
190 if total > 0.0 {
191 for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
192 *v /= total;
193 }
194 }
195 }
196
197 Ok(ScaledForwardBackwardResult {
198 alpha: sf.alpha,
199 beta: sb.beta,
200 scales: sf.scales,
201 gamma,
202 xi,
203 log_likelihood: sf.log_likelihood,
204 })
205}
206
207pub fn scaled_baum_welch_step(
212 hmm: &HmmDiscrete,
213 obs: &[usize],
214 sfb: &ScaledForwardBackwardResult,
215) -> SeqResult<(Vec<f64>, Vec<f64>, Vec<f64>)> {
216 let t_max = obs.len();
217 let n = hmm.n_states;
218 let n_obs = hmm.n_obs;
219
220 let new_pi: Vec<f64> = sfb.gamma[..n].to_vec();
222
223 let mut a_num = vec![0.0f64; n * n];
225 for t in 0..t_max.saturating_sub(1) {
226 for i in 0..n {
227 for j in 0..n {
228 a_num[i * n + j] += sfb.xi[t * n * n + i * n + j];
229 }
230 }
231 }
232
233 let mut b_num = vec![0.0f64; n * n_obs];
235 for (t, &o) in obs.iter().enumerate() {
236 if o >= n_obs {
237 return Err(SeqError::IndexOutOfBounds {
238 index: o,
239 len: n_obs,
240 });
241 }
242 for j in 0..n {
243 b_num[j * n_obs + o] += sfb.gamma[t * n + j];
244 }
245 }
246
247 Ok((new_pi, a_num, b_num))
248}
249
250pub fn scaled_viterbi(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<Vec<usize>> {
255 if obs.is_empty() {
256 return Err(SeqError::EmptyInput);
257 }
258 let t_max = obs.len();
259 let n = hmm.n_states;
260
261 let mut delta = vec![f64::NEG_INFINITY; t_max * n];
262 let mut psi = vec![0usize; t_max * n];
263
264 for j in 0..n {
266 delta[j] = log_safe(hmm.pi[j]) + log_safe(hmm.b[j * hmm.n_obs + obs[0]]);
267 }
268
269 for t in 1..t_max {
270 for j in 0..n {
271 let log_em = log_safe(hmm.b[j * hmm.n_obs + obs[t]]);
272 let mut best = f64::NEG_INFINITY;
273 let mut argmax = 0usize;
274 for i in 0..n {
275 let v = delta[(t - 1) * n + i] + log_safe(hmm.a[i * n + j]);
276 if v > best {
277 best = v;
278 argmax = i;
279 }
280 }
281 delta[t * n + j] = best + log_em;
282 psi[t * n + j] = argmax;
283 }
284 }
285
286 let mut best = f64::NEG_INFINITY;
288 let mut last = 0usize;
289 for j in 0..n {
290 let v = delta[(t_max - 1) * n + j];
291 if v > best {
292 best = v;
293 last = j;
294 }
295 }
296
297 let mut path = vec![0usize; t_max];
299 path[t_max - 1] = last;
300 for t in (1..t_max).rev() {
301 path[t - 1] = psi[t * n + path[t]];
302 }
303
304 Ok(path)
305}
306
307#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::hmm::forward_backward::forward_backward;
313 use crate::hmm::viterbi::viterbi;
314
315 fn small_hmm() -> HmmDiscrete {
316 HmmDiscrete::new(
317 2,
318 2,
319 vec![0.6, 0.4],
320 vec![0.7, 0.3, 0.4, 0.6],
321 vec![0.1, 0.9, 0.8, 0.2],
322 )
323 .expect("small_hmm ok")
324 }
325
326 fn hmm_2s_2o() -> HmmDiscrete {
327 HmmDiscrete::new(
328 2,
329 2,
330 vec![0.5, 0.5],
331 vec![0.9, 0.1, 0.1, 0.9],
332 vec![0.9, 0.1, 0.1, 0.9],
333 )
334 .expect("hmm_2s_2o ok")
335 }
336
337 fn single_state_hmm() -> HmmDiscrete {
338 HmmDiscrete::new(1, 2, vec![1.0], vec![1.0], vec![0.5, 0.5]).expect("single ok")
339 }
340
341 #[test]
342 fn scaled_forward_likelihood_matches_log_space() {
343 let h = small_hmm();
344 let obs = vec![0usize, 1, 0, 1, 0];
345 let sf = scaled_forward(&h, &obs).expect("ok");
346 let fb = forward_backward(&h, &obs).expect("ok");
347 assert!(
348 (sf.log_likelihood - fb.log_likelihood).abs() < 1e-6,
349 "scaled ll={} log-space ll={}",
350 sf.log_likelihood,
351 fb.log_likelihood
352 );
353 }
354
355 #[test]
356 fn scaled_forward_scales_all_positive() {
357 let h = small_hmm();
358 let sf = scaled_forward(&h, &[0, 1, 0, 1]).expect("ok");
359 for (t, &c) in sf.scales.iter().enumerate() {
360 assert!(c > 0.0, "c[{t}]={c} not positive");
361 }
362 }
363
364 #[test]
365 fn scaled_forward_alpha_rows_sum_to_one() {
366 let h = small_hmm();
367 let obs = vec![0, 1, 0, 1];
368 let sf = scaled_forward(&h, &obs).expect("ok");
369 let n = h.n_states;
370 for t in 0..obs.len() {
371 let s: f64 = sf.alpha[t * n..(t + 1) * n].iter().sum();
372 assert!((s - 1.0).abs() < 1e-12, "t={t} row sum={s}");
373 }
374 }
375
376 #[test]
377 fn scaled_backward_beta_finite() {
378 let h = small_hmm();
379 let obs = vec![0, 1, 0];
380 let sf = scaled_forward(&h, &obs).expect("ok");
381 let sb = scaled_backward(&h, &obs, &sf.scales).expect("ok");
382 for &v in &sb.beta {
383 assert!(v.is_finite(), "beta value not finite: {v}");
384 }
385 }
386
387 #[test]
388 fn scaled_forward_backward_gamma_sum() {
389 let h = small_hmm();
390 let obs = vec![0, 1, 0, 1];
391 let sfb = scaled_forward_backward(&h, &obs).expect("ok");
392 let n = h.n_states;
393 for t in 0..obs.len() {
394 let s: f64 = sfb.gamma[t * n..(t + 1) * n].iter().sum();
395 assert!((s - 1.0).abs() < 1e-9, "gamma t={t} sum={s}");
396 }
397 }
398
399 #[test]
400 fn scaled_forward_backward_xi_sum() {
401 let h = small_hmm();
402 let obs = vec![0, 1, 0, 1];
403 let sfb = scaled_forward_backward(&h, &obs).expect("ok");
404 let n = h.n_states;
405 for t in 0..obs.len() - 1 {
406 let s: f64 = sfb.xi[t * n * n..(t + 1) * n * n].iter().sum();
407 assert!((s - 1.0).abs() < 1e-9, "xi t={t} sum={s}");
408 }
409 }
410
411 #[test]
412 fn scaled_ll_equals_log_space_ll() {
413 let h = hmm_2s_2o();
414 let obs = vec![0, 0, 1, 1, 0];
415 let sf = scaled_forward(&h, &obs).expect("ok");
416 let fb = forward_backward(&h, &obs).expect("ok");
417 assert!(
418 (sf.log_likelihood - fb.log_likelihood).abs() < 1e-6,
419 "scaled={} log-space={}",
420 sf.log_likelihood,
421 fb.log_likelihood
422 );
423 }
424
425 #[test]
426 fn scaled_forward_empty_obs_err() {
427 let h = small_hmm();
428 let res = scaled_forward(&h, &[]);
429 assert!(matches!(res, Err(SeqError::EmptyInput)));
430 }
431
432 #[test]
433 fn scaled_viterbi_consistent_with_standard_viterbi() {
434 let h = hmm_2s_2o();
435 let obs = vec![0, 0, 1, 1];
436 let sv = scaled_viterbi(&h, &obs).expect("ok");
437 let lv = viterbi(&h, &obs).expect("ok");
438 assert_eq!(
439 sv, lv.path,
440 "scaled_viterbi path diverges from log-space viterbi"
441 );
442 }
443
444 #[test]
445 fn scaled_forward_single_obs() {
446 let h = small_hmm();
447 let sf = scaled_forward(&h, &[0]).expect("ok");
448 assert_eq!(sf.alpha.len(), h.n_states);
449 assert_eq!(sf.scales.len(), 1);
450 let s: f64 = sf.alpha.iter().sum();
451 assert!((s - 1.0).abs() < 1e-12);
452 }
453
454 #[test]
455 fn scaled_forward_long_sequence_no_underflow() {
456 let h = hmm_2s_2o();
457 let obs: Vec<usize> = (0..1000).map(|i| i % 2).collect();
458 let sf = scaled_forward(&h, &obs);
459 assert!(sf.is_ok(), "scaled_forward failed on length-1000 sequence");
460 let sf = sf.expect("ok");
461 assert!(sf.log_likelihood.is_finite());
462 assert!(sf.log_likelihood < 0.0, "log-likelihood must be negative");
463 }
464
465 #[test]
466 fn scaled_backward_wrong_scales_len_err() {
467 let h = small_hmm();
468 let obs = vec![0, 1, 0];
469 let bad_scales = vec![1.0, 1.0];
470 let res = scaled_backward(&h, &obs, &bad_scales);
471 assert!(
472 matches!(res, Err(SeqError::LengthMismatch { .. })),
473 "expected LengthMismatch"
474 );
475 }
476
477 #[test]
478 fn scaled_baum_welch_step_pi_sums_to_1() {
479 let h = small_hmm();
480 let obs = vec![0, 1, 0, 1];
481 let sfb = scaled_forward_backward(&h, &obs).expect("ok");
482 let (new_pi, _, _) = scaled_baum_welch_step(&h, &obs, &sfb).expect("ok");
483 let s: f64 = new_pi.iter().sum();
484 assert!((s - 1.0).abs() < 1e-9, "new_pi sum={s}");
485 }
486
487 #[test]
488 fn scaled_baum_welch_step_shapes_correct() {
489 let h = small_hmm();
490 let obs = vec![0, 1, 0, 1];
491 let sfb = scaled_forward_backward(&h, &obs).expect("ok");
492 let (pi, a_num, b_num) = scaled_baum_welch_step(&h, &obs, &sfb).expect("ok");
493 assert_eq!(pi.len(), h.n_states);
494 assert_eq!(a_num.len(), h.n_states * h.n_states);
495 assert_eq!(b_num.len(), h.n_states * h.n_obs);
496 }
497
498 #[test]
499 fn scaled_forward_backward_2state_2obs() {
500 let h = HmmDiscrete::new(
503 2,
504 2,
505 vec![1.0, 0.0],
506 vec![0.0, 1.0, 1.0, 0.0],
507 vec![0.99, 0.01, 0.01, 0.99],
508 )
509 .expect("ok");
510 let obs = vec![0, 1, 0, 1];
511 let sfb = scaled_forward_backward(&h, &obs).expect("ok");
512 assert!(sfb.gamma[0] > 0.9, "gamma[0][0]={}", sfb.gamma[0]);
514 let n = h.n_states;
516 assert!(sfb.gamma[n + 1] > 0.9, "gamma[1][1]={}", sfb.gamma[n + 1]);
517 }
518
519 #[test]
520 fn scaled_forward_single_state() {
521 let h = single_state_hmm();
522 let obs = vec![0, 1, 0];
523 let sf = scaled_forward(&h, &obs).expect("ok");
524 assert_eq!(sf.scales.len(), 3);
525 assert_eq!(sf.alpha.len(), 3);
526 for &a in &sf.alpha {
527 assert!(
528 (a - 1.0).abs() < 1e-12,
529 "single-state alpha must be 1.0, got {a}"
530 );
531 }
532 }
533
534 #[test]
535 fn scaled_viterbi_single_state() {
536 let h = single_state_hmm();
537 let obs = vec![0, 1, 0, 1];
538 let path = scaled_viterbi(&h, &obs).expect("ok");
539 assert_eq!(
540 path,
541 vec![0, 0, 0, 0],
542 "single-state path must be all zeros"
543 );
544 }
545}