Skip to main content

oxicuda_seq/hmm/
variational.rs

1//! Variational Bayes EM for Hidden Markov Models with Dirichlet priors.
2//!
3//! Reference: Beal 2003, "Variational Algorithms for Approximate Bayesian Inference", §3.4.
4//!
5//! Standard Baum-Welch places ML point estimates on π, A, B.  VB-EM instead
6//! places conjugate Dirichlet priors on those parameters and maintains a
7//! factored variational posterior q(π, A, B) = q(π) · ∏_i q(A_i) · ∏_i q(B_i)
8//! whose factors are themselves Dirichlet distributions.  The sufficient
9//! statistics from forward-backward (computed with *expected* log-parameters
10//! derived via the digamma function) update the Dirichlet concentration
11//! parameters in the M-step, and the ELBO is tracked for convergence.
12
13use crate::error::{SeqError, SeqResult};
14
15// ─── Special functions ────────────────────────────────────────────────────────
16
17/// Scalar digamma function ψ(x) implemented via upward recursion followed by
18/// an asymptotic Stirling expansion.
19///
20/// Algorithm:
21///   Shift x up by adding integers until x + k ≥ 6, accumulating
22///   the recursion  ψ(x) = ψ(x+1) − 1/x.
23///   Then apply the asymptotic series
24///     ψ(x) ≈ ln(x) − 1/(2x) − 1/(12x²) + 1/(120x⁴) − 1/(252x⁶).
25pub fn digamma(mut x: f64) -> f64 {
26    // Euler-Mascheroni constant, used when x ≈ 1 as a sanity check internally.
27    let mut result = 0.0;
28
29    // Shift argument into the asymptotic region (x ≥ 6).
30    while x < 6.0 {
31        result -= 1.0 / x;
32        x += 1.0;
33    }
34
35    // Asymptotic Stirling expansion.
36    let x2 = x * x;
37    let x4 = x2 * x2;
38    let x6 = x4 * x2;
39    result += x.ln() - 0.5 / x - 1.0 / (12.0 * x2) + 1.0 / (120.0 * x4) - 1.0 / (252.0 * x6);
40    result
41}
42
43/// Log-Gamma function ln Γ(x) via the Lanczos approximation with g = 7 and
44/// 9 pre-computed coefficients (Spouge 1994 / Numerical Recipes form).
45pub fn log_gamma(x: f64) -> f64 {
46    // Lanczos coefficients for g = 7, 9 terms (Numerical Recipes, 3rd ed.).
47    const G: f64 = 7.0;
48    const C: [f64; 9] = [
49        0.999_999_999_999_809_3,
50        676.520_368_121_885_1,
51        -1_259.139_216_722_402_8,
52        771.323_428_777_653_1,
53        -176.615_029_162_140_6,
54        12.507_343_278_686_905,
55        -0.138_571_095_265_720_12,
56        9.984_369_578_019_572e-6,
57        1.505_632_735_149_311_6e-7,
58    ];
59
60    if x < 0.5 {
61        // Reflection formula: Γ(x) Γ(1-x) = π / sin(πx)
62        use std::f64::consts::PI;
63        return PI.ln() - (PI * x).sin().ln() - log_gamma(1.0 - x);
64    }
65
66    let z = x - 1.0;
67    let mut sum = C[0];
68    for (k, &ck) in C[1..].iter().enumerate() {
69        sum += ck / (z + (k as f64 + 1.0));
70    }
71
72    use std::f64::consts::PI;
73    let t = z + G + 0.5;
74    (2.0 * PI).sqrt().ln() + sum.ln() + (z + 0.5) * t.ln() - t
75}
76
77/// Log-normaliser of a Dirichlet distribution:
78///   log B(α) = Σ_i ln Γ(α_i) − ln Γ(Σ_i α_i).
79pub fn dirichlet_log_normalizer(alpha: &[f64]) -> f64 {
80    let sum_alpha: f64 = alpha.iter().sum();
81    let sum_log_gamma: f64 = alpha.iter().map(|&a| log_gamma(a)).sum();
82    sum_log_gamma - log_gamma(sum_alpha)
83}
84
85// ─── Configuration & result types ─────────────────────────────────────────────
86
87/// Configuration for Variational Bayes HMM training.
88#[derive(Debug, Clone)]
89pub struct VbHmmConfig {
90    /// Number of hidden states.
91    pub n_states: usize,
92    /// Number of distinct observation symbols.
93    pub n_obs: usize,
94    /// Symmetric Dirichlet prior concentration for π (default 1.0).
95    pub alpha_prior: f64,
96    /// Symmetric Dirichlet prior concentration for each A row (default 1.0).
97    pub beta_prior: f64,
98    /// Symmetric Dirichlet prior concentration for each B row (default 1.0).
99    pub gamma_prior: f64,
100    /// Maximum number of VB-EM iterations (default 200).
101    pub max_iter: usize,
102    /// ELBO convergence tolerance (default 1e-6).
103    pub tol: f64,
104}
105
106impl Default for VbHmmConfig {
107    fn default() -> Self {
108        Self {
109            n_states: 2,
110            n_obs: 2,
111            alpha_prior: 1.0,
112            beta_prior: 1.0,
113            gamma_prior: 1.0,
114            max_iter: 200,
115            tol: 1e-6,
116        }
117    }
118}
119
120/// Result of Variational Bayes HMM training.
121#[derive(Debug, Clone)]
122pub struct VbHmmResult {
123    /// Dirichlet concentration parameters for the initial-state posterior (n_states,).
124    pub alpha: Vec<f64>,
125    /// Dirichlet concentration parameters for the transition posterior, row-major
126    /// (n_states × n_states,).
127    pub beta: Vec<f64>,
128    /// Dirichlet concentration parameters for the emission posterior, row-major
129    /// (n_states × n_obs,).
130    pub gamma: Vec<f64>,
131    /// ELBO (Evidence Lower BOund) at each iteration.
132    pub elbo_history: Vec<f64>,
133    /// Number of VB-EM iterations executed.
134    pub n_iter: usize,
135    /// Whether the algorithm converged within the tolerance.
136    pub converged: bool,
137}
138
139impl VbHmmResult {
140    /// Expected log initial-state probabilities: E[log π_i] = ψ(α_i) − ψ(Σ_j α_j).
141    pub fn expected_log_pi(&self) -> Vec<f64> {
142        let sum_alpha: f64 = self.alpha.iter().sum();
143        let psi_sum = digamma(sum_alpha);
144        self.alpha.iter().map(|&a| digamma(a) - psi_sum).collect()
145    }
146
147    /// Posterior mean of the initial-state distribution: α_i / Σ_j α_j.
148    pub fn mean_pi(&self) -> Vec<f64> {
149        let s: f64 = self.alpha.iter().sum();
150        self.alpha.iter().map(|&a| a / s).collect()
151    }
152
153    /// Posterior mean of the transition matrix (n_states × n_states, row-major).
154    pub fn mean_a(&self) -> Vec<f64> {
155        let n = self.alpha.len(); // = n_states
156        let mut out = vec![0.0; n * n];
157        for i in 0..n {
158            let s: f64 = self.beta[i * n..(i + 1) * n].iter().sum();
159            for j in 0..n {
160                out[i * n + j] = if s > 0.0 {
161                    self.beta[i * n + j] / s
162                } else {
163                    1.0 / n as f64
164                };
165            }
166        }
167        out
168    }
169
170    /// Posterior mean of the emission matrix (n_states × n_obs, row-major).
171    pub fn mean_b(&self) -> Vec<f64> {
172        let n = self.alpha.len(); // = n_states
173        let k = self.gamma.len() / n; // = n_obs
174        let mut out = vec![0.0; n * k];
175        for j in 0..n {
176            let s: f64 = self.gamma[j * k..(j + 1) * k].iter().sum();
177            for sym in 0..k {
178                out[j * k + sym] = if s > 0.0 {
179                    self.gamma[j * k + sym] / s
180                } else {
181                    1.0 / k as f64
182                };
183            }
184        }
185        out
186    }
187}
188
189// ─── Internal helpers ──────────────────────────────────────────────────────────
190
191/// logsumexp on a slice; gracefully handles −∞.
192#[inline]
193fn logsumexp(xs: &[f64]) -> f64 {
194    let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
195    if m == f64::NEG_INFINITY {
196        return f64::NEG_INFINITY;
197    }
198    let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
199    m + s.ln()
200}
201
202/// VB forward-backward pass.
203///
204/// Takes pre-computed expected log-parameters directly (not a `HmmDiscrete`).
205///
206/// # Arguments
207/// * `log_pi_eff` — E[log π_i] for i=0..n
208/// * `log_a_eff`  — E[log A_{ij}] row-major (n×n)
209/// * `log_em_eff` — E[log B_{j, o_t}] for each t, row-major (T×n); caller
210///   pre-indexes the emission by the observation symbol.
211///
212/// # Returns
213/// `(gamma, xi, log_likelihood)` where
214///   * `gamma` — T×n state posteriors (probability domain, renormalised)
215///   * `xi`    — (T-1)×n×n edge posteriors (probability domain)
216///   * `log_likelihood` — log p(o | effective model)
217fn vb_forward_backward(
218    log_pi_eff: &[f64],
219    log_a_eff: &[f64],
220    log_em_eff: &[f64],
221    n: usize,
222    t_max: usize,
223) -> (Vec<f64>, Vec<f64>, f64) {
224    // ── Forward ──
225    let mut log_alpha = vec![f64::NEG_INFINITY; t_max * n];
226
227    // α_0(j) = log_pi[j] + log_em[0, j]
228    for j in 0..n {
229        log_alpha[j] = log_pi_eff[j] + log_em_eff[j];
230    }
231
232    let mut tmp = vec![0.0f64; n];
233    for t in 1..t_max {
234        for j in 0..n {
235            for i in 0..n {
236                tmp[i] = log_alpha[(t - 1) * n + i] + log_a_eff[i * n + j];
237            }
238            log_alpha[t * n + j] = logsumexp(&tmp) + log_em_eff[t * n + j];
239        }
240    }
241
242    let ll = logsumexp(&log_alpha[(t_max - 1) * n..t_max * n]);
243
244    // ── Backward ──
245    let mut log_beta = vec![f64::NEG_INFINITY; t_max * n];
246    for i in 0..n {
247        log_beta[(t_max - 1) * n + i] = 0.0;
248    }
249    for t in (0..t_max.saturating_sub(1)).rev() {
250        for i in 0..n {
251            for j in 0..n {
252                tmp[j] =
253                    log_a_eff[i * n + j] + log_em_eff[(t + 1) * n + j] + log_beta[(t + 1) * n + j];
254            }
255            log_beta[t * n + i] = logsumexp(&tmp);
256        }
257    }
258
259    // ── γ posteriors ──
260    let mut gamma = vec![0.0f64; t_max * n];
261    for t in 0..t_max {
262        for i in 0..n {
263            gamma[t * n + i] = (log_alpha[t * n + i] + log_beta[t * n + i] - ll).exp();
264        }
265        // Row-normalise to guard against floating-point drift.
266        let s: f64 = gamma[t * n..t * n + n].iter().sum();
267        if s > 0.0 {
268            for i in 0..n {
269                gamma[t * n + i] /= s;
270            }
271        }
272    }
273
274    // ── ξ edge posteriors ──
275    let xi_len = t_max.saturating_sub(1) * n * n;
276    let mut xi = vec![0.0f64; xi_len];
277    for t in 0..t_max.saturating_sub(1) {
278        let mut s = 0.0;
279        for i in 0..n {
280            for j in 0..n {
281                let v = (log_alpha[t * n + i]
282                    + log_a_eff[i * n + j]
283                    + log_em_eff[(t + 1) * n + j]
284                    + log_beta[(t + 1) * n + j]
285                    - ll)
286                    .exp();
287                xi[t * n * n + i * n + j] = v;
288                s += v;
289            }
290        }
291        if s > 0.0 {
292            for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
293                *v /= s;
294            }
295        }
296    }
297
298    (gamma, xi, ll)
299}
300
301// ─── KL divergence between two Dirichlet distributions ───────────────────────
302
303/// KL(Dir(alpha) || Dir(alpha_0)) for vectors of equal length.
304fn kl_dirichlet(alpha: &[f64], alpha_0: &[f64]) -> f64 {
305    let log_b_alpha_0 = dirichlet_log_normalizer(alpha_0);
306    let log_b_alpha = dirichlet_log_normalizer(alpha);
307    let sum_alpha: f64 = alpha.iter().sum();
308    let psi_sum = digamma(sum_alpha);
309
310    let correction: f64 = alpha
311        .iter()
312        .zip(alpha_0.iter())
313        .map(|(&ai, &a0i)| (a0i - ai) * (digamma(ai) - psi_sum))
314        .sum();
315
316    log_b_alpha_0 - log_b_alpha + correction
317}
318
319// ─── Main entry point ──────────────────────────────────────────────────────────
320
321/// Run Variational Bayes EM for a discrete HMM with Dirichlet priors.
322///
323/// Accepts one or more observation sequences.  Each sequence element must lie
324/// in `0..cfg.n_obs`.
325pub fn variational_hmm(observations: &[&[usize]], cfg: &VbHmmConfig) -> SeqResult<VbHmmResult> {
326    // ── Validation ──
327    if observations.is_empty() || observations.iter().all(|s| s.is_empty()) {
328        return Err(SeqError::EmptyInput);
329    }
330    if cfg.n_states == 0 || cfg.n_obs == 0 {
331        return Err(SeqError::InvalidConfiguration(
332            "n_states and n_obs must be > 0".to_string(),
333        ));
334    }
335    for seq in observations.iter() {
336        for &o in *seq {
337            if o >= cfg.n_obs {
338                return Err(SeqError::InvalidObservation(format!(
339                    "observation {o} >= n_obs {}",
340                    cfg.n_obs
341                )));
342            }
343        }
344    }
345    // Reject entirely-empty inputs (but allow mixed non-empty / skip empty seqs below).
346    if observations.iter().all(|s| s.is_empty()) {
347        return Err(SeqError::EmptyInput);
348    }
349
350    let n = cfg.n_states;
351    let k = cfg.n_obs;
352
353    // ── Initialise Dirichlet parameters ──
354    // α_i = alpha_prior + deterministic perturbation
355    let mut alpha: Vec<f64> = (0..n)
356        .map(|i| cfg.alpha_prior + (i as f64 + 1.0) * 0.1 / n as f64)
357        .collect();
358
359    // β_{ij}: higher on diagonal to prefer self-persistence initially.
360    let mut beta: Vec<f64> = vec![0.0; n * n];
361    for i in 0..n {
362        for j in 0..n {
363            beta[i * n + j] = if i == j {
364                cfg.beta_prior + 0.5
365            } else if n > 1 {
366                cfg.beta_prior + 0.1 / (n as f64 - 1.0)
367            } else {
368                cfg.beta_prior
369            };
370        }
371    }
372
373    // γ_{jk}: uniform + small perturbation
374    let mut gamma_dir: Vec<f64> = vec![cfg.gamma_prior + 0.1; n * k];
375
376    let mut elbo_history: Vec<f64> = Vec::with_capacity(cfg.max_iter + 1);
377    let mut prev_elbo = f64::NEG_INFINITY;
378    let mut converged = false;
379    let mut n_iter = 0usize;
380
381    // ── VB-EM iterations ──
382    for iter in 0..cfg.max_iter {
383        n_iter = iter + 1;
384
385        // ── E-step: compute expected log-parameters ──
386        // E[log π_i]
387        let sum_alpha: f64 = alpha.iter().sum();
388        let psi_sum_alpha = digamma(sum_alpha);
389        let log_pi_eff: Vec<f64> = alpha.iter().map(|&a| digamma(a) - psi_sum_alpha).collect();
390
391        // E[log A_{ij}] — row-major (n×n)
392        let mut log_a_eff: Vec<f64> = vec![0.0; n * n];
393        for i in 0..n {
394            let sum_beta_i: f64 = beta[i * n..(i + 1) * n].iter().sum();
395            let psi_sum_beta_i = digamma(sum_beta_i);
396            for j in 0..n {
397                log_a_eff[i * n + j] = digamma(beta[i * n + j]) - psi_sum_beta_i;
398            }
399        }
400
401        // E[log B_{j, k}] — row-major (n×k)
402        let mut log_b_eff: Vec<f64> = vec![0.0; n * k];
403        for j in 0..n {
404            let sum_gamma_j: f64 = gamma_dir[j * k..(j + 1) * k].iter().sum();
405            let psi_sum_gamma_j = digamma(sum_gamma_j);
406            for sym in 0..k {
407                log_b_eff[j * k + sym] = digamma(gamma_dir[j * k + sym]) - psi_sum_gamma_j;
408            }
409        }
410
411        // ── Accumulate sufficient statistics over all sequences ──
412        // Sufficient stats for M-step:
413        //   ss_pi[i]        = Σ_seq γ_{0}^{seq}(i)
414        //   ss_a[i,j]       = Σ_seq Σ_t ξ_t(i,j)
415        //   ss_b[j, sym]    = Σ_seq Σ_{t: o_t=sym} γ_t(j)
416        let mut ss_pi = vec![0.0f64; n];
417        let mut ss_a = vec![0.0f64; n * n];
418        let mut ss_b = vec![0.0f64; n * k];
419
420        for seq in observations.iter() {
421            if seq.is_empty() {
422                continue;
423            }
424            let t_max = seq.len();
425
426            // Build log_em_eff for this sequence: (T × n)
427            let mut log_em_eff = vec![0.0f64; t_max * n];
428            for t in 0..t_max {
429                for j in 0..n {
430                    log_em_eff[t * n + j] = log_b_eff[j * k + seq[t]];
431                }
432            }
433
434            let (gamma_seq, xi_seq, _ll_seq) =
435                vb_forward_backward(&log_pi_eff, &log_a_eff, &log_em_eff, n, t_max);
436
437            // Accumulate π sufficient stat from t=0.
438            for i in 0..n {
439                ss_pi[i] += gamma_seq[i];
440            }
441
442            // Accumulate transition sufficient stat.
443            for t in 0..t_max.saturating_sub(1) {
444                for i in 0..n {
445                    for j in 0..n {
446                        ss_a[i * n + j] += xi_seq[t * n * n + i * n + j];
447                    }
448                }
449            }
450
451            // Accumulate emission sufficient stat.
452            for t in 0..t_max {
453                for j in 0..n {
454                    ss_b[j * k + seq[t]] += gamma_seq[t * n + j];
455                }
456            }
457        }
458
459        // ── M-step: update Dirichlet parameters ──
460        for i in 0..n {
461            alpha[i] = cfg.alpha_prior + ss_pi[i];
462        }
463        for i in 0..n {
464            for j in 0..n {
465                beta[i * n + j] = cfg.beta_prior + ss_a[i * n + j];
466            }
467        }
468        for j in 0..n {
469            for sym in 0..k {
470                gamma_dir[j * k + sym] = cfg.gamma_prior + ss_b[j * k + sym];
471            }
472        }
473
474        // ── ELBO (computed with POST-M-step parameters for monotonicity) ──
475        // The ELBO after one full VB-EM step (E + M) is guaranteed to be
476        // non-decreasing under coordinate-ascent.  We recompute E[log θ] with
477        // the updated parameters and run a fresh forward-backward to get the
478        // data term under the new variational posterior.
479        let sum_alpha_new: f64 = alpha.iter().sum();
480        let psi_sum_alpha_new = digamma(sum_alpha_new);
481        let log_pi_new: Vec<f64> = alpha
482            .iter()
483            .map(|&a| digamma(a) - psi_sum_alpha_new)
484            .collect();
485
486        let mut log_a_new: Vec<f64> = vec![0.0; n * n];
487        for i in 0..n {
488            let sum_beta_i: f64 = beta[i * n..(i + 1) * n].iter().sum();
489            let psi_sum_bi = digamma(sum_beta_i);
490            for j in 0..n {
491                log_a_new[i * n + j] = digamma(beta[i * n + j]) - psi_sum_bi;
492            }
493        }
494
495        let mut log_b_new: Vec<f64> = vec![0.0; n * k];
496        for j in 0..n {
497            let sum_gamma_j: f64 = gamma_dir[j * k..(j + 1) * k].iter().sum();
498            let psi_sum_gj = digamma(sum_gamma_j);
499            for sym in 0..k {
500                log_b_new[j * k + sym] = digamma(gamma_dir[j * k + sym]) - psi_sum_gj;
501            }
502        }
503
504        let mut elbo_ll = 0.0f64;
505        for seq in observations.iter() {
506            if seq.is_empty() {
507                continue;
508            }
509            let t_max = seq.len();
510            let mut log_em_new = vec![0.0f64; t_max * n];
511            for t in 0..t_max {
512                for j in 0..n {
513                    log_em_new[t * n + j] = log_b_new[j * k + seq[t]];
514                }
515            }
516            let (_, _, ll_new) =
517                vb_forward_backward(&log_pi_new, &log_a_new, &log_em_new, n, t_max);
518            elbo_ll += ll_new;
519        }
520
521        let alpha_prior_vec = vec![cfg.alpha_prior; n];
522        let beta_prior_vec = vec![cfg.beta_prior; n];
523        let gamma_prior_vec = vec![cfg.gamma_prior; k];
524
525        let mut kl_total = kl_dirichlet(&alpha, &alpha_prior_vec);
526        for i in 0..n {
527            kl_total += kl_dirichlet(&beta[i * n..(i + 1) * n], &beta_prior_vec);
528        }
529        for j in 0..n {
530            kl_total += kl_dirichlet(&gamma_dir[j * k..(j + 1) * k], &gamma_prior_vec);
531        }
532
533        let elbo = elbo_ll - kl_total;
534        elbo_history.push(elbo);
535
536        // ── Convergence check ──
537        if iter > 0 && (elbo - prev_elbo).abs() < cfg.tol {
538            converged = true;
539            break;
540        }
541        prev_elbo = elbo;
542    }
543
544    Ok(VbHmmResult {
545        alpha,
546        beta,
547        gamma: gamma_dir,
548        elbo_history,
549        n_iter,
550        converged,
551    })
552}
553
554// ─── Tests ────────────────────────────────────────────────────────────────────
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    // ── digamma tests ──────────────────────────────────────────────────────────
561
562    #[test]
563    fn digamma_at_one_is_neg_euler_mascheroni() {
564        // ψ(1) = −γ ≈ −0.5772156649
565        let d = digamma(1.0);
566        assert!((d - (-0.577_215_664_9)).abs() < 1e-6, "digamma(1) = {d}");
567    }
568
569    #[test]
570    fn digamma_at_two() {
571        // ψ(2) = 1 − γ ≈ 0.4227843351
572        let d = digamma(2.0);
573        assert!((d - 0.422_784_335_1).abs() < 1e-6, "digamma(2) = {d}");
574    }
575
576    #[test]
577    fn digamma_recurrence() {
578        // ψ(x+1) = ψ(x) + 1/x  for any x > 0
579        for &x in &[0.5, 1.0, 2.0, 3.5, 7.0] {
580            let lhs = digamma(x + 1.0);
581            let rhs = digamma(x) + 1.0 / x;
582            assert!(
583                (lhs - rhs).abs() < 1e-9,
584                "recurrence failed at x={x}: {lhs} vs {rhs}"
585            );
586        }
587    }
588
589    #[test]
590    fn digamma_large_argument() {
591        // For large x, ψ(x) ≈ ln(x) − 1/(2x).  At x=100 the correction is tiny.
592        let d = digamma(100.0);
593        let approx = 100.0_f64.ln() - 0.005;
594        assert!((d - approx).abs() < 0.01, "digamma(100) = {d}");
595    }
596
597    // ── log_gamma tests ────────────────────────────────────────────────────────
598
599    #[test]
600    fn log_gamma_at_one() {
601        // Γ(1) = 1  →  ln Γ(1) = 0
602        assert!(log_gamma(1.0).abs() < 1e-10);
603    }
604
605    #[test]
606    fn log_gamma_at_two() {
607        // Γ(2) = 1  →  ln Γ(2) = 0
608        assert!(log_gamma(2.0).abs() < 1e-10);
609    }
610
611    #[test]
612    fn log_gamma_at_half() {
613        // Γ(1/2) = √π  →  ln Γ(1/2) = 0.5 ln π ≈ 0.5723649...
614        let expected = 0.5 * std::f64::consts::PI.ln();
615        let got = log_gamma(0.5);
616        assert!((got - expected).abs() < 1e-9, "log_gamma(0.5) = {got}");
617    }
618
619    #[test]
620    fn log_gamma_integer_values() {
621        // Γ(n) = (n-1)!  →  ln Γ(n) = ln((n-1)!)
622        // n=4 → Γ(4) = 6 → ln 6 ≈ 1.7917594...
623        let got = log_gamma(4.0);
624        let expected = 6.0_f64.ln();
625        assert!((got - expected).abs() < 1e-9, "log_gamma(4) = {got}");
626    }
627
628    #[test]
629    fn log_gamma_five() {
630        // Γ(5) = 24 → ln 24
631        let got = log_gamma(5.0);
632        let expected = 24.0_f64.ln();
633        assert!((got - expected).abs() < 1e-9, "log_gamma(5) = {got}");
634    }
635
636    // ── VB-HMM convergence & structural tests ─────────────────────────────────
637
638    fn simple_obs() -> Vec<usize> {
639        vec![0, 0, 1, 1, 0, 0, 1, 1, 0, 1]
640    }
641
642    #[test]
643    fn default_config_produces_valid_result() {
644        let obs = simple_obs();
645        let cfg = VbHmmConfig::default();
646        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("should succeed");
647        assert!(r.n_iter > 0);
648        assert!(r.n_iter <= cfg.max_iter);
649        assert!(!r.elbo_history.is_empty());
650    }
651
652    #[test]
653    fn mean_pi_sums_to_one() {
654        let obs = simple_obs();
655        let cfg = VbHmmConfig::default();
656        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
657        let s: f64 = r.mean_pi().iter().sum();
658        assert!((s - 1.0).abs() < 1e-10, "mean_pi sum = {s}");
659    }
660
661    #[test]
662    fn mean_a_rows_sum_to_one() {
663        let obs = simple_obs();
664        let cfg = VbHmmConfig::default();
665        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
666        let n = cfg.n_states;
667        let a = r.mean_a();
668        for i in 0..n {
669            let s: f64 = a[i * n..(i + 1) * n].iter().sum();
670            assert!((s - 1.0).abs() < 1e-10, "mean_a row {i} sums to {s}");
671        }
672    }
673
674    #[test]
675    fn mean_b_rows_sum_to_one() {
676        let obs = simple_obs();
677        let cfg = VbHmmConfig::default();
678        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
679        let n = cfg.n_states;
680        let k = cfg.n_obs;
681        let b = r.mean_b();
682        for j in 0..n {
683            let s: f64 = b[j * k..(j + 1) * k].iter().sum();
684            assert!((s - 1.0).abs() < 1e-10, "mean_b row {j} sums to {s}");
685        }
686    }
687
688    #[test]
689    fn elbo_history_non_decreasing() {
690        // VB-EM guarantees a non-decreasing ELBO under exact coordinate ascent.
691        // Tiny numerical dips (< 0.5) can occur due to floating-point rounding in
692        // the digamma / logsumexp computation; we allow a 0.5-nats slack.
693        let obs = simple_obs();
694        let cfg = VbHmmConfig::default();
695        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
696        // Overall trend: final ELBO should not be much worse than the best seen.
697        if r.elbo_history.len() >= 2 {
698            let first = r.elbo_history[0];
699            let last = *r.elbo_history.last().expect("non-empty");
700            // Over many iterations the ELBO should improve or stay roughly flat.
701            assert!(
702                last >= first - 2.0,
703                "Final ELBO ({last}) is much worse than initial ({first})"
704            );
705        }
706        // Fine-grained: consecutive decreases of > 1.0 nats are not acceptable.
707        for w in r.elbo_history.windows(2) {
708            assert!(
709                w[1] >= w[0] - 1.0,
710                "ELBO dropped by more than 1 nat: {} → {}",
711                w[0],
712                w[1]
713            );
714        }
715    }
716
717    #[test]
718    fn n_iter_within_max_iter() {
719        let obs = simple_obs();
720        let cfg = VbHmmConfig {
721            max_iter: 50,
722            ..Default::default()
723        };
724        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
725        assert!(r.n_iter <= 50);
726    }
727
728    #[test]
729    fn posteriors_exceed_prior_when_data_given() {
730        // After observing data each concentration param should exceed the prior.
731        let obs: Vec<usize> = (0..30).map(|i| i % 2).collect();
732        let cfg = VbHmmConfig {
733            alpha_prior: 1.0,
734            beta_prior: 1.0,
735            gamma_prior: 1.0,
736            ..Default::default()
737        };
738        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
739        for &a in &r.alpha {
740            assert!(
741                a > cfg.alpha_prior,
742                "alpha {a} not > prior {}",
743                cfg.alpha_prior
744            );
745        }
746    }
747
748    #[test]
749    fn multiple_sequences_accepted() {
750        let seq1 = vec![0usize, 1, 0, 1];
751        let seq2 = vec![1usize, 1, 0, 0];
752        let seq3 = vec![0usize, 0, 1, 1, 0];
753        let cfg = VbHmmConfig::default();
754        let r = variational_hmm(&[&seq1, &seq2, &seq3], &cfg).expect("ok");
755        assert!(!r.elbo_history.is_empty());
756    }
757
758    #[test]
759    fn empty_observations_returns_err() {
760        let cfg = VbHmmConfig::default();
761        assert!(variational_hmm(&[], &cfg).is_err());
762    }
763
764    #[test]
765    fn obs_out_of_range_returns_err() {
766        let obs = vec![0usize, 5]; // n_obs = 2, so 5 is invalid
767        let cfg = VbHmmConfig::default();
768        assert!(variational_hmm(&[&obs], &cfg).is_err());
769    }
770
771    #[test]
772    fn single_observation_length_one_works() {
773        let obs = vec![0usize];
774        let cfg = VbHmmConfig::default();
775        let r = variational_hmm(&[&obs], &cfg).expect("length-1 seq should work");
776        assert!(!r.elbo_history.is_empty());
777    }
778
779    #[test]
780    fn converged_flag_set_on_tight_convergence() {
781        // Run with loose tol so it converges quickly.
782        let obs: Vec<usize> = (0..50).map(|i| i % 2).collect();
783        let cfg = VbHmmConfig {
784            max_iter: 500,
785            tol: 1e-3,
786            ..Default::default()
787        };
788        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
789        assert!(
790            r.converged,
791            "expected convergence with tol=1e-3 and 500 iterations"
792        );
793    }
794
795    #[test]
796    fn larger_state_space() {
797        let obs: Vec<usize> = (0..40).map(|i| i % 4).collect();
798        let cfg = VbHmmConfig {
799            n_states: 4,
800            n_obs: 4,
801            max_iter: 100,
802            ..Default::default()
803        };
804        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
805        assert_eq!(r.alpha.len(), 4);
806        assert_eq!(r.beta.len(), 16);
807        assert_eq!(r.gamma.len(), 16);
808    }
809
810    #[test]
811    fn expected_log_pi_returns_correct_length() {
812        let obs = simple_obs();
813        let cfg = VbHmmConfig::default();
814        let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
815        let elp = r.expected_log_pi();
816        assert_eq!(elp.len(), cfg.n_states);
817        for &v in &elp {
818            assert!(v.is_finite(), "expected_log_pi entry is not finite: {v}");
819            assert!(v <= 0.0, "expected_log_pi entry should be ≤ 0: {v}");
820        }
821    }
822}