1use crate::error::{SeqError, SeqResult};
19
20#[inline]
23fn logsumexp(xs: &[f64]) -> f64 {
24 let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
25 if m == f64::NEG_INFINITY {
26 return f64::NEG_INFINITY;
27 }
28 let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
29 m + s.ln()
30}
31
32#[inline]
33fn log_safe(x: f64) -> f64 {
34 if x <= 0.0 || !x.is_finite() {
35 f64::NEG_INFINITY
36 } else {
37 x.ln()
38 }
39}
40
41#[derive(Debug, Clone)]
49pub enum DurationDistrib {
50 Poisson { lambda: f64 },
53 Geometric { p: f64 },
55 Histogram { probs: Vec<f64> },
57}
58
59impl DurationDistrib {
60 pub fn prob(&self, tau: usize, max_dur: usize) -> f64 {
66 if tau == 0 {
67 return 0.0;
68 }
69 match self {
70 DurationDistrib::Poisson { lambda } => {
71 if *lambda <= 0.0 {
72 return if tau == 1 { 1.0 } else { 0.0 };
73 }
74 let raw = poisson_pmf(*lambda, tau);
76 let total: f64 = (1..=max_dur).map(|d| poisson_pmf(*lambda, d)).sum();
77 if total > 0.0 { raw / total } else { 0.0 }
78 }
79 DurationDistrib::Geometric { p } => {
80 let p = p.clamp(0.0, 1.0);
81 p * (1.0 - p).powi((tau - 1) as i32)
82 }
83 DurationDistrib::Histogram { probs } => {
84 if tau <= probs.len() {
85 probs[tau - 1]
86 } else {
87 0.0
88 }
89 }
90 }
91 }
92
93 pub fn log_prob(&self, tau: usize, max_dur: usize) -> f64 {
95 log_safe(self.prob(tau, max_dur))
96 }
97}
98
99fn poisson_pmf(lambda: f64, k: usize) -> f64 {
101 if lambda <= 0.0 {
102 return if k == 0 { 1.0 } else { 0.0 };
103 }
104 let log_p = (k as f64) * lambda.ln() - lambda - log_factorial(k);
106 log_p.exp()
107}
108
109fn log_factorial(k: usize) -> f64 {
111 if k <= 1 {
112 return 0.0;
113 }
114 (1..=k).map(|i| (i as f64).ln()).sum()
115}
116
117#[derive(Debug, Clone)]
121pub struct Hsmm {
122 pub n_states: usize,
124 pub n_obs: usize,
126 pub max_dur: usize,
128 pub pi: Vec<f64>,
130 pub a: Vec<f64>,
133 pub b: Vec<f64>,
135 pub dur: Vec<DurationDistrib>,
137}
138
139impl Hsmm {
140 pub fn new(
142 n_states: usize,
143 n_obs: usize,
144 max_dur: usize,
145 pi: Vec<f64>,
146 a: Vec<f64>,
147 b: Vec<f64>,
148 dur: Vec<DurationDistrib>,
149 ) -> SeqResult<Self> {
150 if n_states == 0 || n_obs == 0 || max_dur == 0 {
151 return Err(SeqError::InvalidConfiguration(
152 "n_states, n_obs, and max_dur must all be > 0".to_string(),
153 ));
154 }
155 if pi.len() != n_states {
156 return Err(SeqError::ShapeMismatch {
157 expected: n_states,
158 got: pi.len(),
159 });
160 }
161 if a.len() != n_states * n_states {
162 return Err(SeqError::ShapeMismatch {
163 expected: n_states * n_states,
164 got: a.len(),
165 });
166 }
167 if b.len() != n_states * n_obs {
168 return Err(SeqError::ShapeMismatch {
169 expected: n_states * n_obs,
170 got: b.len(),
171 });
172 }
173 if dur.len() != n_states {
174 return Err(SeqError::ShapeMismatch {
175 expected: n_states,
176 got: dur.len(),
177 });
178 }
179
180 for i in 0..n_states {
182 let diag = a[i * n_states + i];
183 if diag.abs() > 1e-9 {
184 return Err(SeqError::InvalidConfiguration(format!(
185 "transition matrix diagonal A[{i},{i}] = {diag} must be 0"
186 )));
187 }
188 }
189
190 let pi_sum: f64 = pi.iter().sum();
192 if (pi_sum - 1.0).abs() > 1e-5 {
193 return Err(SeqError::InvalidConfiguration(format!(
194 "pi sums to {pi_sum}, expected 1"
195 )));
196 }
197 if n_states > 1 {
199 for i in 0..n_states {
200 let s: f64 = a[i * n_states..(i + 1) * n_states].iter().sum();
201 if (s - 1.0).abs() > 1e-5 {
202 return Err(SeqError::InvalidConfiguration(format!(
203 "A row {i} sums to {s}, expected 1"
204 )));
205 }
206 }
207 }
208 for j in 0..n_states {
210 let s: f64 = b[j * n_obs..(j + 1) * n_obs].iter().sum();
211 if (s - 1.0).abs() > 1e-5 {
212 return Err(SeqError::InvalidConfiguration(format!(
213 "B row {j} sums to {s}, expected 1"
214 )));
215 }
216 }
217
218 Ok(Self {
219 n_states,
220 n_obs,
221 max_dur,
222 pi,
223 a,
224 b,
225 dur,
226 })
227 }
228
229 pub fn log_likelihood(&self, obs: &[usize]) -> SeqResult<f64> {
231 if obs.is_empty() {
232 return Err(SeqError::EmptyInput);
233 }
234 for &o in obs {
235 if o >= self.n_obs {
236 return Err(SeqError::InvalidObservation(format!(
237 "observation {o} >= n_obs {}",
238 self.n_obs
239 )));
240 }
241 }
242 let (_, log_z) = hsmm_forward(self, obs);
243 Ok(log_z)
244 }
245
246 pub fn decode(&self, obs: &[usize]) -> SeqResult<Vec<usize>> {
248 if obs.is_empty() {
249 return Err(SeqError::EmptyInput);
250 }
251 for &o in obs {
252 if o >= self.n_obs {
253 return Err(SeqError::InvalidObservation(format!(
254 "observation {o} >= n_obs {}",
255 self.n_obs
256 )));
257 }
258 }
259 hsmm_viterbi(self, obs)
260 }
261}
262
263fn build_cum_log_b(model: &Hsmm, obs: &[usize]) -> Vec<f64> {
270 let t_max = obs.len();
271 let n = model.n_states;
272 let mut cum = vec![0.0f64; n * (t_max + 1)];
274 for j in 0..n {
275 for t in 0..t_max {
276 let le = log_safe(model.b[j * model.n_obs + obs[t]]);
277 cum[j * (t_max + 1) + t + 1] = cum[j * (t_max + 1) + t] + le;
278 }
279 }
280 cum
281}
282
283#[inline]
285fn seg_log_em(cum: &[f64], j: usize, t1: usize, t2: usize, t_max: usize) -> f64 {
286 cum[j * (t_max + 1) + t2 + 1] - cum[j * (t_max + 1) + t1]
287}
288
289fn hsmm_forward(model: &Hsmm, obs: &[usize]) -> (Vec<f64>, f64) {
297 let t_max = obs.len();
298 let n = model.n_states;
299 let d_max = model.max_dur.min(t_max);
300 let cum = build_cum_log_b(model, obs);
301
302 let mut log_e = vec![f64::NEG_INFINITY; n * t_max];
305
306 let log_pi: Vec<f64> = model.pi.iter().map(|&p| log_safe(p)).collect();
308
309 let log_a: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
311
312 let mut terms = Vec::with_capacity(d_max * n);
314
315 for t in 0..t_max {
316 for j in 0..n {
318 terms.clear();
319 let d_end = (t + 1).min(d_max); for d in 1..=d_end {
321 let t_start = t + 1 - d; let log_dur = model.dur[j].log_prob(d, model.max_dur);
324 let log_em_seg = seg_log_em(&cum, j, t_start, t, t_max);
325
326 let log_init = if d == t + 1 {
329 log_pi[j]
331 } else {
332 let prev_t = t_start - 1; let mut trans_terms: Vec<f64> = Vec::with_capacity(n);
335 for i in 0..n {
336 if i == j {
337 continue;
338 }
339 let log_prev = log_e[i * t_max + prev_t];
340 if log_prev == f64::NEG_INFINITY {
341 continue;
342 }
343 trans_terms.push(log_a[i * n + j] + log_prev);
344 }
345 if trans_terms.is_empty() {
346 f64::NEG_INFINITY
347 } else {
348 logsumexp(&trans_terms)
349 }
350 };
351
352 if log_init == f64::NEG_INFINITY || log_dur == f64::NEG_INFINITY {
353 continue;
354 }
355 terms.push(log_dur + log_em_seg + log_init);
356 }
357 log_e[j * t_max + t] = if terms.is_empty() {
358 f64::NEG_INFINITY
359 } else {
360 logsumexp(&terms)
361 };
362 }
363 }
364
365 let last_terms: Vec<f64> = (0..n).map(|j| log_e[j * t_max + t_max - 1]).collect();
367 let log_z = logsumexp(&last_terms);
368
369 (log_e, log_z)
370}
371
372fn hsmm_backward(model: &Hsmm, obs: &[usize], cum: &[f64]) -> Vec<f64> {
381 let t_max = obs.len();
382 let n = model.n_states;
383 let d_max = model.max_dur;
384
385 let log_a: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
386
387 let mut log_f = vec![f64::NEG_INFINITY; n * t_max];
389
390 for j in 0..n {
392 log_f[j * t_max + t_max - 1] = 0.0;
393 }
394
395 let mut terms: Vec<f64> = Vec::with_capacity(d_max * n);
399
400 for t in (0..t_max.saturating_sub(1)).rev() {
401 for j in 0..n {
402 terms.clear();
403 let remaining = t_max - t - 1; let d_end = d_max.min(remaining);
406
407 for k in 0..n {
408 if k == j {
409 continue;
410 }
411 let log_ajk = log_a[j * n + k];
412 if log_ajk == f64::NEG_INFINITY {
413 continue;
414 }
415 for d in 1..=d_end {
416 let t_start_new = t + 1;
418 let t_end_new = t + d; let log_dur = model.dur[k].log_prob(d, model.max_dur);
420 let log_em_seg = seg_log_em(cum, k, t_start_new, t_end_new, t_max);
421 let log_f_next = log_f[k * t_max + t_end_new];
422 if log_dur == f64::NEG_INFINITY || log_f_next == f64::NEG_INFINITY {
423 continue;
424 }
425 terms.push(log_ajk + log_dur + log_em_seg + log_f_next);
426 }
427 }
428
429 log_f[j * t_max + t] = if terms.is_empty() {
430 f64::NEG_INFINITY
431 } else {
432 logsumexp(&terms)
433 };
434 }
435 }
436
437 log_f
438}
439
440fn hsmm_viterbi(model: &Hsmm, obs: &[usize]) -> SeqResult<Vec<usize>> {
445 let t_max = obs.len();
446 let n = model.n_states;
447 let d_max = model.max_dur.min(t_max);
448 let cum = build_cum_log_b(model, obs);
449
450 let log_pi: Vec<f64> = model.pi.iter().map(|&p| log_safe(p)).collect();
451 let log_a: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
452
453 let mut log_v = vec![f64::NEG_INFINITY; n * t_max];
456
457 let mut bp_d = vec![0usize; n * t_max];
460 let mut bp_prev = vec![n; n * t_max]; for t in 0..t_max {
463 for j in 0..n {
464 let d_end = (t + 1).min(d_max);
465 let mut best_val = f64::NEG_INFINITY;
466 let mut best_d = 1;
467 let mut best_prev = n; for d in 1..=d_end {
470 let t_start = t + 1 - d;
471 let log_dur = model.dur[j].log_prob(d, model.max_dur);
472 let log_em_seg = seg_log_em(&cum, j, t_start, t, t_max);
473
474 if log_dur == f64::NEG_INFINITY {
475 continue;
476 }
477
478 let log_seg_cost = log_dur + log_em_seg;
479
480 if d == t + 1 {
481 let v = log_seg_cost + log_pi[j];
483 if v > best_val {
484 best_val = v;
485 best_d = d;
486 best_prev = n; }
488 } else {
489 let prev_t = t_start - 1;
490 for i in 0..n {
491 if i == j {
492 continue;
493 }
494 let log_prev_v = log_v[i * t_max + prev_t];
495 if log_prev_v == f64::NEG_INFINITY {
496 continue;
497 }
498 let v = log_seg_cost + log_a[i * n + j] + log_prev_v;
499 if v > best_val {
500 best_val = v;
501 best_d = d;
502 best_prev = i;
503 }
504 }
505 }
506 }
507
508 log_v[j * t_max + t] = best_val;
509 bp_d[j * t_max + t] = best_d;
510 bp_prev[j * t_max + t] = best_prev;
511 }
512 }
513
514 let last_t = t_max - 1;
516 let mut best_final = f64::NEG_INFINITY;
517 let mut best_j = 0;
518 for j in 0..n {
519 let v = log_v[j * t_max + last_t];
520 if v > best_final {
521 best_final = v;
522 best_j = j;
523 }
524 }
525
526 if best_final == f64::NEG_INFINITY {
527 return Ok(vec![0usize; t_max]);
529 }
530
531 let mut path = vec![0usize; t_max];
533 let mut cur_t = last_t as isize;
534 let mut cur_j = best_j;
535
536 while cur_t >= 0 {
537 let t = cur_t as usize;
538 let d = bp_d[cur_j * t_max + t];
539 let t_start = t + 1 - d;
540
541 for u in t_start..=t {
543 path[u] = cur_j;
544 }
545
546 if t_start == 0 {
547 break;
548 }
549 let prev_state = bp_prev[cur_j * t_max + t];
550 if prev_state == n {
551 break;
553 }
554 cur_t = (t_start as isize) - 1;
555 cur_j = prev_state;
556 }
557
558 Ok(path)
559}
560
561#[derive(Debug, Clone)]
565pub struct HsmConfig {
566 pub n_states: usize,
568 pub n_obs: usize,
570 pub max_dur: usize,
572 pub max_iter: usize,
574 pub tol: f64,
576}
577
578impl Default for HsmConfig {
579 fn default() -> Self {
580 Self {
581 n_states: 2,
582 n_obs: 2,
583 max_dur: 10,
584 max_iter: 100,
585 tol: 1e-5,
586 }
587 }
588}
589
590#[derive(Debug, Clone)]
592pub struct HsmResult {
593 pub model: Hsmm,
595 pub log_likelihood_history: Vec<f64>,
597 pub n_iter: usize,
599 pub converged: bool,
601}
602
603fn build_initial_model(cfg: &HsmConfig) -> SeqResult<Hsmm> {
606 let n = cfg.n_states;
607 let k = cfg.n_obs;
608 let d_max = cfg.max_dur;
609
610 let pi: Vec<f64> = vec![1.0 / n as f64; n];
612
613 let mut a = vec![0.0f64; n * n];
615 if n > 1 {
616 for i in 0..n {
617 for j in 0..n {
618 a[i * n + j] = if i == j { 0.0 } else { 1.0 / (n as f64 - 1.0) };
619 }
620 }
621 }
622
623 let mut b = vec![0.0f64; n * k];
625 for j in 0..n {
626 let mut row_sum = 0.0f64;
627 for sym in 0..k {
628 let base = 1.0 / k as f64;
630 let bump = if sym == j % k { 0.2 / k as f64 } else { 0.0 };
631 b[j * k + sym] = base + bump;
632 row_sum += b[j * k + sym];
633 }
634 for sym in 0..k {
636 b[j * k + sym] /= row_sum;
637 }
638 }
639
640 let p = 1.0 / d_max.max(1) as f64;
642 let dur: Vec<DurationDistrib> = (0..n).map(|_| DurationDistrib::Geometric { p }).collect();
643
644 Hsmm::new(n, k, d_max, pi, a, b, dur)
645}
646
647pub fn hsm_fit(observations: &[&[usize]], cfg: &HsmConfig) -> SeqResult<HsmResult> {
651 if observations.is_empty() || observations.iter().all(|s| s.is_empty()) {
652 return Err(SeqError::EmptyInput);
653 }
654 if cfg.n_states == 0 || cfg.n_obs == 0 || cfg.max_dur == 0 {
655 return Err(SeqError::InvalidConfiguration(
656 "n_states, n_obs, and max_dur must all be > 0".to_string(),
657 ));
658 }
659 for seq in observations.iter() {
660 for &o in *seq {
661 if o >= cfg.n_obs {
662 return Err(SeqError::InvalidObservation(format!(
663 "observation {o} >= n_obs {}",
664 cfg.n_obs
665 )));
666 }
667 }
668 }
669
670 let n = cfg.n_states;
671 let k = cfg.n_obs;
672 let d_max = cfg.max_dur;
673
674 let mut model = build_initial_model(cfg)?;
675 let mut history: Vec<f64> = Vec::with_capacity(cfg.max_iter + 1);
676 let mut prev_ll = f64::NEG_INFINITY;
677 let mut converged = false;
678 let mut n_iter = 0usize;
679
680 for iter in 0..cfg.max_iter {
681 n_iter = iter + 1;
682
683 let mut ss_pi = vec![0.0f64; n];
686 let mut ss_a = vec![0.0f64; n * n];
688 let mut ss_b = vec![0.0f64; n * k];
690 let mut ss_dur = vec![0.0f64; n * (d_max + 1)];
692
693 let mut total_ll = 0.0f64;
694
695 for seq in observations.iter() {
696 if seq.is_empty() {
697 continue;
698 }
699 let t_max = seq.len();
700 let cum = build_cum_log_b(&model, seq);
701 let (log_e, log_z) = hsmm_forward(&model, seq);
702 let log_f = hsmm_backward(&model, seq, &cum);
703
704 if !log_z.is_finite() {
705 continue;
707 }
708
709 total_ll += log_z;
710
711 let log_pi_v: Vec<f64> = model.pi.iter().map(|&p| log_safe(p)).collect();
716 let log_a_v: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
717
718 for j in 0..n {
719 for t_end in 0..t_max {
720 for d in 1..=(t_end + 1).min(d_max) {
721 let t_start = t_end + 1 - d;
722 let log_dur = model.dur[j].log_prob(d, d_max);
723 if log_dur == f64::NEG_INFINITY {
724 continue;
725 }
726 let log_em_seg = seg_log_em(&cum, j, t_start, t_end, t_max);
727
728 let log_init = if t_start == 0 {
731 log_pi_v[j]
732 } else {
733 let prev_t = t_start - 1;
734 let mut terms: Vec<f64> = Vec::with_capacity(n);
735 for i in 0..n {
736 if i == j {
737 continue;
738 }
739 let lv = log_e[i * t_max + prev_t];
740 if lv == f64::NEG_INFINITY {
741 continue;
742 }
743 terms.push(log_a_v[i * n + j] + lv);
744 }
745 if terms.is_empty() {
746 f64::NEG_INFINITY
747 } else {
748 logsumexp(&terms)
749 }
750 };
751
752 if log_init == f64::NEG_INFINITY {
753 continue;
754 }
755
756 let log_f_val = log_f[j * t_max + t_end];
757 if log_f_val == f64::NEG_INFINITY {
758 continue;
759 }
760
761 let log_gamma_seg = log_init + log_dur + log_em_seg + log_f_val - log_z;
762 let gamma_seg = log_gamma_seg.exp();
763
764 if !gamma_seg.is_finite() || gamma_seg <= 0.0 {
765 continue;
766 }
767
768 if t_start == 0 {
770 ss_pi[j] += gamma_seg;
771 }
772
773 ss_dur[j * (d_max + 1) + d] += gamma_seg;
775
776 for u in t_start..=t_end {
778 ss_b[j * k + seq[u]] += gamma_seg;
779 }
780
781 if t_start > 0 {
783 let prev_t = t_start - 1;
784 for i in 0..n {
785 if i == j {
786 continue;
787 }
788 let lv = log_e[i * t_max + prev_t];
789 if lv == f64::NEG_INFINITY {
790 continue;
791 }
792 let log_xi =
793 log_a_v[i * n + j] + lv + log_dur + log_em_seg + log_f_val
794 - log_z;
795 let xi_val = log_xi.exp();
796 if xi_val.is_finite() && xi_val > 0.0 {
797 ss_a[i * n + j] += xi_val;
798 }
799 }
800 }
801 }
802 }
803 }
804 }
805
806 history.push(total_ll);
807
808 if iter > 0 && (total_ll - prev_ll).abs() < cfg.tol {
810 converged = true;
811 break;
812 }
813 prev_ll = total_ll;
814
815 let pi_sum: f64 = ss_pi.iter().sum();
819 let new_pi: Vec<f64> = if pi_sum > 0.0 {
820 ss_pi.iter().map(|&v| v / pi_sum).collect()
821 } else {
822 vec![1.0 / n as f64; n]
823 };
824
825 let mut new_a = vec![0.0f64; n * n];
827 if n > 1 {
828 for i in 0..n {
829 let row_sum: f64 = ss_a[i * n..(i + 1) * n].iter().sum();
830 for j in 0..n {
831 if i == j {
832 new_a[i * n + j] = 0.0;
833 } else {
834 new_a[i * n + j] = if row_sum > 0.0 {
835 ss_a[i * n + j] / row_sum
836 } else {
837 1.0 / (n as f64 - 1.0)
838 };
839 }
840 }
841 }
842 }
843
844 let mut new_b = vec![0.0f64; n * k];
846 for j in 0..n {
847 let row_sum: f64 = ss_b[j * k..(j + 1) * k].iter().sum();
848 for sym in 0..k {
849 new_b[j * k + sym] = if row_sum > 0.0 {
850 ss_b[j * k + sym] / row_sum
851 } else {
852 1.0 / k as f64
853 };
854 }
855 }
856
857 let mut new_dur: Vec<DurationDistrib> = Vec::with_capacity(n);
859 for j in 0..n {
860 let total: f64 = ss_dur[j * (d_max + 1) + 1..=(j * (d_max + 1) + d_max)]
861 .iter()
862 .sum();
863 let probs: Vec<f64> = if total > 0.0 {
864 (1..=d_max)
865 .map(|d| ss_dur[j * (d_max + 1) + d] / total)
866 .collect()
867 } else {
868 let p = 1.0 / d_max as f64;
870 (1..=d_max)
871 .map(|d| {
872 let geo = DurationDistrib::Geometric { p };
873 geo.prob(d, d_max)
874 })
875 .collect()
876 };
877 new_dur.push(DurationDistrib::Histogram { probs });
878 }
879
880 model = Hsmm {
882 n_states: n,
883 n_obs: k,
884 max_dur: d_max,
885 pi: new_pi,
886 a: new_a,
887 b: new_b,
888 dur: new_dur,
889 };
890 }
891
892 Ok(HsmResult {
893 model,
894 log_likelihood_history: history,
895 n_iter,
896 converged,
897 })
898}
899
900#[cfg(test)]
903mod tests {
904 use super::*;
905
906 #[test]
909 fn poisson_probs_sum_to_one() {
910 let d = DurationDistrib::Poisson { lambda: 3.0 };
912 let s: f64 = (1..=20).map(|t| d.prob(t, 20)).sum();
913 assert!((s - 1.0).abs() < 1e-9, "Poisson prob sum = {s}");
914 }
915
916 #[test]
917 fn geometric_probs_approx_one() {
918 let d = DurationDistrib::Geometric { p: 0.3 };
919 let s: f64 = (1..=1000).map(|t| d.prob(t, 1000)).sum();
921 assert!((s - 1.0).abs() < 1e-9, "Geometric prob sum = {s}");
922 }
923
924 #[test]
925 fn histogram_probs_sum_to_one() {
926 let probs = vec![0.2, 0.5, 0.3];
927 let d = DurationDistrib::Histogram { probs };
928 let s: f64 = (1..=3).map(|t| d.prob(t, 3)).sum();
929 assert!((s - 1.0).abs() < 1e-9, "Histogram prob sum = {s}");
930 }
931
932 #[test]
933 fn poisson_log_prob_finite_for_positive_lambda() {
934 let d = DurationDistrib::Poisson { lambda: 2.0 };
935 let lp = d.log_prob(1, 10);
936 assert!(lp.is_finite(), "Poisson log_prob(1, 10) = {lp}");
937 }
938
939 #[test]
940 fn geometric_prob_decreasing() {
941 let d = DurationDistrib::Geometric { p: 0.5 };
942 for t in 1..=5 {
943 assert!(
944 d.prob(t, 20) > d.prob(t + 1, 20),
945 "Geometric should be decreasing"
946 );
947 }
948 }
949
950 fn two_state_model() -> Hsmm {
953 Hsmm::new(
954 2,
955 2,
956 5,
957 vec![0.5, 0.5],
958 vec![0.0, 1.0, 1.0, 0.0],
959 vec![0.9, 0.1, 0.1, 0.9],
960 vec![
961 DurationDistrib::Geometric { p: 0.3 },
962 DurationDistrib::Geometric { p: 0.3 },
963 ],
964 )
965 .expect("valid model")
966 }
967
968 #[test]
969 fn hsmm_new_validates_shapes() {
970 assert!(
972 Hsmm::new(
973 2,
974 2,
975 5,
976 vec![1.0],
977 vec![0.0, 1.0, 1.0, 0.0],
978 vec![0.5, 0.5, 0.5, 0.5],
979 vec![DurationDistrib::Geometric { p: 0.5 }; 2],
980 )
981 .is_err()
982 );
983 }
984
985 #[test]
986 fn hsmm_new_rejects_nonzero_diagonal() {
987 assert!(
989 Hsmm::new(
990 2,
991 2,
992 5,
993 vec![0.5, 0.5],
994 vec![0.5, 0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5, 0.5],
996 vec![DurationDistrib::Geometric { p: 0.5 }; 2],
997 )
998 .is_err()
999 );
1000 }
1001
1002 #[test]
1003 fn log_likelihood_finite_for_valid_obs() {
1004 let m = two_state_model();
1005 let ll = m.log_likelihood(&[0, 1, 0, 1]).expect("should succeed");
1006 assert!(ll.is_finite(), "ll = {ll}");
1007 }
1008
1009 #[test]
1010 fn log_likelihood_err_for_empty_obs() {
1011 let m = two_state_model();
1012 assert!(m.log_likelihood(&[]).is_err());
1013 }
1014
1015 #[test]
1016 fn log_likelihood_err_for_obs_out_of_range() {
1017 let m = two_state_model();
1018 assert!(m.log_likelihood(&[0, 5]).is_err());
1019 }
1020
1021 #[test]
1022 fn decode_returns_sequence_of_correct_length() {
1023 let m = two_state_model();
1024 let obs = vec![0usize, 0, 1, 1, 0];
1025 let path = m.decode(&obs).expect("ok");
1026 assert_eq!(path.len(), obs.len());
1027 }
1028
1029 #[test]
1030 fn decode_all_same_when_one_state_dominates() {
1031 let m = Hsmm::new(
1035 1, 2,
1037 3,
1038 vec![1.0], vec![0.0], vec![0.999, 0.001], vec![DurationDistrib::Geometric { p: 0.5 }],
1042 )
1043 .expect("ok");
1044 let obs = vec![0usize; 4];
1045 let path = m.decode(&obs).expect("ok");
1046 assert!(
1047 path.iter().all(|&s| s == 0),
1048 "expected all state 0, got {:?}",
1049 path
1050 );
1051 }
1052
1053 #[test]
1056 fn hsm_fit_runs_without_error() {
1057 let obs = vec![0usize, 0, 1, 1, 0, 1, 0, 0, 1, 1];
1058 let cfg = HsmConfig::default();
1059 assert!(hsm_fit(&[&obs], &cfg).is_ok());
1060 }
1061
1062 #[test]
1063 fn hsm_fit_ll_non_decreasing() {
1064 let obs: Vec<usize> = (0..20).map(|i| i % 2).collect();
1065 let cfg = HsmConfig {
1066 max_iter: 30,
1067 ..Default::default()
1068 };
1069 let r = hsm_fit(&[&obs], &cfg).expect("ok");
1070 for w in r.log_likelihood_history.windows(2) {
1071 assert!(w[1] >= w[0] - 1e-4, "LL decreased: {} → {}", w[0], w[1]);
1072 }
1073 }
1074
1075 #[test]
1076 fn hsm_fit_converged_flag() {
1077 let obs: Vec<usize> = (0..50).map(|i| i % 2).collect();
1078 let cfg = HsmConfig {
1079 max_iter: 500,
1080 tol: 1e-3,
1081 ..Default::default()
1082 };
1083 let r = hsm_fit(&[&obs], &cfg).expect("ok");
1084 assert!(r.converged, "expected convergence");
1085 }
1086
1087 #[test]
1088 fn hsm_fit_result_pi_sums_to_one() {
1089 let obs = vec![0usize, 1, 0, 1, 0, 0];
1090 let cfg = HsmConfig::default();
1091 let r = hsm_fit(&[&obs], &cfg).expect("ok");
1092 let s: f64 = r.model.pi.iter().sum();
1093 assert!((s - 1.0).abs() < 1e-9, "pi sums to {s}");
1094 }
1095
1096 #[test]
1097 fn hsm_fit_result_b_rows_sum_to_one() {
1098 let obs = vec![0usize, 1, 0, 1, 0, 0];
1099 let cfg = HsmConfig::default();
1100 let r = hsm_fit(&[&obs], &cfg).expect("ok");
1101 let n = cfg.n_states;
1102 let k = cfg.n_obs;
1103 for j in 0..n {
1104 let s: f64 = r.model.b[j * k..(j + 1) * k].iter().sum();
1105 assert!((s - 1.0).abs() < 1e-9, "B row {j} sums to {s}");
1106 }
1107 }
1108
1109 #[test]
1110 fn hsm_fit_n_iter_within_max_iter() {
1111 let obs = vec![0usize, 1, 0, 1, 0, 0];
1112 let cfg = HsmConfig {
1113 max_iter: 10,
1114 ..Default::default()
1115 };
1116 let r = hsm_fit(&[&obs], &cfg).expect("ok");
1117 assert!(r.n_iter <= 10);
1118 }
1119
1120 #[test]
1121 fn hsm_fit_multiple_sequences() {
1122 let s1 = vec![0usize, 0, 1, 1];
1123 let s2 = vec![1usize, 0, 1, 0, 0];
1124 let s3 = vec![0usize, 1, 1, 0, 1, 0];
1125 let cfg = HsmConfig::default();
1126 assert!(hsm_fit(&[&s1, &s2, &s3], &cfg).is_ok());
1127 }
1128
1129 #[test]
1130 fn hsm_fit_short_sequence_length_one() {
1131 let obs = vec![0usize];
1132 let cfg = HsmConfig::default();
1133 let r = hsm_fit(&[&obs], &cfg).expect("length-1 sequence should work");
1134 assert!(!r.log_likelihood_history.is_empty());
1135 }
1136
1137 #[test]
1138 fn hsm_fit_max_dur_one() {
1139 let obs: Vec<usize> = (0..10).map(|i| i % 2).collect();
1141 let cfg = HsmConfig {
1142 max_dur: 1,
1143 ..Default::default()
1144 };
1145 let r = hsm_fit(&[&obs], &cfg).expect("max_dur=1 should work");
1146 assert!(!r.log_likelihood_history.is_empty());
1147 }
1148
1149 #[test]
1150 fn hsmm_a_rows_zero_diagonal() {
1151 let obs: Vec<usize> = (0..10).map(|i| i % 2).collect();
1153 let cfg = HsmConfig::default();
1154 let r = hsm_fit(&[&obs], &cfg).expect("ok");
1155 let n = cfg.n_states;
1156 for i in 0..n {
1157 let diag = r.model.a[i * n + i];
1158 assert!(diag.abs() < 1e-9, "diagonal A[{i},{i}] = {diag}");
1159 }
1160 }
1161
1162 #[test]
1163 fn hsm_fit_empty_input_err() {
1164 let cfg = HsmConfig::default();
1165 assert!(hsm_fit(&[], &cfg).is_err());
1166 }
1167}