use crate::error::{SeqError, SeqResult};
pub fn digamma(mut x: f64) -> f64 {
let mut result = 0.0;
while x < 6.0 {
result -= 1.0 / x;
x += 1.0;
}
let x2 = x * x;
let x4 = x2 * x2;
let x6 = x4 * x2;
result += x.ln() - 0.5 / x - 1.0 / (12.0 * x2) + 1.0 / (120.0 * x4) - 1.0 / (252.0 * x6);
result
}
pub fn log_gamma(x: f64) -> f64 {
const G: f64 = 7.0;
const C: [f64; 9] = [
0.999_999_999_999_809_3,
676.520_368_121_885_1,
-1_259.139_216_722_402_8,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_311_6e-7,
];
if x < 0.5 {
use std::f64::consts::PI;
return PI.ln() - (PI * x).sin().ln() - log_gamma(1.0 - x);
}
let z = x - 1.0;
let mut sum = C[0];
for (k, &ck) in C[1..].iter().enumerate() {
sum += ck / (z + (k as f64 + 1.0));
}
use std::f64::consts::PI;
let t = z + G + 0.5;
(2.0 * PI).sqrt().ln() + sum.ln() + (z + 0.5) * t.ln() - t
}
pub fn dirichlet_log_normalizer(alpha: &[f64]) -> f64 {
let sum_alpha: f64 = alpha.iter().sum();
let sum_log_gamma: f64 = alpha.iter().map(|&a| log_gamma(a)).sum();
sum_log_gamma - log_gamma(sum_alpha)
}
#[derive(Debug, Clone)]
pub struct VbHmmConfig {
pub n_states: usize,
pub n_obs: usize,
pub alpha_prior: f64,
pub beta_prior: f64,
pub gamma_prior: f64,
pub max_iter: usize,
pub tol: f64,
}
impl Default for VbHmmConfig {
fn default() -> Self {
Self {
n_states: 2,
n_obs: 2,
alpha_prior: 1.0,
beta_prior: 1.0,
gamma_prior: 1.0,
max_iter: 200,
tol: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct VbHmmResult {
pub alpha: Vec<f64>,
pub beta: Vec<f64>,
pub gamma: Vec<f64>,
pub elbo_history: Vec<f64>,
pub n_iter: usize,
pub converged: bool,
}
impl VbHmmResult {
pub fn expected_log_pi(&self) -> Vec<f64> {
let sum_alpha: f64 = self.alpha.iter().sum();
let psi_sum = digamma(sum_alpha);
self.alpha.iter().map(|&a| digamma(a) - psi_sum).collect()
}
pub fn mean_pi(&self) -> Vec<f64> {
let s: f64 = self.alpha.iter().sum();
self.alpha.iter().map(|&a| a / s).collect()
}
pub fn mean_a(&self) -> Vec<f64> {
let n = self.alpha.len(); let mut out = vec![0.0; n * n];
for i in 0..n {
let s: f64 = self.beta[i * n..(i + 1) * n].iter().sum();
for j in 0..n {
out[i * n + j] = if s > 0.0 {
self.beta[i * n + j] / s
} else {
1.0 / n as f64
};
}
}
out
}
pub fn mean_b(&self) -> Vec<f64> {
let n = self.alpha.len(); let k = self.gamma.len() / n; let mut out = vec![0.0; n * k];
for j in 0..n {
let s: f64 = self.gamma[j * k..(j + 1) * k].iter().sum();
for sym in 0..k {
out[j * k + sym] = if s > 0.0 {
self.gamma[j * k + sym] / s
} else {
1.0 / k as f64
};
}
}
out
}
}
#[inline]
fn logsumexp(xs: &[f64]) -> f64 {
let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if m == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
m + s.ln()
}
fn vb_forward_backward(
log_pi_eff: &[f64],
log_a_eff: &[f64],
log_em_eff: &[f64],
n: usize,
t_max: usize,
) -> (Vec<f64>, Vec<f64>, f64) {
let mut log_alpha = vec![f64::NEG_INFINITY; t_max * n];
for j in 0..n {
log_alpha[j] = log_pi_eff[j] + log_em_eff[j];
}
let mut tmp = vec![0.0f64; n];
for t in 1..t_max {
for j in 0..n {
for i in 0..n {
tmp[i] = log_alpha[(t - 1) * n + i] + log_a_eff[i * n + j];
}
log_alpha[t * n + j] = logsumexp(&tmp) + log_em_eff[t * n + j];
}
}
let ll = logsumexp(&log_alpha[(t_max - 1) * n..t_max * n]);
let mut log_beta = vec![f64::NEG_INFINITY; t_max * n];
for i in 0..n {
log_beta[(t_max - 1) * n + i] = 0.0;
}
for t in (0..t_max.saturating_sub(1)).rev() {
for i in 0..n {
for j in 0..n {
tmp[j] =
log_a_eff[i * n + j] + log_em_eff[(t + 1) * n + j] + log_beta[(t + 1) * n + j];
}
log_beta[t * n + i] = logsumexp(&tmp);
}
}
let mut gamma = vec![0.0f64; t_max * n];
for t in 0..t_max {
for i in 0..n {
gamma[t * n + i] = (log_alpha[t * n + i] + log_beta[t * n + i] - ll).exp();
}
let s: f64 = gamma[t * n..t * n + n].iter().sum();
if s > 0.0 {
for i in 0..n {
gamma[t * n + i] /= s;
}
}
}
let xi_len = t_max.saturating_sub(1) * n * n;
let mut xi = vec![0.0f64; xi_len];
for t in 0..t_max.saturating_sub(1) {
let mut s = 0.0;
for i in 0..n {
for j in 0..n {
let v = (log_alpha[t * n + i]
+ log_a_eff[i * n + j]
+ log_em_eff[(t + 1) * n + j]
+ log_beta[(t + 1) * n + j]
- ll)
.exp();
xi[t * n * n + i * n + j] = v;
s += v;
}
}
if s > 0.0 {
for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
*v /= s;
}
}
}
(gamma, xi, ll)
}
fn kl_dirichlet(alpha: &[f64], alpha_0: &[f64]) -> f64 {
let log_b_alpha_0 = dirichlet_log_normalizer(alpha_0);
let log_b_alpha = dirichlet_log_normalizer(alpha);
let sum_alpha: f64 = alpha.iter().sum();
let psi_sum = digamma(sum_alpha);
let correction: f64 = alpha
.iter()
.zip(alpha_0.iter())
.map(|(&ai, &a0i)| (a0i - ai) * (digamma(ai) - psi_sum))
.sum();
log_b_alpha_0 - log_b_alpha + correction
}
pub fn variational_hmm(observations: &[&[usize]], cfg: &VbHmmConfig) -> SeqResult<VbHmmResult> {
if observations.is_empty() || observations.iter().all(|s| s.is_empty()) {
return Err(SeqError::EmptyInput);
}
if cfg.n_states == 0 || cfg.n_obs == 0 {
return Err(SeqError::InvalidConfiguration(
"n_states and n_obs must be > 0".to_string(),
));
}
for seq in observations.iter() {
for &o in *seq {
if o >= cfg.n_obs {
return Err(SeqError::InvalidObservation(format!(
"observation {o} >= n_obs {}",
cfg.n_obs
)));
}
}
}
if observations.iter().all(|s| s.is_empty()) {
return Err(SeqError::EmptyInput);
}
let n = cfg.n_states;
let k = cfg.n_obs;
let mut alpha: Vec<f64> = (0..n)
.map(|i| cfg.alpha_prior + (i as f64 + 1.0) * 0.1 / n as f64)
.collect();
let mut beta: Vec<f64> = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
beta[i * n + j] = if i == j {
cfg.beta_prior + 0.5
} else if n > 1 {
cfg.beta_prior + 0.1 / (n as f64 - 1.0)
} else {
cfg.beta_prior
};
}
}
let mut gamma_dir: Vec<f64> = vec![cfg.gamma_prior + 0.1; n * k];
let mut elbo_history: Vec<f64> = Vec::with_capacity(cfg.max_iter + 1);
let mut prev_elbo = f64::NEG_INFINITY;
let mut converged = false;
let mut n_iter = 0usize;
for iter in 0..cfg.max_iter {
n_iter = iter + 1;
let sum_alpha: f64 = alpha.iter().sum();
let psi_sum_alpha = digamma(sum_alpha);
let log_pi_eff: Vec<f64> = alpha.iter().map(|&a| digamma(a) - psi_sum_alpha).collect();
let mut log_a_eff: Vec<f64> = vec![0.0; n * n];
for i in 0..n {
let sum_beta_i: f64 = beta[i * n..(i + 1) * n].iter().sum();
let psi_sum_beta_i = digamma(sum_beta_i);
for j in 0..n {
log_a_eff[i * n + j] = digamma(beta[i * n + j]) - psi_sum_beta_i;
}
}
let mut log_b_eff: Vec<f64> = vec![0.0; n * k];
for j in 0..n {
let sum_gamma_j: f64 = gamma_dir[j * k..(j + 1) * k].iter().sum();
let psi_sum_gamma_j = digamma(sum_gamma_j);
for sym in 0..k {
log_b_eff[j * k + sym] = digamma(gamma_dir[j * k + sym]) - psi_sum_gamma_j;
}
}
let mut ss_pi = vec![0.0f64; n];
let mut ss_a = vec![0.0f64; n * n];
let mut ss_b = vec![0.0f64; n * k];
for seq in observations.iter() {
if seq.is_empty() {
continue;
}
let t_max = seq.len();
let mut log_em_eff = vec![0.0f64; t_max * n];
for t in 0..t_max {
for j in 0..n {
log_em_eff[t * n + j] = log_b_eff[j * k + seq[t]];
}
}
let (gamma_seq, xi_seq, _ll_seq) =
vb_forward_backward(&log_pi_eff, &log_a_eff, &log_em_eff, n, t_max);
for i in 0..n {
ss_pi[i] += gamma_seq[i];
}
for t in 0..t_max.saturating_sub(1) {
for i in 0..n {
for j in 0..n {
ss_a[i * n + j] += xi_seq[t * n * n + i * n + j];
}
}
}
for t in 0..t_max {
for j in 0..n {
ss_b[j * k + seq[t]] += gamma_seq[t * n + j];
}
}
}
for i in 0..n {
alpha[i] = cfg.alpha_prior + ss_pi[i];
}
for i in 0..n {
for j in 0..n {
beta[i * n + j] = cfg.beta_prior + ss_a[i * n + j];
}
}
for j in 0..n {
for sym in 0..k {
gamma_dir[j * k + sym] = cfg.gamma_prior + ss_b[j * k + sym];
}
}
let sum_alpha_new: f64 = alpha.iter().sum();
let psi_sum_alpha_new = digamma(sum_alpha_new);
let log_pi_new: Vec<f64> = alpha
.iter()
.map(|&a| digamma(a) - psi_sum_alpha_new)
.collect();
let mut log_a_new: Vec<f64> = vec![0.0; n * n];
for i in 0..n {
let sum_beta_i: f64 = beta[i * n..(i + 1) * n].iter().sum();
let psi_sum_bi = digamma(sum_beta_i);
for j in 0..n {
log_a_new[i * n + j] = digamma(beta[i * n + j]) - psi_sum_bi;
}
}
let mut log_b_new: Vec<f64> = vec![0.0; n * k];
for j in 0..n {
let sum_gamma_j: f64 = gamma_dir[j * k..(j + 1) * k].iter().sum();
let psi_sum_gj = digamma(sum_gamma_j);
for sym in 0..k {
log_b_new[j * k + sym] = digamma(gamma_dir[j * k + sym]) - psi_sum_gj;
}
}
let mut elbo_ll = 0.0f64;
for seq in observations.iter() {
if seq.is_empty() {
continue;
}
let t_max = seq.len();
let mut log_em_new = vec![0.0f64; t_max * n];
for t in 0..t_max {
for j in 0..n {
log_em_new[t * n + j] = log_b_new[j * k + seq[t]];
}
}
let (_, _, ll_new) =
vb_forward_backward(&log_pi_new, &log_a_new, &log_em_new, n, t_max);
elbo_ll += ll_new;
}
let alpha_prior_vec = vec![cfg.alpha_prior; n];
let beta_prior_vec = vec![cfg.beta_prior; n];
let gamma_prior_vec = vec![cfg.gamma_prior; k];
let mut kl_total = kl_dirichlet(&alpha, &alpha_prior_vec);
for i in 0..n {
kl_total += kl_dirichlet(&beta[i * n..(i + 1) * n], &beta_prior_vec);
}
for j in 0..n {
kl_total += kl_dirichlet(&gamma_dir[j * k..(j + 1) * k], &gamma_prior_vec);
}
let elbo = elbo_ll - kl_total;
elbo_history.push(elbo);
if iter > 0 && (elbo - prev_elbo).abs() < cfg.tol {
converged = true;
break;
}
prev_elbo = elbo;
}
Ok(VbHmmResult {
alpha,
beta,
gamma: gamma_dir,
elbo_history,
n_iter,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn digamma_at_one_is_neg_euler_mascheroni() {
let d = digamma(1.0);
assert!((d - (-0.577_215_664_9)).abs() < 1e-6, "digamma(1) = {d}");
}
#[test]
fn digamma_at_two() {
let d = digamma(2.0);
assert!((d - 0.422_784_335_1).abs() < 1e-6, "digamma(2) = {d}");
}
#[test]
fn digamma_recurrence() {
for &x in &[0.5, 1.0, 2.0, 3.5, 7.0] {
let lhs = digamma(x + 1.0);
let rhs = digamma(x) + 1.0 / x;
assert!(
(lhs - rhs).abs() < 1e-9,
"recurrence failed at x={x}: {lhs} vs {rhs}"
);
}
}
#[test]
fn digamma_large_argument() {
let d = digamma(100.0);
let approx = 100.0_f64.ln() - 0.005;
assert!((d - approx).abs() < 0.01, "digamma(100) = {d}");
}
#[test]
fn log_gamma_at_one() {
assert!(log_gamma(1.0).abs() < 1e-10);
}
#[test]
fn log_gamma_at_two() {
assert!(log_gamma(2.0).abs() < 1e-10);
}
#[test]
fn log_gamma_at_half() {
let expected = 0.5 * std::f64::consts::PI.ln();
let got = log_gamma(0.5);
assert!((got - expected).abs() < 1e-9, "log_gamma(0.5) = {got}");
}
#[test]
fn log_gamma_integer_values() {
let got = log_gamma(4.0);
let expected = 6.0_f64.ln();
assert!((got - expected).abs() < 1e-9, "log_gamma(4) = {got}");
}
#[test]
fn log_gamma_five() {
let got = log_gamma(5.0);
let expected = 24.0_f64.ln();
assert!((got - expected).abs() < 1e-9, "log_gamma(5) = {got}");
}
fn simple_obs() -> Vec<usize> {
vec![0, 0, 1, 1, 0, 0, 1, 1, 0, 1]
}
#[test]
fn default_config_produces_valid_result() {
let obs = simple_obs();
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("should succeed");
assert!(r.n_iter > 0);
assert!(r.n_iter <= cfg.max_iter);
assert!(!r.elbo_history.is_empty());
}
#[test]
fn mean_pi_sums_to_one() {
let obs = simple_obs();
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
let s: f64 = r.mean_pi().iter().sum();
assert!((s - 1.0).abs() < 1e-10, "mean_pi sum = {s}");
}
#[test]
fn mean_a_rows_sum_to_one() {
let obs = simple_obs();
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
let n = cfg.n_states;
let a = r.mean_a();
for i in 0..n {
let s: f64 = a[i * n..(i + 1) * n].iter().sum();
assert!((s - 1.0).abs() < 1e-10, "mean_a row {i} sums to {s}");
}
}
#[test]
fn mean_b_rows_sum_to_one() {
let obs = simple_obs();
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
let n = cfg.n_states;
let k = cfg.n_obs;
let b = r.mean_b();
for j in 0..n {
let s: f64 = b[j * k..(j + 1) * k].iter().sum();
assert!((s - 1.0).abs() < 1e-10, "mean_b row {j} sums to {s}");
}
}
#[test]
fn elbo_history_non_decreasing() {
let obs = simple_obs();
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
if r.elbo_history.len() >= 2 {
let first = r.elbo_history[0];
let last = *r.elbo_history.last().expect("non-empty");
assert!(
last >= first - 2.0,
"Final ELBO ({last}) is much worse than initial ({first})"
);
}
for w in r.elbo_history.windows(2) {
assert!(
w[1] >= w[0] - 1.0,
"ELBO dropped by more than 1 nat: {} → {}",
w[0],
w[1]
);
}
}
#[test]
fn n_iter_within_max_iter() {
let obs = simple_obs();
let cfg = VbHmmConfig {
max_iter: 50,
..Default::default()
};
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
assert!(r.n_iter <= 50);
}
#[test]
fn posteriors_exceed_prior_when_data_given() {
let obs: Vec<usize> = (0..30).map(|i| i % 2).collect();
let cfg = VbHmmConfig {
alpha_prior: 1.0,
beta_prior: 1.0,
gamma_prior: 1.0,
..Default::default()
};
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
for &a in &r.alpha {
assert!(
a > cfg.alpha_prior,
"alpha {a} not > prior {}",
cfg.alpha_prior
);
}
}
#[test]
fn multiple_sequences_accepted() {
let seq1 = vec![0usize, 1, 0, 1];
let seq2 = vec![1usize, 1, 0, 0];
let seq3 = vec![0usize, 0, 1, 1, 0];
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[&seq1, &seq2, &seq3], &cfg).expect("ok");
assert!(!r.elbo_history.is_empty());
}
#[test]
fn empty_observations_returns_err() {
let cfg = VbHmmConfig::default();
assert!(variational_hmm(&[], &cfg).is_err());
}
#[test]
fn obs_out_of_range_returns_err() {
let obs = vec![0usize, 5]; let cfg = VbHmmConfig::default();
assert!(variational_hmm(&[&obs], &cfg).is_err());
}
#[test]
fn single_observation_length_one_works() {
let obs = vec![0usize];
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[&obs], &cfg).expect("length-1 seq should work");
assert!(!r.elbo_history.is_empty());
}
#[test]
fn converged_flag_set_on_tight_convergence() {
let obs: Vec<usize> = (0..50).map(|i| i % 2).collect();
let cfg = VbHmmConfig {
max_iter: 500,
tol: 1e-3,
..Default::default()
};
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
assert!(
r.converged,
"expected convergence with tol=1e-3 and 500 iterations"
);
}
#[test]
fn larger_state_space() {
let obs: Vec<usize> = (0..40).map(|i| i % 4).collect();
let cfg = VbHmmConfig {
n_states: 4,
n_obs: 4,
max_iter: 100,
..Default::default()
};
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
assert_eq!(r.alpha.len(), 4);
assert_eq!(r.beta.len(), 16);
assert_eq!(r.gamma.len(), 16);
}
#[test]
fn expected_log_pi_returns_correct_length() {
let obs = simple_obs();
let cfg = VbHmmConfig::default();
let r = variational_hmm(&[obs.as_slice()], &cfg).expect("ok");
let elp = r.expected_log_pi();
assert_eq!(elp.len(), cfg.n_states);
for &v in &elp {
assert!(v.is_finite(), "expected_log_pi entry is not finite: {v}");
assert!(v <= 0.0, "expected_log_pi entry should be ≤ 0: {v}");
}
}
}