Skip to main content

oxicuda_seq/hmm/
semimarkov.rs

1//! Hidden Semi-Markov Model (HSMM) with explicit duration distributions.
2//!
3//! Reference: Yu 2010, "Hidden semi-Markov models", Artificial Intelligence
4//! 174:215–243, §2.2.
5//!
6//! Unlike a standard HMM (which forces geometric sojourn times), an HSMM
7//! models an explicit duration distribution d_j(τ) = P(stay in state j for
8//! exactly τ consecutive steps).  The transition matrix A has zeros on the
9//! diagonal (self-transitions are absorbed into the duration model); each row
10//! sums to 1 over the off-diagonal entries.
11//!
12//! The implementation follows the *state-completion* forward-backward
13//! formulation of Yu 2010:
14//!
15//!   e_j(t)  = P(o_1..o_t, a segment in state j completes at time t)
16//!   f_j(t)  = P(o_{t+1}..o_T | current-segment in state j ended at t)
17
18use crate::error::{SeqError, SeqResult};
19
20// ─── logsumexp helper ─────────────────────────────────────────────────────────
21
22#[inline]
23fn logsumexp(xs: &[f64]) -> f64 {
24    let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
25    if m == f64::NEG_INFINITY {
26        return f64::NEG_INFINITY;
27    }
28    let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
29    m + s.ln()
30}
31
32#[inline]
33fn log_safe(x: f64) -> f64 {
34    if x <= 0.0 || !x.is_finite() {
35        f64::NEG_INFINITY
36    } else {
37        x.ln()
38    }
39}
40
41// ─── Duration distributions ───────────────────────────────────────────────────
42
43/// Duration distribution for a single state.
44///
45/// All distributions are truncated to `max_dur` in the forward-backward
46/// algorithm; `max_dur` is passed as an argument to `prob` / `log_prob` to
47/// facilitate normalisation where needed.
48#[derive(Debug, Clone)]
49pub enum DurationDistrib {
50    /// Poisson duration: P(dur = τ) ∝ λ^τ exp(−λ) / τ!, τ ≥ 1.
51    /// Truncated and renormalised over τ = 1..max_dur.
52    Poisson { lambda: f64 },
53    /// Geometric: P(dur = τ) = p · (1−p)^{τ−1}, τ ≥ 1.
54    Geometric { p: f64 },
55    /// Histogram: `probs[τ−1]` = P(dur = τ), τ = 1..len.
56    Histogram { probs: Vec<f64> },
57}
58
59impl DurationDistrib {
60    /// Probability of duration `tau` (≥ 1).
61    ///
62    /// For `Poisson`, the distribution is truncated and renormalised to the
63    /// range τ = 1..`max_dur` so that probabilities sum to 1 within the
64    /// algorithm's horizon.
65    pub fn prob(&self, tau: usize, max_dur: usize) -> f64 {
66        if tau == 0 {
67            return 0.0;
68        }
69        match self {
70            DurationDistrib::Poisson { lambda } => {
71                if *lambda <= 0.0 {
72                    return if tau == 1 { 1.0 } else { 0.0 };
73                }
74                // Compute unnormalised Poisson pmf at tau; normalise over 1..max_dur.
75                let raw = poisson_pmf(*lambda, tau);
76                let total: f64 = (1..=max_dur).map(|d| poisson_pmf(*lambda, d)).sum();
77                if total > 0.0 { raw / total } else { 0.0 }
78            }
79            DurationDistrib::Geometric { p } => {
80                let p = p.clamp(0.0, 1.0);
81                p * (1.0 - p).powi((tau - 1) as i32)
82            }
83            DurationDistrib::Histogram { probs } => {
84                if tau <= probs.len() {
85                    probs[tau - 1]
86                } else {
87                    0.0
88                }
89            }
90        }
91    }
92
93    /// Log probability of duration `tau`; returns −∞ if `prob` = 0.
94    pub fn log_prob(&self, tau: usize, max_dur: usize) -> f64 {
95        log_safe(self.prob(tau, max_dur))
96    }
97}
98
99/// Unnormalised Poisson pmf at non-negative integer k.
100fn poisson_pmf(lambda: f64, k: usize) -> f64 {
101    if lambda <= 0.0 {
102        return if k == 0 { 1.0 } else { 0.0 };
103    }
104    // Compute in log space then exponentiate to avoid overflow for large k.
105    let log_p = (k as f64) * lambda.ln() - lambda - log_factorial(k);
106    log_p.exp()
107}
108
109/// ln(k!) computed via Stirling / iterative sum for k ≤ ~20, else lgamma.
110fn log_factorial(k: usize) -> f64 {
111    if k <= 1 {
112        return 0.0;
113    }
114    (1..=k).map(|i| (i as f64).ln()).sum()
115}
116
117// ─── Model struct ─────────────────────────────────────────────────────────────
118
119/// Hidden Semi-Markov Model with explicit per-state duration distributions.
120#[derive(Debug, Clone)]
121pub struct Hsmm {
122    /// Number of hidden states.
123    pub n_states: usize,
124    /// Number of distinct observation symbols.
125    pub n_obs: usize,
126    /// Maximum segment duration (D_max).
127    pub max_dur: usize,
128    /// Initial state probability vector (n_states,).
129    pub pi: Vec<f64>,
130    /// Transition matrix without self-transitions, row-major (n_states × n_states).
131    /// Diagonal entries must be 0; each row sums to 1.
132    pub a: Vec<f64>,
133    /// Emission matrix, row-major (n_states × n_obs).  Each row sums to 1.
134    pub b: Vec<f64>,
135    /// Duration distribution for each state.
136    pub dur: Vec<DurationDistrib>,
137}
138
139impl Hsmm {
140    /// Construct and validate an HSMM.
141    pub fn new(
142        n_states: usize,
143        n_obs: usize,
144        max_dur: usize,
145        pi: Vec<f64>,
146        a: Vec<f64>,
147        b: Vec<f64>,
148        dur: Vec<DurationDistrib>,
149    ) -> SeqResult<Self> {
150        if n_states == 0 || n_obs == 0 || max_dur == 0 {
151            return Err(SeqError::InvalidConfiguration(
152                "n_states, n_obs, and max_dur must all be > 0".to_string(),
153            ));
154        }
155        if pi.len() != n_states {
156            return Err(SeqError::ShapeMismatch {
157                expected: n_states,
158                got: pi.len(),
159            });
160        }
161        if a.len() != n_states * n_states {
162            return Err(SeqError::ShapeMismatch {
163                expected: n_states * n_states,
164                got: a.len(),
165            });
166        }
167        if b.len() != n_states * n_obs {
168            return Err(SeqError::ShapeMismatch {
169                expected: n_states * n_obs,
170                got: b.len(),
171            });
172        }
173        if dur.len() != n_states {
174            return Err(SeqError::ShapeMismatch {
175                expected: n_states,
176                got: dur.len(),
177            });
178        }
179
180        // Validate diagonal = 0
181        for i in 0..n_states {
182            let diag = a[i * n_states + i];
183            if diag.abs() > 1e-9 {
184                return Err(SeqError::InvalidConfiguration(format!(
185                    "transition matrix diagonal A[{i},{i}] = {diag} must be 0"
186                )));
187            }
188        }
189
190        // Validate π sums to ~1.
191        let pi_sum: f64 = pi.iter().sum();
192        if (pi_sum - 1.0).abs() > 1e-5 {
193            return Err(SeqError::InvalidConfiguration(format!(
194                "pi sums to {pi_sum}, expected 1"
195            )));
196        }
197        // Validate A rows (only for n_states > 1; single-state has all-zero row).
198        if n_states > 1 {
199            for i in 0..n_states {
200                let s: f64 = a[i * n_states..(i + 1) * n_states].iter().sum();
201                if (s - 1.0).abs() > 1e-5 {
202                    return Err(SeqError::InvalidConfiguration(format!(
203                        "A row {i} sums to {s}, expected 1"
204                    )));
205                }
206            }
207        }
208        // Validate B rows.
209        for j in 0..n_states {
210            let s: f64 = b[j * n_obs..(j + 1) * n_obs].iter().sum();
211            if (s - 1.0).abs() > 1e-5 {
212                return Err(SeqError::InvalidConfiguration(format!(
213                    "B row {j} sums to {s}, expected 1"
214                )));
215            }
216        }
217
218        Ok(Self {
219            n_states,
220            n_obs,
221            max_dur,
222            pi,
223            a,
224            b,
225            dur,
226        })
227    }
228
229    /// Compute log P(o₁..o_T | model) using the HSMM forward algorithm.
230    pub fn log_likelihood(&self, obs: &[usize]) -> SeqResult<f64> {
231        if obs.is_empty() {
232            return Err(SeqError::EmptyInput);
233        }
234        for &o in obs {
235            if o >= self.n_obs {
236                return Err(SeqError::InvalidObservation(format!(
237                    "observation {o} >= n_obs {}",
238                    self.n_obs
239                )));
240            }
241        }
242        let (_, log_z) = hsmm_forward(self, obs);
243        Ok(log_z)
244    }
245
246    /// Viterbi decoding: return the most likely state sequence.
247    pub fn decode(&self, obs: &[usize]) -> SeqResult<Vec<usize>> {
248        if obs.is_empty() {
249            return Err(SeqError::EmptyInput);
250        }
251        for &o in obs {
252            if o >= self.n_obs {
253                return Err(SeqError::InvalidObservation(format!(
254                    "observation {o} >= n_obs {}",
255                    self.n_obs
256                )));
257            }
258        }
259        hsmm_viterbi(self, obs)
260    }
261}
262
263// ─── Pre-compute cumulative log-emission sums ──────────────────────────────────
264
265/// Build cumulative log-emission sums for each state and each time range.
266///
267/// `cum_log_b[j * (t_max+1) + t]` = Σ_{u=0}^{t-1} log B_j(o_u)
268/// so that the sum over t1..=t2 (0-indexed) = cum[t2+1] − cum[t1].
269fn build_cum_log_b(model: &Hsmm, obs: &[usize]) -> Vec<f64> {
270    let t_max = obs.len();
271    let n = model.n_states;
272    // Shape: n × (t_max+1)
273    let mut cum = vec![0.0f64; n * (t_max + 1)];
274    for j in 0..n {
275        for t in 0..t_max {
276            let le = log_safe(model.b[j * model.n_obs + obs[t]]);
277            cum[j * (t_max + 1) + t + 1] = cum[j * (t_max + 1) + t] + le;
278        }
279    }
280    cum
281}
282
283/// Log emission of state j for the segment obs[t1..=t2] (0-indexed, inclusive).
284#[inline]
285fn seg_log_em(cum: &[f64], j: usize, t1: usize, t2: usize, t_max: usize) -> f64 {
286    cum[j * (t_max + 1) + t2 + 1] - cum[j * (t_max + 1) + t1]
287}
288
289// ─── HSMM forward algorithm ───────────────────────────────────────────────────
290
291/// Compute the HSMM forward variables using the state-completion form.
292///
293/// Returns `(log_e, log_z)` where:
294///   `log_e[j * t_max + t]` = log e_j(t+1)  (1-indexed time, 0-indexed array)
295///   `log_z` = log p(o_1..o_T | model)
296fn hsmm_forward(model: &Hsmm, obs: &[usize]) -> (Vec<f64>, f64) {
297    let t_max = obs.len();
298    let n = model.n_states;
299    let d_max = model.max_dur.min(t_max);
300    let cum = build_cum_log_b(model, obs);
301
302    // log_e[j, t] = log e_j(t+1) where t is 0-based (t+1 is the 1-based time
303    // at which a segment in state j completes).
304    let mut log_e = vec![f64::NEG_INFINITY; n * t_max];
305
306    // log_pi[j] = log π_j
307    let log_pi: Vec<f64> = model.pi.iter().map(|&p| log_safe(p)).collect();
308
309    // log_a[i, j] = log A_{ij} (i→j off-diagonal)
310    let log_a: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
311
312    // We buffer logsumexp terms.
313    let mut terms = Vec::with_capacity(d_max * n);
314
315    for t in 0..t_max {
316        // t is the 0-based end-time index; the 1-based version is t+1.
317        for j in 0..n {
318            terms.clear();
319            let d_end = (t + 1).min(d_max); // max segment duration ending at t
320            for d in 1..=d_end {
321                // Segment spans obs[t-d+1 .. t] (0-based, inclusive), length d.
322                let t_start = t + 1 - d; // 0-based start index
323                let log_dur = model.dur[j].log_prob(d, model.max_dur);
324                let log_em_seg = seg_log_em(&cum, j, t_start, t, t_max);
325
326                // Compute init term: log π_j (if d = t+1, i.e. segment starts at 0)
327                //                    + logsumexp_{i≠j}(log A_{ij} + log e_i(t-d))
328                let log_init = if d == t + 1 {
329                    // Segment starts at the very beginning of the sequence.
330                    log_pi[j]
331                } else {
332                    // t_start > 0 → must have come from a previous segment ending at t_start-1.
333                    let prev_t = t_start - 1; // 0-based
334                    let mut trans_terms: Vec<f64> = Vec::with_capacity(n);
335                    for i in 0..n {
336                        if i == j {
337                            continue;
338                        }
339                        let log_prev = log_e[i * t_max + prev_t];
340                        if log_prev == f64::NEG_INFINITY {
341                            continue;
342                        }
343                        trans_terms.push(log_a[i * n + j] + log_prev);
344                    }
345                    if trans_terms.is_empty() {
346                        f64::NEG_INFINITY
347                    } else {
348                        logsumexp(&trans_terms)
349                    }
350                };
351
352                if log_init == f64::NEG_INFINITY || log_dur == f64::NEG_INFINITY {
353                    continue;
354                }
355                terms.push(log_dur + log_em_seg + log_init);
356            }
357            log_e[j * t_max + t] = if terms.is_empty() {
358                f64::NEG_INFINITY
359            } else {
360                logsumexp(&terms)
361            };
362        }
363    }
364
365    // log Z = logsumexp_j log_e_j(T)  (last time step, 0-based index t_max-1)
366    let last_terms: Vec<f64> = (0..n).map(|j| log_e[j * t_max + t_max - 1]).collect();
367    let log_z = logsumexp(&last_terms);
368
369    (log_e, log_z)
370}
371
372// ─── HSMM backward algorithm ─────────────────────────────────────────────────
373
374/// Compute HSMM backward variables.
375///
376/// `log_f[j, t]` = log f_j(t+1) (probability of future observations given a
377/// segment in state j completed at 1-based time t+1).
378///
379/// Boundary: log_f[j, T-1] = 0 for all j (1-based T = t_max).
380fn hsmm_backward(model: &Hsmm, obs: &[usize], cum: &[f64]) -> Vec<f64> {
381    let t_max = obs.len();
382    let n = model.n_states;
383    let d_max = model.max_dur;
384
385    let log_a: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
386
387    // log_f[j * t_max + t] = log f_j(t+1)
388    let mut log_f = vec![f64::NEG_INFINITY; n * t_max];
389
390    // Boundary: f_j(T) = 1 → log = 0.
391    for j in 0..n {
392        log_f[j * t_max + t_max - 1] = 0.0;
393    }
394
395    // Recursion: for t = T-2 down to 0 (0-based).
396    // f_j(t) = Σ_{k≠j} Σ_{d=1}^{min(D, T-(t+1))} A_{jk} * dur_k(d) * Π_em_k(t+1..t+d) * f_k(t+d)
397    // Note: t+1 to t+d are 0-based indices t+1-1..t+d-1 = t..t+d-1.
398    let mut terms: Vec<f64> = Vec::with_capacity(d_max * n);
399
400    for t in (0..t_max.saturating_sub(1)).rev() {
401        for j in 0..n {
402            terms.clear();
403            // Remaining time steps after t: from t+1 (0-based) to t_max-1.
404            let remaining = t_max - t - 1; // number of steps after t
405            let d_end = d_max.min(remaining);
406
407            for k in 0..n {
408                if k == j {
409                    continue;
410                }
411                let log_ajk = log_a[j * n + k];
412                if log_ajk == f64::NEG_INFINITY {
413                    continue;
414                }
415                for d in 1..=d_end {
416                    // Segment in k spans obs[t+1 .. t+d] (0-based), i.e. t_start=t+1, t_end=t+d.
417                    let t_start_new = t + 1;
418                    let t_end_new = t + d; // 0-based inclusive
419                    let log_dur = model.dur[k].log_prob(d, model.max_dur);
420                    let log_em_seg = seg_log_em(cum, k, t_start_new, t_end_new, t_max);
421                    let log_f_next = log_f[k * t_max + t_end_new];
422                    if log_dur == f64::NEG_INFINITY || log_f_next == f64::NEG_INFINITY {
423                        continue;
424                    }
425                    terms.push(log_ajk + log_dur + log_em_seg + log_f_next);
426                }
427            }
428
429            log_f[j * t_max + t] = if terms.is_empty() {
430                f64::NEG_INFINITY
431            } else {
432                logsumexp(&terms)
433            };
434        }
435    }
436
437    log_f
438}
439
440// ─── HSMM Viterbi ─────────────────────────────────────────────────────────────
441
442/// HSMM Viterbi: max-product version of the forward algorithm.
443/// Returns the most-likely state sequence of length T.
444fn hsmm_viterbi(model: &Hsmm, obs: &[usize]) -> SeqResult<Vec<usize>> {
445    let t_max = obs.len();
446    let n = model.n_states;
447    let d_max = model.max_dur.min(t_max);
448    let cum = build_cum_log_b(model, obs);
449
450    let log_pi: Vec<f64> = model.pi.iter().map(|&p| log_safe(p)).collect();
451    let log_a: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
452
453    // v[j, t] = max log-prob of a path in which the segment containing time t
454    //           ends at t in state j.
455    let mut log_v = vec![f64::NEG_INFINITY; n * t_max];
456
457    // Backpointer: for each (j, t) the best (d, prev_state) that achieves log_v[j,t].
458    // We store (d, prev_j) where d = segment duration, prev_j = previous state (n if init).
459    let mut bp_d = vec![0usize; n * t_max];
460    let mut bp_prev = vec![n; n * t_max]; // n means "start of sequence"
461
462    for t in 0..t_max {
463        for j in 0..n {
464            let d_end = (t + 1).min(d_max);
465            let mut best_val = f64::NEG_INFINITY;
466            let mut best_d = 1;
467            let mut best_prev = n; // sentinel = sequence start
468
469            for d in 1..=d_end {
470                let t_start = t + 1 - d;
471                let log_dur = model.dur[j].log_prob(d, model.max_dur);
472                let log_em_seg = seg_log_em(&cum, j, t_start, t, t_max);
473
474                if log_dur == f64::NEG_INFINITY {
475                    continue;
476                }
477
478                let log_seg_cost = log_dur + log_em_seg;
479
480                if d == t + 1 {
481                    // Segment starts at beginning.
482                    let v = log_seg_cost + log_pi[j];
483                    if v > best_val {
484                        best_val = v;
485                        best_d = d;
486                        best_prev = n; // sentinel
487                    }
488                } else {
489                    let prev_t = t_start - 1;
490                    for i in 0..n {
491                        if i == j {
492                            continue;
493                        }
494                        let log_prev_v = log_v[i * t_max + prev_t];
495                        if log_prev_v == f64::NEG_INFINITY {
496                            continue;
497                        }
498                        let v = log_seg_cost + log_a[i * n + j] + log_prev_v;
499                        if v > best_val {
500                            best_val = v;
501                            best_d = d;
502                            best_prev = i;
503                        }
504                    }
505                }
506            }
507
508            log_v[j * t_max + t] = best_val;
509            bp_d[j * t_max + t] = best_d;
510            bp_prev[j * t_max + t] = best_prev;
511        }
512    }
513
514    // Termination: find best final state.
515    let last_t = t_max - 1;
516    let mut best_final = f64::NEG_INFINITY;
517    let mut best_j = 0;
518    for j in 0..n {
519        let v = log_v[j * t_max + last_t];
520        if v > best_final {
521            best_final = v;
522            best_j = j;
523        }
524    }
525
526    if best_final == f64::NEG_INFINITY {
527        // All paths have zero probability; fall back to uniform state 0.
528        return Ok(vec![0usize; t_max]);
529    }
530
531    // Backtrack to fill the state path.
532    let mut path = vec![0usize; t_max];
533    let mut cur_t = last_t as isize;
534    let mut cur_j = best_j;
535
536    while cur_t >= 0 {
537        let t = cur_t as usize;
538        let d = bp_d[cur_j * t_max + t];
539        let t_start = t + 1 - d;
540
541        // Fill states t_start..=t with cur_j.
542        for u in t_start..=t {
543            path[u] = cur_j;
544        }
545
546        if t_start == 0 {
547            break;
548        }
549        let prev_state = bp_prev[cur_j * t_max + t];
550        if prev_state == n {
551            // Sentinel: segment starts at sequence start.
552            break;
553        }
554        cur_t = (t_start as isize) - 1;
555        cur_j = prev_state;
556    }
557
558    Ok(path)
559}
560
561// ─── Configuration & result types ─────────────────────────────────────────────
562
563/// Configuration for HSMM EM training.
564#[derive(Debug, Clone)]
565pub struct HsmConfig {
566    /// Number of hidden states.
567    pub n_states: usize,
568    /// Number of distinct observation symbols.
569    pub n_obs: usize,
570    /// Maximum segment duration D_max (default 10).
571    pub max_dur: usize,
572    /// Maximum EM iterations (default 100).
573    pub max_iter: usize,
574    /// Convergence tolerance on log-likelihood change (default 1e-5).
575    pub tol: f64,
576}
577
578impl Default for HsmConfig {
579    fn default() -> Self {
580        Self {
581            n_states: 2,
582            n_obs: 2,
583            max_dur: 10,
584            max_iter: 100,
585            tol: 1e-5,
586        }
587    }
588}
589
590/// Result of HSMM EM fitting.
591#[derive(Debug, Clone)]
592pub struct HsmResult {
593    /// Fitted HSMM model.
594    pub model: Hsmm,
595    /// Log-likelihood at each iteration.
596    pub log_likelihood_history: Vec<f64>,
597    /// Number of EM iterations executed.
598    pub n_iter: usize,
599    /// Whether the algorithm converged within the tolerance.
600    pub converged: bool,
601}
602
603// ─── EM helper: build initial model ──────────────────────────────────────────
604
605fn build_initial_model(cfg: &HsmConfig) -> SeqResult<Hsmm> {
606    let n = cfg.n_states;
607    let k = cfg.n_obs;
608    let d_max = cfg.max_dur;
609
610    // π: uniform.
611    let pi: Vec<f64> = vec![1.0 / n as f64; n];
612
613    // A: uniform off-diagonal (diagonal = 0).
614    let mut a = vec![0.0f64; n * n];
615    if n > 1 {
616        for i in 0..n {
617            for j in 0..n {
618                a[i * n + j] = if i == j { 0.0 } else { 1.0 / (n as f64 - 1.0) };
619            }
620        }
621    }
622
623    // B: state-indexed perturbation so states are distinguishable.
624    let mut b = vec![0.0f64; n * k];
625    for j in 0..n {
626        let mut row_sum = 0.0f64;
627        for sym in 0..k {
628            // Assign slightly higher weight to symbol = (j % k) for state j.
629            let base = 1.0 / k as f64;
630            let bump = if sym == j % k { 0.2 / k as f64 } else { 0.0 };
631            b[j * k + sym] = base + bump;
632            row_sum += b[j * k + sym];
633        }
634        // Normalise row.
635        for sym in 0..k {
636            b[j * k + sym] /= row_sum;
637        }
638    }
639
640    // Duration: Geometric with p = 1 / max_dur for each state.
641    let p = 1.0 / d_max.max(1) as f64;
642    let dur: Vec<DurationDistrib> = (0..n).map(|_| DurationDistrib::Geometric { p }).collect();
643
644    Hsmm::new(n, k, d_max, pi, a, b, dur)
645}
646
647// ─── Main EM entry point ───────────────────────────────────────────────────────
648
649/// Fit an HSMM by EM on one or more observation sequences.
650pub fn hsm_fit(observations: &[&[usize]], cfg: &HsmConfig) -> SeqResult<HsmResult> {
651    if observations.is_empty() || observations.iter().all(|s| s.is_empty()) {
652        return Err(SeqError::EmptyInput);
653    }
654    if cfg.n_states == 0 || cfg.n_obs == 0 || cfg.max_dur == 0 {
655        return Err(SeqError::InvalidConfiguration(
656            "n_states, n_obs, and max_dur must all be > 0".to_string(),
657        ));
658    }
659    for seq in observations.iter() {
660        for &o in *seq {
661            if o >= cfg.n_obs {
662                return Err(SeqError::InvalidObservation(format!(
663                    "observation {o} >= n_obs {}",
664                    cfg.n_obs
665                )));
666            }
667        }
668    }
669
670    let n = cfg.n_states;
671    let k = cfg.n_obs;
672    let d_max = cfg.max_dur;
673
674    let mut model = build_initial_model(cfg)?;
675    let mut history: Vec<f64> = Vec::with_capacity(cfg.max_iter + 1);
676    let mut prev_ll = f64::NEG_INFINITY;
677    let mut converged = false;
678    let mut n_iter = 0usize;
679
680    for iter in 0..cfg.max_iter {
681        n_iter = iter + 1;
682
683        // ── E-step: collect sufficient statistics ──
684        // For π: Σ γ_j(segment starts at t=0 and ends at t=d-1)
685        let mut ss_pi = vec![0.0f64; n];
686        // For A: Σ_{t,d} γ_i→j (i≠j) transition at time t.
687        let mut ss_a = vec![0.0f64; n * n];
688        // For B: Σ_{t: o_t=sym} occupancy of state j at time t.
689        let mut ss_b = vec![0.0f64; n * k];
690        // For duration: ss_dur[j * (d_max+1) + d] = Σ expected # of segs of length d in state j.
691        let mut ss_dur = vec![0.0f64; n * (d_max + 1)];
692
693        let mut total_ll = 0.0f64;
694
695        for seq in observations.iter() {
696            if seq.is_empty() {
697                continue;
698            }
699            let t_max = seq.len();
700            let cum = build_cum_log_b(&model, seq);
701            let (log_e, log_z) = hsmm_forward(&model, seq);
702            let log_f = hsmm_backward(&model, seq, &cum);
703
704            if !log_z.is_finite() {
705                // Sequence had probability 0 under current model; skip accumulation.
706                continue;
707            }
708
709            total_ll += log_z;
710
711            // Compute segment posteriors γ_{j,t_start,d}.
712            // P(segment j, t_start, d | obs) ∝ init_or_trans * dur_j(d) * em_j * f_j(t_end)
713            //
714            // We iterate over all possible segments (j, t_start, d).
715            let log_pi_v: Vec<f64> = model.pi.iter().map(|&p| log_safe(p)).collect();
716            let log_a_v: Vec<f64> = model.a.iter().map(|&v| log_safe(v)).collect();
717
718            for j in 0..n {
719                for t_end in 0..t_max {
720                    for d in 1..=(t_end + 1).min(d_max) {
721                        let t_start = t_end + 1 - d;
722                        let log_dur = model.dur[j].log_prob(d, d_max);
723                        if log_dur == f64::NEG_INFINITY {
724                            continue;
725                        }
726                        let log_em_seg = seg_log_em(&cum, j, t_start, t_end, t_max);
727
728                        // Init term: from log_e perspective, e_j(t_end) = sum over d of terms;
729                        // here we need the "slice" contribution.
730                        let log_init = if t_start == 0 {
731                            log_pi_v[j]
732                        } else {
733                            let prev_t = t_start - 1;
734                            let mut terms: Vec<f64> = Vec::with_capacity(n);
735                            for i in 0..n {
736                                if i == j {
737                                    continue;
738                                }
739                                let lv = log_e[i * t_max + prev_t];
740                                if lv == f64::NEG_INFINITY {
741                                    continue;
742                                }
743                                terms.push(log_a_v[i * n + j] + lv);
744                            }
745                            if terms.is_empty() {
746                                f64::NEG_INFINITY
747                            } else {
748                                logsumexp(&terms)
749                            }
750                        };
751
752                        if log_init == f64::NEG_INFINITY {
753                            continue;
754                        }
755
756                        let log_f_val = log_f[j * t_max + t_end];
757                        if log_f_val == f64::NEG_INFINITY {
758                            continue;
759                        }
760
761                        let log_gamma_seg = log_init + log_dur + log_em_seg + log_f_val - log_z;
762                        let gamma_seg = log_gamma_seg.exp();
763
764                        if !gamma_seg.is_finite() || gamma_seg <= 0.0 {
765                            continue;
766                        }
767
768                        // Accumulate π.
769                        if t_start == 0 {
770                            ss_pi[j] += gamma_seg;
771                        }
772
773                        // Accumulate duration.
774                        ss_dur[j * (d_max + 1) + d] += gamma_seg;
775
776                        // Accumulate emission (each time step in the segment).
777                        for u in t_start..=t_end {
778                            ss_b[j * k + seq[u]] += gamma_seg;
779                        }
780
781                        // Accumulate transitions from i to j (for t_start > 0).
782                        if t_start > 0 {
783                            let prev_t = t_start - 1;
784                            for i in 0..n {
785                                if i == j {
786                                    continue;
787                                }
788                                let lv = log_e[i * t_max + prev_t];
789                                if lv == f64::NEG_INFINITY {
790                                    continue;
791                                }
792                                let log_xi =
793                                    log_a_v[i * n + j] + lv + log_dur + log_em_seg + log_f_val
794                                        - log_z;
795                                let xi_val = log_xi.exp();
796                                if xi_val.is_finite() && xi_val > 0.0 {
797                                    ss_a[i * n + j] += xi_val;
798                                }
799                            }
800                        }
801                    }
802                }
803            }
804        }
805
806        history.push(total_ll);
807
808        // Convergence check.
809        if iter > 0 && (total_ll - prev_ll).abs() < cfg.tol {
810            converged = true;
811            break;
812        }
813        prev_ll = total_ll;
814
815        // ── M-step ──
816
817        // Update π.
818        let pi_sum: f64 = ss_pi.iter().sum();
819        let new_pi: Vec<f64> = if pi_sum > 0.0 {
820            ss_pi.iter().map(|&v| v / pi_sum).collect()
821        } else {
822            vec![1.0 / n as f64; n]
823        };
824
825        // Update A (normalise each row, zero diagonal).
826        let mut new_a = vec![0.0f64; n * n];
827        if n > 1 {
828            for i in 0..n {
829                let row_sum: f64 = ss_a[i * n..(i + 1) * n].iter().sum();
830                for j in 0..n {
831                    if i == j {
832                        new_a[i * n + j] = 0.0;
833                    } else {
834                        new_a[i * n + j] = if row_sum > 0.0 {
835                            ss_a[i * n + j] / row_sum
836                        } else {
837                            1.0 / (n as f64 - 1.0)
838                        };
839                    }
840                }
841            }
842        }
843
844        // Update B (normalise each row).
845        let mut new_b = vec![0.0f64; n * k];
846        for j in 0..n {
847            let row_sum: f64 = ss_b[j * k..(j + 1) * k].iter().sum();
848            for sym in 0..k {
849                new_b[j * k + sym] = if row_sum > 0.0 {
850                    ss_b[j * k + sym] / row_sum
851                } else {
852                    1.0 / k as f64
853                };
854            }
855        }
856
857        // Update duration distributions (histogram form).
858        let mut new_dur: Vec<DurationDistrib> = Vec::with_capacity(n);
859        for j in 0..n {
860            let total: f64 = ss_dur[j * (d_max + 1) + 1..=(j * (d_max + 1) + d_max)]
861                .iter()
862                .sum();
863            let probs: Vec<f64> = if total > 0.0 {
864                (1..=d_max)
865                    .map(|d| ss_dur[j * (d_max + 1) + d] / total)
866                    .collect()
867            } else {
868                // Fall back to geometric if no data.
869                let p = 1.0 / d_max as f64;
870                (1..=d_max)
871                    .map(|d| {
872                        let geo = DurationDistrib::Geometric { p };
873                        geo.prob(d, d_max)
874                    })
875                    .collect()
876            };
877            new_dur.push(DurationDistrib::Histogram { probs });
878        }
879
880        // Install updated model (using unchecked path for diagonal since we built it correctly).
881        model = Hsmm {
882            n_states: n,
883            n_obs: k,
884            max_dur: d_max,
885            pi: new_pi,
886            a: new_a,
887            b: new_b,
888            dur: new_dur,
889        };
890    }
891
892    Ok(HsmResult {
893        model,
894        log_likelihood_history: history,
895        n_iter,
896        converged,
897    })
898}
899
900// ─── Tests ────────────────────────────────────────────────────────────────────
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905
906    // ── DurationDistrib tests ──────────────────────────────────────────────────
907
908    #[test]
909    fn poisson_probs_sum_to_one() {
910        // With truncation to max_dur=20.
911        let d = DurationDistrib::Poisson { lambda: 3.0 };
912        let s: f64 = (1..=20).map(|t| d.prob(t, 20)).sum();
913        assert!((s - 1.0).abs() < 1e-9, "Poisson prob sum = {s}");
914    }
915
916    #[test]
917    fn geometric_probs_approx_one() {
918        let d = DurationDistrib::Geometric { p: 0.3 };
919        // Over d=1..1000 the geometric CDF should be essentially 1.
920        let s: f64 = (1..=1000).map(|t| d.prob(t, 1000)).sum();
921        assert!((s - 1.0).abs() < 1e-9, "Geometric prob sum = {s}");
922    }
923
924    #[test]
925    fn histogram_probs_sum_to_one() {
926        let probs = vec![0.2, 0.5, 0.3];
927        let d = DurationDistrib::Histogram { probs };
928        let s: f64 = (1..=3).map(|t| d.prob(t, 3)).sum();
929        assert!((s - 1.0).abs() < 1e-9, "Histogram prob sum = {s}");
930    }
931
932    #[test]
933    fn poisson_log_prob_finite_for_positive_lambda() {
934        let d = DurationDistrib::Poisson { lambda: 2.0 };
935        let lp = d.log_prob(1, 10);
936        assert!(lp.is_finite(), "Poisson log_prob(1, 10) = {lp}");
937    }
938
939    #[test]
940    fn geometric_prob_decreasing() {
941        let d = DurationDistrib::Geometric { p: 0.5 };
942        for t in 1..=5 {
943            assert!(
944                d.prob(t, 20) > d.prob(t + 1, 20),
945                "Geometric should be decreasing"
946            );
947        }
948    }
949
950    // ── Hsmm construction tests ────────────────────────────────────────────────
951
952    fn two_state_model() -> Hsmm {
953        Hsmm::new(
954            2,
955            2,
956            5,
957            vec![0.5, 0.5],
958            vec![0.0, 1.0, 1.0, 0.0],
959            vec![0.9, 0.1, 0.1, 0.9],
960            vec![
961                DurationDistrib::Geometric { p: 0.3 },
962                DurationDistrib::Geometric { p: 0.3 },
963            ],
964        )
965        .expect("valid model")
966    }
967
968    #[test]
969    fn hsmm_new_validates_shapes() {
970        // Wrong pi length.
971        assert!(
972            Hsmm::new(
973                2,
974                2,
975                5,
976                vec![1.0],
977                vec![0.0, 1.0, 1.0, 0.0],
978                vec![0.5, 0.5, 0.5, 0.5],
979                vec![DurationDistrib::Geometric { p: 0.5 }; 2],
980            )
981            .is_err()
982        );
983    }
984
985    #[test]
986    fn hsmm_new_rejects_nonzero_diagonal() {
987        // Non-zero diagonal.
988        assert!(
989            Hsmm::new(
990                2,
991                2,
992                5,
993                vec![0.5, 0.5],
994                vec![0.5, 0.5, 0.5, 0.5], // diagonal = 0.5 (invalid)
995                vec![0.5, 0.5, 0.5, 0.5],
996                vec![DurationDistrib::Geometric { p: 0.5 }; 2],
997            )
998            .is_err()
999        );
1000    }
1001
1002    #[test]
1003    fn log_likelihood_finite_for_valid_obs() {
1004        let m = two_state_model();
1005        let ll = m.log_likelihood(&[0, 1, 0, 1]).expect("should succeed");
1006        assert!(ll.is_finite(), "ll = {ll}");
1007    }
1008
1009    #[test]
1010    fn log_likelihood_err_for_empty_obs() {
1011        let m = two_state_model();
1012        assert!(m.log_likelihood(&[]).is_err());
1013    }
1014
1015    #[test]
1016    fn log_likelihood_err_for_obs_out_of_range() {
1017        let m = two_state_model();
1018        assert!(m.log_likelihood(&[0, 5]).is_err());
1019    }
1020
1021    #[test]
1022    fn decode_returns_sequence_of_correct_length() {
1023        let m = two_state_model();
1024        let obs = vec![0usize, 0, 1, 1, 0];
1025        let path = m.decode(&obs).expect("ok");
1026        assert_eq!(path.len(), obs.len());
1027    }
1028
1029    #[test]
1030    fn decode_all_same_when_one_state_dominates() {
1031        // In an HSMM, A diagonal must be 0 (no self-transitions), so a 2-state
1032        // model alternates.  Use a single-state model to verify that observing
1033        // a symbol repeatedly maps only to that one state.
1034        let m = Hsmm::new(
1035            1, // single state
1036            2,
1037            3,
1038            vec![1.0],          // pi
1039            vec![0.0],          // A (1×1, diagonal=0)
1040            vec![0.999, 0.001], // B: state 0 strongly emits symbol 0
1041            vec![DurationDistrib::Geometric { p: 0.5 }],
1042        )
1043        .expect("ok");
1044        let obs = vec![0usize; 4];
1045        let path = m.decode(&obs).expect("ok");
1046        assert!(
1047            path.iter().all(|&s| s == 0),
1048            "expected all state 0, got {:?}",
1049            path
1050        );
1051    }
1052
1053    // ── hsm_fit tests ──────────────────────────────────────────────────────────
1054
1055    #[test]
1056    fn hsm_fit_runs_without_error() {
1057        let obs = vec![0usize, 0, 1, 1, 0, 1, 0, 0, 1, 1];
1058        let cfg = HsmConfig::default();
1059        assert!(hsm_fit(&[&obs], &cfg).is_ok());
1060    }
1061
1062    #[test]
1063    fn hsm_fit_ll_non_decreasing() {
1064        let obs: Vec<usize> = (0..20).map(|i| i % 2).collect();
1065        let cfg = HsmConfig {
1066            max_iter: 30,
1067            ..Default::default()
1068        };
1069        let r = hsm_fit(&[&obs], &cfg).expect("ok");
1070        for w in r.log_likelihood_history.windows(2) {
1071            assert!(w[1] >= w[0] - 1e-4, "LL decreased: {} → {}", w[0], w[1]);
1072        }
1073    }
1074
1075    #[test]
1076    fn hsm_fit_converged_flag() {
1077        let obs: Vec<usize> = (0..50).map(|i| i % 2).collect();
1078        let cfg = HsmConfig {
1079            max_iter: 500,
1080            tol: 1e-3,
1081            ..Default::default()
1082        };
1083        let r = hsm_fit(&[&obs], &cfg).expect("ok");
1084        assert!(r.converged, "expected convergence");
1085    }
1086
1087    #[test]
1088    fn hsm_fit_result_pi_sums_to_one() {
1089        let obs = vec![0usize, 1, 0, 1, 0, 0];
1090        let cfg = HsmConfig::default();
1091        let r = hsm_fit(&[&obs], &cfg).expect("ok");
1092        let s: f64 = r.model.pi.iter().sum();
1093        assert!((s - 1.0).abs() < 1e-9, "pi sums to {s}");
1094    }
1095
1096    #[test]
1097    fn hsm_fit_result_b_rows_sum_to_one() {
1098        let obs = vec![0usize, 1, 0, 1, 0, 0];
1099        let cfg = HsmConfig::default();
1100        let r = hsm_fit(&[&obs], &cfg).expect("ok");
1101        let n = cfg.n_states;
1102        let k = cfg.n_obs;
1103        for j in 0..n {
1104            let s: f64 = r.model.b[j * k..(j + 1) * k].iter().sum();
1105            assert!((s - 1.0).abs() < 1e-9, "B row {j} sums to {s}");
1106        }
1107    }
1108
1109    #[test]
1110    fn hsm_fit_n_iter_within_max_iter() {
1111        let obs = vec![0usize, 1, 0, 1, 0, 0];
1112        let cfg = HsmConfig {
1113            max_iter: 10,
1114            ..Default::default()
1115        };
1116        let r = hsm_fit(&[&obs], &cfg).expect("ok");
1117        assert!(r.n_iter <= 10);
1118    }
1119
1120    #[test]
1121    fn hsm_fit_multiple_sequences() {
1122        let s1 = vec![0usize, 0, 1, 1];
1123        let s2 = vec![1usize, 0, 1, 0, 0];
1124        let s3 = vec![0usize, 1, 1, 0, 1, 0];
1125        let cfg = HsmConfig::default();
1126        assert!(hsm_fit(&[&s1, &s2, &s3], &cfg).is_ok());
1127    }
1128
1129    #[test]
1130    fn hsm_fit_short_sequence_length_one() {
1131        let obs = vec![0usize];
1132        let cfg = HsmConfig::default();
1133        let r = hsm_fit(&[&obs], &cfg).expect("length-1 sequence should work");
1134        assert!(!r.log_likelihood_history.is_empty());
1135    }
1136
1137    #[test]
1138    fn hsm_fit_max_dur_one() {
1139        // max_dur=1 means every segment has exactly 1 step → equivalent to standard HMM.
1140        let obs: Vec<usize> = (0..10).map(|i| i % 2).collect();
1141        let cfg = HsmConfig {
1142            max_dur: 1,
1143            ..Default::default()
1144        };
1145        let r = hsm_fit(&[&obs], &cfg).expect("max_dur=1 should work");
1146        assert!(!r.log_likelihood_history.is_empty());
1147    }
1148
1149    #[test]
1150    fn hsmm_a_rows_zero_diagonal() {
1151        // After fit, diagonal should remain 0.
1152        let obs: Vec<usize> = (0..10).map(|i| i % 2).collect();
1153        let cfg = HsmConfig::default();
1154        let r = hsm_fit(&[&obs], &cfg).expect("ok");
1155        let n = cfg.n_states;
1156        for i in 0..n {
1157            let diag = r.model.a[i * n + i];
1158            assert!(diag.abs() < 1e-9, "diagonal A[{i},{i}] = {diag}");
1159        }
1160    }
1161
1162    #[test]
1163    fn hsm_fit_empty_input_err() {
1164        let cfg = HsmConfig::default();
1165        assert!(hsm_fit(&[], &cfg).is_err());
1166    }
1167}