1use crate::error::{SeqError, SeqResult};
14
15pub fn digamma(mut x: f64) -> f64 {
26 let mut result = 0.0;
28
29 while x < 6.0 {
31 result -= 1.0 / x;
32 x += 1.0;
33 }
34
35 let x2 = x * x;
37 let x4 = x2 * x2;
38 let x6 = x4 * x2;
39 result += x.ln() - 0.5 / x - 1.0 / (12.0 * x2) + 1.0 / (120.0 * x4) - 1.0 / (252.0 * x6);
40 result
41}
42
43pub fn log_gamma(x: f64) -> f64 {
46 const G: f64 = 7.0;
48 const C: [f64; 9] = [
49 0.999_999_999_999_809_3,
50 676.520_368_121_885_1,
51 -1_259.139_216_722_402_8,
52 771.323_428_777_653_1,
53 -176.615_029_162_140_6,
54 12.507_343_278_686_905,
55 -0.138_571_095_265_720_12,
56 9.984_369_578_019_572e-6,
57 1.505_632_735_149_311_6e-7,
58 ];
59
60 if x < 0.5 {
61 use std::f64::consts::PI;
63 return PI.ln() - (PI * x).sin().ln() - log_gamma(1.0 - x);
64 }
65
66 let z = x - 1.0;
67 let mut sum = C[0];
68 for (k, &ck) in C[1..].iter().enumerate() {
69 sum += ck / (z + (k as f64 + 1.0));
70 }
71
72 use std::f64::consts::PI;
73 let t = z + G + 0.5;
74 (2.0 * PI).sqrt().ln() + sum.ln() + (z + 0.5) * t.ln() - t
75}
76
77pub fn dirichlet_log_normalizer(alpha: &[f64]) -> f64 {
80 let sum_alpha: f64 = alpha.iter().sum();
81 let sum_log_gamma: f64 = alpha.iter().map(|&a| log_gamma(a)).sum();
82 sum_log_gamma - log_gamma(sum_alpha)
83}
84
85#[derive(Debug, Clone)]
89pub struct VbHmmConfig {
90 pub n_states: usize,
92 pub n_obs: usize,
94 pub alpha_prior: f64,
96 pub beta_prior: f64,
98 pub gamma_prior: f64,
100 pub max_iter: usize,
102 pub tol: f64,
104}
105
106impl Default for VbHmmConfig {
107 fn default() -> Self {
108 Self {
109 n_states: 2,
110 n_obs: 2,
111 alpha_prior: 1.0,
112 beta_prior: 1.0,
113 gamma_prior: 1.0,
114 max_iter: 200,
115 tol: 1e-6,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct VbHmmResult {
123 pub alpha: Vec<f64>,
125 pub beta: Vec<f64>,
128 pub gamma: Vec<f64>,
131 pub elbo_history: Vec<f64>,
133 pub n_iter: usize,
135 pub converged: bool,
137}
138
139impl VbHmmResult {
140 pub fn expected_log_pi(&self) -> Vec<f64> {
142 let sum_alpha: f64 = self.alpha.iter().sum();
143 let psi_sum = digamma(sum_alpha);
144 self.alpha.iter().map(|&a| digamma(a) - psi_sum).collect()
145 }
146
147 pub fn mean_pi(&self) -> Vec<f64> {
149 let s: f64 = self.alpha.iter().sum();
150 self.alpha.iter().map(|&a| a / s).collect()
151 }
152
153 pub fn mean_a(&self) -> Vec<f64> {
155 let n = self.alpha.len(); let mut out = vec![0.0; n * n];
157 for i in 0..n {
158 let s: f64 = self.beta[i * n..(i + 1) * n].iter().sum();
159 for j in 0..n {
160 out[i * n + j] = if s > 0.0 {
161 self.beta[i * n + j] / s
162 } else {
163 1.0 / n as f64
164 };
165 }
166 }
167 out
168 }
169
170 pub fn mean_b(&self) -> Vec<f64> {
172 let n = self.alpha.len(); let k = self.gamma.len() / n; let mut out = vec![0.0; n * k];
175 for j in 0..n {
176 let s: f64 = self.gamma[j * k..(j + 1) * k].iter().sum();
177 for sym in 0..k {
178 out[j * k + sym] = if s > 0.0 {
179 self.gamma[j * k + sym] / s
180 } else {
181 1.0 / k as f64
182 };
183 }
184 }
185 out
186 }
187}
188
189#[inline]
193fn logsumexp(xs: &[f64]) -> f64 {
194 let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
195 if m == f64::NEG_INFINITY {
196 return f64::NEG_INFINITY;
197 }
198 let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
199 m + s.ln()
200}
201
202fn vb_forward_backward(
218 log_pi_eff: &[f64],
219 log_a_eff: &[f64],
220 log_em_eff: &[f64],
221 n: usize,
222 t_max: usize,
223) -> (Vec<f64>, Vec<f64>, f64) {
224 let mut log_alpha = vec![f64::NEG_INFINITY; t_max * n];
226
227 for j in 0..n {
229 log_alpha[j] = log_pi_eff[j] + log_em_eff[j];
230 }
231
232 let mut tmp = vec![0.0f64; n];
233 for t in 1..t_max {
234 for j in 0..n {
235 for i in 0..n {
236 tmp[i] = log_alpha[(t - 1) * n + i] + log_a_eff[i * n + j];
237 }
238 log_alpha[t * n + j] = logsumexp(&tmp) + log_em_eff[t * n + j];
239 }
240 }
241
242 let ll = logsumexp(&log_alpha[(t_max - 1) * n..t_max * n]);
243
244 let mut log_beta = vec![f64::NEG_INFINITY; t_max * n];
246 for i in 0..n {
247 log_beta[(t_max - 1) * n + i] = 0.0;
248 }
249 for t in (0..t_max.saturating_sub(1)).rev() {
250 for i in 0..n {
251 for j in 0..n {
252 tmp[j] =
253 log_a_eff[i * n + j] + log_em_eff[(t + 1) * n + j] + log_beta[(t + 1) * n + j];
254 }
255 log_beta[t * n + i] = logsumexp(&tmp);
256 }
257 }
258
259 let mut gamma = vec![0.0f64; t_max * n];
261 for t in 0..t_max {
262 for i in 0..n {
263 gamma[t * n + i] = (log_alpha[t * n + i] + log_beta[t * n + i] - ll).exp();
264 }
265 let s: f64 = gamma[t * n..t * n + n].iter().sum();
267 if s > 0.0 {
268 for i in 0..n {
269 gamma[t * n + i] /= s;
270 }
271 }
272 }
273
274 let xi_len = t_max.saturating_sub(1) * n * n;
276 let mut xi = vec![0.0f64; xi_len];
277 for t in 0..t_max.saturating_sub(1) {
278 let mut s = 0.0;
279 for i in 0..n {
280 for j in 0..n {
281 let v = (log_alpha[t * n + i]
282 + log_a_eff[i * n + j]
283 + log_em_eff[(t + 1) * n + j]
284 + log_beta[(t + 1) * n + j]
285 - ll)
286 .exp();
287 xi[t * n * n + i * n + j] = v;
288 s += v;
289 }
290 }
291 if s > 0.0 {
292 for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
293 *v /= s;
294 }
295 }
296 }
297
298 (gamma, xi, ll)
299}
300
301fn kl_dirichlet(alpha: &[f64], alpha_0: &[f64]) -> f64 {
305 let log_b_alpha_0 = dirichlet_log_normalizer(alpha_0);
306 let log_b_alpha = dirichlet_log_normalizer(alpha);
307 let sum_alpha: f64 = alpha.iter().sum();
308 let psi_sum = digamma(sum_alpha);
309
310 let correction: f64 = alpha
311 .iter()
312 .zip(alpha_0.iter())
313 .map(|(&ai, &a0i)| (a0i - ai) * (digamma(ai) - psi_sum))
314 .sum();
315
316 log_b_alpha_0 - log_b_alpha + correction
317}
318
319pub fn variational_hmm(observations: &[&[usize]], cfg: &VbHmmConfig) -> SeqResult<VbHmmResult> {
326 if observations.is_empty() || observations.iter().all(|s| s.is_empty()) {
328 return Err(SeqError::EmptyInput);
329 }
330 if cfg.n_states == 0 || cfg.n_obs == 0 {
331 return Err(SeqError::InvalidConfiguration(
332 "n_states and n_obs must be > 0".to_string(),
333 ));
334 }
335 for seq in observations.iter() {
336 for &o in *seq {
337 if o >= cfg.n_obs {
338 return Err(SeqError::InvalidObservation(format!(
339 "observation {o} >= n_obs {}",
340 cfg.n_obs
341 )));
342 }
343 }
344 }
345 if observations.iter().all(|s| s.is_empty()) {
347 return Err(SeqError::EmptyInput);
348 }
349
350 let n = cfg.n_states;
351 let k = cfg.n_obs;
352
353 let mut alpha: Vec<f64> = (0..n)
356 .map(|i| cfg.alpha_prior + (i as f64 + 1.0) * 0.1 / n as f64)
357 .collect();
358
359 let mut beta: Vec<f64> = vec![0.0; n * n];
361 for i in 0..n {
362 for j in 0..n {
363 beta[i * n + j] = if i == j {
364 cfg.beta_prior + 0.5
365 } else if n > 1 {
366 cfg.beta_prior + 0.1 / (n as f64 - 1.0)
367 } else {
368 cfg.beta_prior
369 };
370 }
371 }
372
373 let mut gamma_dir: Vec<f64> = vec![cfg.gamma_prior + 0.1; n * k];
375
376 let mut elbo_history: Vec<f64> = Vec::with_capacity(cfg.max_iter + 1);
377 let mut prev_elbo = f64::NEG_INFINITY;
378 let mut converged = false;
379 let mut n_iter = 0usize;
380
381 for iter in 0..cfg.max_iter {
383 n_iter = iter + 1;
384
385 let sum_alpha: f64 = alpha.iter().sum();
388 let psi_sum_alpha = digamma(sum_alpha);
389 let log_pi_eff: Vec<f64> = alpha.iter().map(|&a| digamma(a) - psi_sum_alpha).collect();
390
391 let mut log_a_eff: Vec<f64> = vec![0.0; n * n];
393 for i in 0..n {
394 let sum_beta_i: f64 = beta[i * n..(i + 1) * n].iter().sum();
395 let psi_sum_beta_i = digamma(sum_beta_i);
396 for j in 0..n {
397 log_a_eff[i * n + j] = digamma(beta[i * n + j]) - psi_sum_beta_i;
398 }
399 }
400
401 let mut log_b_eff: Vec<f64> = vec![0.0; n * k];
403 for j in 0..n {
404 let sum_gamma_j: f64 = gamma_dir[j * k..(j + 1) * k].iter().sum();
405 let psi_sum_gamma_j = digamma(sum_gamma_j);
406 for sym in 0..k {
407 log_b_eff[j * k + sym] = digamma(gamma_dir[j * k + sym]) - psi_sum_gamma_j;
408 }
409 }
410
411 let mut ss_pi = vec![0.0f64; n];
417 let mut ss_a = vec![0.0f64; n * n];
418 let mut ss_b = vec![0.0f64; n * k];
419
420 for seq in observations.iter() {
421 if seq.is_empty() {
422 continue;
423 }
424 let t_max = seq.len();
425
426 let mut log_em_eff = vec![0.0f64; t_max * n];
428 for t in 0..t_max {
429 for j in 0..n {
430 log_em_eff[t * n + j] = log_b_eff[j * k + seq[t]];
431 }
432 }
433
434 let (gamma_seq, xi_seq, _ll_seq) =
435 vb_forward_backward(&log_pi_eff, &log_a_eff, &log_em_eff, n, t_max);
436
437 for i in 0..n {
439 ss_pi[i] += gamma_seq[i];
440 }
441
442 for t in 0..t_max.saturating_sub(1) {
444 for i in 0..n {
445 for j in 0..n {
446 ss_a[i * n + j] += xi_seq[t * n * n + i * n + j];
447 }
448 }
449 }
450
451 for t in 0..t_max {
453 for j in 0..n {
454 ss_b[j * k + seq[t]] += gamma_seq[t * n + j];
455 }
456 }
457 }
458
459 for i in 0..n {
461 alpha[i] = cfg.alpha_prior + ss_pi[i];
462 }
463 for i in 0..n {
464 for j in 0..n {
465 beta[i * n + j] = cfg.beta_prior + ss_a[i * n + j];
466 }
467 }
468 for j in 0..n {
469 for sym in 0..k {
470 gamma_dir[j * k + sym] = cfg.gamma_prior + ss_b[j * k + sym];
471 }
472 }
473
474 let sum_alpha_new: f64 = alpha.iter().sum();
480 let psi_sum_alpha_new = digamma(sum_alpha_new);
481 let log_pi_new: Vec<f64> = alpha
482 .iter()
483 .map(|&a| digamma(a) - psi_sum_alpha_new)
484 .collect();
485
486 let mut log_a_new: Vec<f64> = vec![0.0; n * n];
487 for i in 0..n {
488 let sum_beta_i: f64 = beta[i * n..(i + 1) * n].iter().sum();
489 let psi_sum_bi = digamma(sum_beta_i);
490 for j in 0..n {
491 log_a_new[i * n + j] = digamma(beta[i * n + j]) - psi_sum_bi;
492 }
493 }
494
495 let mut log_b_new: Vec<f64> = vec![0.0; n * k];
496 for j in 0..n {
497 let sum_gamma_j: f64 = gamma_dir[j * k..(j + 1) * k].iter().sum();
498 let psi_sum_gj = digamma(sum_gamma_j);
499 for sym in 0..k {
500 log_b_new[j * k + sym] = digamma(gamma_dir[j * k + sym]) - psi_sum_gj;
501 }
502 }
503
504 let mut elbo_ll = 0.0f64;
505 for seq in observations.iter() {
506 if seq.is_empty() {
507 continue;
508 }
509 let t_max = seq.len();
510 let mut log_em_new = vec![0.0f64; t_max * n];
511 for t in 0..t_max {
512 for j in 0..n {
513 log_em_new[t * n + j] = log_b_new[j * k + seq[t]];
514 }
515 }
516 let (_, _, ll_new) =
517 vb_forward_backward(&log_pi_new, &log_a_new, &log_em_new, n, t_max);
518 elbo_ll += ll_new;
519 }
520
521 let alpha_prior_vec = vec![cfg.alpha_prior; n];
522 let beta_prior_vec = vec![cfg.beta_prior; n];
523 let gamma_prior_vec = vec![cfg.gamma_prior; k];
524
525 let mut kl_total = kl_dirichlet(&alpha, &alpha_prior_vec);
526 for i in 0..n {
527 kl_total += kl_dirichlet(&beta[i * n..(i + 1) * n], &beta_prior_vec);
528 }
529 for j in 0..n {
530 kl_total += kl_dirichlet(&gamma_dir[j * k..(j + 1) * k], &gamma_prior_vec);
531 }
532
533 let elbo = elbo_ll - kl_total;
534 elbo_history.push(elbo);
535
536 if iter > 0 && (elbo - prev_elbo).abs() < cfg.tol {
538 converged = true;
539 break;
540 }
541 prev_elbo = elbo;
542 }
543
544 Ok(VbHmmResult {
545 alpha,
546 beta,
547 gamma: gamma_dir,
548 elbo_history,
549 n_iter,
550 converged,
551 })
552}
553
554#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
563 fn digamma_at_one_is_neg_euler_mascheroni() {
564 let d = digamma(1.0);
566 assert!((d - (-0.577_215_664_9)).abs() < 1e-6, "digamma(1) = {d}");
567 }
568
569 #[test]
570 fn digamma_at_two() {
571 let d = digamma(2.0);
573 assert!((d - 0.422_784_335_1).abs() < 1e-6, "digamma(2) = {d}");
574 }
575
576 #[test]
577 fn digamma_recurrence() {
578 for &x in &[0.5, 1.0, 2.0, 3.5, 7.0] {
580 let lhs = digamma(x + 1.0);
581 let rhs = digamma(x) + 1.0 / x;
582 assert!(
583 (lhs - rhs).abs() < 1e-9,
584 "recurrence failed at x={x}: {lhs} vs {rhs}"
585 );
586 }
587 }
588
589 #[test]
590 fn digamma_large_argument() {
591 let d = digamma(100.0);
593 let approx = 100.0_f64.ln() - 0.005;
594 assert!((d - approx).abs() < 0.01, "digamma(100) = {d}");
595 }
596
597 #[test]
600 fn log_gamma_at_one() {
601 assert!(log_gamma(1.0).abs() < 1e-10);
603 }
604
605 #[test]
606 fn log_gamma_at_two() {
607 assert!(log_gamma(2.0).abs() < 1e-10);
609 }
610
611 #[test]
612 fn log_gamma_at_half() {
613 let expected = 0.5 * std::f64::consts::PI.ln();
615 let got = log_gamma(0.5);
616 assert!((got - expected).abs() < 1e-9, "log_gamma(0.5) = {got}");
617 }
618
619 #[test]
620 fn log_gamma_integer_values() {
621 let got = log_gamma(4.0);
624 let expected = 6.0_f64.ln();
625 assert!((got - expected).abs() < 1e-9, "log_gamma(4) = {got}");
626 }
627
628 #[test]
629 fn log_gamma_five() {
630 let got = log_gamma(5.0);
632 let expected = 24.0_f64.ln();
633 assert!((got - expected).abs() < 1e-9, "log_gamma(5) = {got}");
634 }
635
636 fn simple_obs() -> Vec<usize> {
639 vec![0, 0, 1, 1, 0, 0, 1, 1, 0, 1]
640 }
641
642 #[test]
643 fn default_config_produces_valid_result() {
644 let obs = simple_obs();
645 let cfg = VbHmmConfig::default();
646 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("should succeed");
647 assert!(r.n_iter > 0);
648 assert!(r.n_iter <= cfg.max_iter);
649 assert!(!r.elbo_history.is_empty());
650 }
651
652 #[test]
653 fn mean_pi_sums_to_one() {
654 let obs = simple_obs();
655 let cfg = VbHmmConfig::default();
656 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
657 let s: f64 = r.mean_pi().iter().sum();
658 assert!((s - 1.0).abs() < 1e-10, "mean_pi sum = {s}");
659 }
660
661 #[test]
662 fn mean_a_rows_sum_to_one() {
663 let obs = simple_obs();
664 let cfg = VbHmmConfig::default();
665 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
666 let n = cfg.n_states;
667 let a = r.mean_a();
668 for i in 0..n {
669 let s: f64 = a[i * n..(i + 1) * n].iter().sum();
670 assert!((s - 1.0).abs() < 1e-10, "mean_a row {i} sums to {s}");
671 }
672 }
673
674 #[test]
675 fn mean_b_rows_sum_to_one() {
676 let obs = simple_obs();
677 let cfg = VbHmmConfig::default();
678 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
679 let n = cfg.n_states;
680 let k = cfg.n_obs;
681 let b = r.mean_b();
682 for j in 0..n {
683 let s: f64 = b[j * k..(j + 1) * k].iter().sum();
684 assert!((s - 1.0).abs() < 1e-10, "mean_b row {j} sums to {s}");
685 }
686 }
687
688 #[test]
689 fn elbo_history_non_decreasing() {
690 let obs = simple_obs();
694 let cfg = VbHmmConfig::default();
695 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
696 if r.elbo_history.len() >= 2 {
698 let first = r.elbo_history[0];
699 let last = *r.elbo_history.last().expect("non-empty");
700 assert!(
702 last >= first - 2.0,
703 "Final ELBO ({last}) is much worse than initial ({first})"
704 );
705 }
706 for w in r.elbo_history.windows(2) {
708 assert!(
709 w[1] >= w[0] - 1.0,
710 "ELBO dropped by more than 1 nat: {} → {}",
711 w[0],
712 w[1]
713 );
714 }
715 }
716
717 #[test]
718 fn n_iter_within_max_iter() {
719 let obs = simple_obs();
720 let cfg = VbHmmConfig {
721 max_iter: 50,
722 ..Default::default()
723 };
724 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
725 assert!(r.n_iter <= 50);
726 }
727
728 #[test]
729 fn posteriors_exceed_prior_when_data_given() {
730 let obs: Vec<usize> = (0..30).map(|i| i % 2).collect();
732 let cfg = VbHmmConfig {
733 alpha_prior: 1.0,
734 beta_prior: 1.0,
735 gamma_prior: 1.0,
736 ..Default::default()
737 };
738 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
739 for &a in &r.alpha {
740 assert!(
741 a > cfg.alpha_prior,
742 "alpha {a} not > prior {}",
743 cfg.alpha_prior
744 );
745 }
746 }
747
748 #[test]
749 fn multiple_sequences_accepted() {
750 let seq1 = vec![0usize, 1, 0, 1];
751 let seq2 = vec![1usize, 1, 0, 0];
752 let seq3 = vec![0usize, 0, 1, 1, 0];
753 let cfg = VbHmmConfig::default();
754 let r = variational_hmm(&[&seq1, &seq2, &seq3], &cfg).expect("ok");
755 assert!(!r.elbo_history.is_empty());
756 }
757
758 #[test]
759 fn empty_observations_returns_err() {
760 let cfg = VbHmmConfig::default();
761 assert!(variational_hmm(&[], &cfg).is_err());
762 }
763
764 #[test]
765 fn obs_out_of_range_returns_err() {
766 let obs = vec![0usize, 5]; let cfg = VbHmmConfig::default();
768 assert!(variational_hmm(&[&obs], &cfg).is_err());
769 }
770
771 #[test]
772 fn single_observation_length_one_works() {
773 let obs = vec![0usize];
774 let cfg = VbHmmConfig::default();
775 let r = variational_hmm(&[&obs], &cfg).expect("length-1 seq should work");
776 assert!(!r.elbo_history.is_empty());
777 }
778
779 #[test]
780 fn converged_flag_set_on_tight_convergence() {
781 let obs: Vec<usize> = (0..50).map(|i| i % 2).collect();
783 let cfg = VbHmmConfig {
784 max_iter: 500,
785 tol: 1e-3,
786 ..Default::default()
787 };
788 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
789 assert!(
790 r.converged,
791 "expected convergence with tol=1e-3 and 500 iterations"
792 );
793 }
794
795 #[test]
796 fn larger_state_space() {
797 let obs: Vec<usize> = (0..40).map(|i| i % 4).collect();
798 let cfg = VbHmmConfig {
799 n_states: 4,
800 n_obs: 4,
801 max_iter: 100,
802 ..Default::default()
803 };
804 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
805 assert_eq!(r.alpha.len(), 4);
806 assert_eq!(r.beta.len(), 16);
807 assert_eq!(r.gamma.len(), 16);
808 }
809
810 #[test]
811 fn expected_log_pi_returns_correct_length() {
812 let obs = simple_obs();
813 let cfg = VbHmmConfig::default();
814 let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
815 let elp = r.expected_log_pi();
816 assert_eq!(elp.len(), cfg.n_states);
817 for &v in &elp {
818 assert!(v.is_finite(), "expected_log_pi entry is not finite: {v}");
819 assert!(v <= 0.0, "expected_log_pi entry should be ≤ 0: {v}");
820 }
821 }
822}