use crate::error::{StatsError, StatsResult};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum ConjugateModel {
BetaBernoulli,
GammaPoisson,
NormalNormal,
DirichletMultinomial,
}
#[derive(Debug, Clone)]
pub struct OnlineVbConfig {
pub model: ConjugateModel,
pub learning_rate: f64,
pub forgetting_factor: f64,
pub delay: f64,
pub mini_batch_size: usize,
}
impl Default for OnlineVbConfig {
fn default() -> Self {
Self {
model: ConjugateModel::BetaBernoulli,
learning_rate: 1.0,
forgetting_factor: 0.7,
delay: 1.0,
mini_batch_size: 32,
}
}
}
#[derive(Debug, Clone)]
pub struct OnlineVbState {
pub natural_params: Vec<f64>,
pub n_processed: usize,
pub elbo_estimate: f64,
}
pub struct OnlineVbEstimator {
config: OnlineVbConfig,
state: OnlineVbState,
prior_natural: Vec<f64>,
total_n: usize,
step: usize,
}
impl OnlineVbEstimator {
pub fn new(config: OnlineVbConfig, prior_params: &[f64], total_n: usize) -> StatsResult<Self> {
if total_n == 0 {
return Err(StatsError::InvalidArgument(
"total_n must be > 0".to_string(),
));
}
if config.forgetting_factor <= 0.5 || config.forgetting_factor > 1.0 {
return Err(StatsError::InvalidArgument(
"forgetting_factor (kappa) must be in (0.5, 1.0]".to_string(),
));
}
if config.delay < 0.0 {
return Err(StatsError::InvalidArgument(
"delay (tau) must be >= 0".to_string(),
));
}
let prior_natural = moment_to_natural(&config.model, prior_params)?;
let natural_params = prior_natural.clone();
let state = OnlineVbState {
natural_params,
n_processed: 0,
elbo_estimate: f64::NEG_INFINITY,
};
Ok(Self {
config,
state,
prior_natural,
total_n,
step: 0,
})
}
pub fn update(&mut self, data: &[f64]) -> StatsResult<f64> {
if data.is_empty() {
return Err(StatsError::InsufficientData(
"mini-batch must not be empty".to_string(),
));
}
self.step += 1;
let rho = self.learning_rate_at(self.step);
let scale = self.total_n as f64 / data.len() as f64;
let ess = expected_sufficient_stats(&self.config.model, &self.state.natural_params, data)?;
let n_variational = match self.config.model {
ConjugateModel::NormalNormal => 2,
_ => self.prior_natural.len(),
};
let mut lambda_tilde = vec![0.0_f64; self.prior_natural.len()];
for i in 0..self.prior_natural.len() {
lambda_tilde[i] = self.prior_natural[i]; }
for i in 0..n_variational {
let ess_i = *ess.get(i).unwrap_or(&0.0);
lambda_tilde[i] = self.prior_natural[i] + scale * ess_i;
}
for i in 0..n_variational {
self.state.natural_params[i] =
(1.0 - rho) * self.state.natural_params[i] + rho * lambda_tilde[i];
}
for i in n_variational..self.prior_natural.len() {
self.state.natural_params[i] = self.prior_natural[i];
}
self.state.n_processed += data.len();
let elbo = estimate_elbo(
&self.config.model,
&self.state.natural_params,
&self.prior_natural,
data,
scale,
)?;
if self.state.elbo_estimate.is_finite() {
self.state.elbo_estimate = 0.9 * self.state.elbo_estimate + 0.1 * elbo;
} else {
self.state.elbo_estimate = elbo;
}
Ok(elbo)
}
pub fn posterior_params(&self) -> Vec<f64> {
natural_to_moment(&self.config.model, &self.state.natural_params)
}
pub fn predict(&self) -> f64 {
let mp = self.posterior_params();
match self.config.model {
ConjugateModel::BetaBernoulli => {
let alpha = mp.first().copied().unwrap_or(1.0);
let beta = mp.get(1).copied().unwrap_or(1.0);
let denom = alpha + beta;
if denom > f64::EPSILON {
alpha / denom
} else {
0.5
}
}
ConjugateModel::GammaPoisson => {
let a = mp.first().copied().unwrap_or(1.0);
let b = mp.get(1).copied().unwrap_or(1.0);
if b > f64::EPSILON {
a / b
} else {
0.0
}
}
ConjugateModel::NormalNormal => mp.first().copied().unwrap_or(0.0),
ConjugateModel::DirichletMultinomial => {
if mp.is_empty() {
return 0.0;
}
let sum: f64 = mp.iter().sum();
if sum > f64::EPSILON {
mp[0] / sum
} else {
1.0 / mp.len() as f64
}
}
_ => 0.0,
}
}
pub fn state(&self) -> &OnlineVbState {
&self.state
}
pub fn config(&self) -> &OnlineVbConfig {
&self.config
}
fn learning_rate_at(&self, t: usize) -> f64 {
let t_f = t as f64;
let tau = self.config.delay;
let kappa = self.config.forgetting_factor;
let rho = self.config.learning_rate * (t_f + tau).powf(-kappa);
rho.clamp(1e-10, 1.0)
}
}
fn moment_to_natural(model: &ConjugateModel, params: &[f64]) -> StatsResult<Vec<f64>> {
match model {
ConjugateModel::BetaBernoulli => {
if params.len() < 2 {
return Err(StatsError::InvalidArgument(
"BetaBernoulli prior requires [alpha, beta]".to_string(),
));
}
let alpha = params[0];
let beta = params[1];
if alpha <= 0.0 || beta <= 0.0 {
return Err(StatsError::InvalidArgument(
"BetaBernoulli prior: alpha, beta must be > 0".to_string(),
));
}
Ok(vec![alpha - 1.0, beta - 1.0])
}
ConjugateModel::GammaPoisson => {
if params.len() < 2 {
return Err(StatsError::InvalidArgument(
"GammaPoisson prior requires [a, b]".to_string(),
));
}
let a = params[0];
let b = params[1];
if a <= 0.0 || b <= 0.0 {
return Err(StatsError::InvalidArgument(
"GammaPoisson prior: a, b must be > 0".to_string(),
));
}
Ok(vec![a - 1.0, -b])
}
ConjugateModel::NormalNormal => {
if params.len() < 3 {
return Err(StatsError::InvalidArgument(
"NormalNormal prior requires [mu0, sigma0^2, sigma^2]".to_string(),
));
}
let mu0 = params[0];
let sigma0_sq = params[1];
let sigma_sq = params[2];
if sigma0_sq <= 0.0 || sigma_sq <= 0.0 {
return Err(StatsError::InvalidArgument(
"NormalNormal prior: sigma0^2 and sigma^2 must be > 0".to_string(),
));
}
Ok(vec![mu0 / sigma0_sq, -1.0 / (2.0 * sigma0_sq), sigma_sq])
}
ConjugateModel::DirichletMultinomial => {
if params.is_empty() {
return Err(StatsError::InvalidArgument(
"DirichletMultinomial prior requires at least one alpha".to_string(),
));
}
for (k, &a) in params.iter().enumerate() {
if a <= 0.0 {
return Err(StatsError::InvalidArgument(format!(
"DirichletMultinomial prior: alpha[{}] must be > 0",
k
)));
}
}
Ok(params.iter().map(|&a| a - 1.0).collect())
}
_ => Err(StatsError::NotImplementedError(
"Unsupported conjugate model variant".to_string(),
)),
}
}
fn natural_to_moment(model: &ConjugateModel, eta: &[f64]) -> Vec<f64> {
match model {
ConjugateModel::BetaBernoulli => {
let alpha = (eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let beta = (eta.get(1).copied().unwrap_or(0.0) + 1.0).max(1e-10);
vec![alpha, beta]
}
ConjugateModel::GammaPoisson => {
let a = (eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let b = (-eta.get(1).copied().unwrap_or(-1.0)).max(1e-10);
vec![a, b]
}
ConjugateModel::NormalNormal => {
let eta1 = eta.first().copied().unwrap_or(0.0);
let eta2 = eta.get(1).copied().unwrap_or(-0.5);
let sigma_sq = if eta2 < -f64::EPSILON {
-1.0 / (2.0 * eta2)
} else {
1.0
};
let mu = eta1 * sigma_sq;
vec![mu, sigma_sq]
}
ConjugateModel::DirichletMultinomial => {
eta.iter().map(|&e| (e + 1.0).max(1e-10)).collect()
}
_ => vec![],
}
}
fn expected_sufficient_stats(
model: &ConjugateModel,
eta: &[f64],
data: &[f64],
) -> StatsResult<Vec<f64>> {
let n = data.len() as f64;
match model {
ConjugateModel::BetaBernoulli => {
let sum_x: f64 = data.iter().sum();
Ok(vec![sum_x, n - sum_x])
}
ConjugateModel::GammaPoisson => {
for &x in data {
if x < 0.0 {
return Err(StatsError::DomainError(
"GammaPoisson: observations must be non-negative".to_string(),
));
}
}
let sum_x: f64 = data.iter().sum();
Ok(vec![sum_x, n])
}
ConjugateModel::NormalNormal => {
let sigma_sq = if eta.len() >= 3 {
eta[2].max(f64::EPSILON)
} else {
1.0
};
let sum_x: f64 = data.iter().sum();
Ok(vec![sum_x / sigma_sq, -n / (2.0 * sigma_sq)])
}
ConjugateModel::DirichletMultinomial => {
let k = eta.len();
if k == 0 {
return Err(StatsError::InvalidArgument(
"DirichletMultinomial: variational params have length 0".to_string(),
));
}
let mut counts = vec![0.0_f64; k];
for &x in data {
let idx = x as usize;
if idx >= k {
return Err(StatsError::DomainError(format!(
"DirichletMultinomial: category index {} >= K={}",
idx, k
)));
}
counts[idx] += 1.0;
}
Ok(counts)
}
_ => Err(StatsError::NotImplementedError(
"Unsupported model variant".to_string(),
)),
}
}
fn estimate_elbo(
model: &ConjugateModel,
eta: &[f64],
prior_eta: &[f64],
data: &[f64],
scale: f64,
) -> StatsResult<f64> {
let expected_log_likelihood = compute_expected_log_likelihood(model, eta, data)?;
let kl = compute_kl(model, eta, prior_eta)?;
Ok(scale * expected_log_likelihood - kl)
}
fn compute_expected_log_likelihood(
model: &ConjugateModel,
eta: &[f64],
data: &[f64],
) -> StatsResult<f64> {
match model {
ConjugateModel::BetaBernoulli => {
let alpha = (eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let beta = (eta.get(1).copied().unwrap_or(0.0) + 1.0).max(1e-10);
let e_log_theta = digamma(alpha) - digamma(alpha + beta);
let e_log_1m_theta = digamma(beta) - digamma(alpha + beta);
let sum_x: f64 = data.iter().sum();
let n = data.len() as f64;
Ok(sum_x * e_log_theta + (n - sum_x) * e_log_1m_theta)
}
ConjugateModel::GammaPoisson => {
let a = (eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let b = (-eta.get(1).copied().unwrap_or(-1.0)).max(1e-10);
let e_log_lambda = digamma(a) - b.ln();
let e_lambda = a / b;
let n = data.len() as f64;
let sum_x: f64 = data.iter().sum();
let sum_log_fact: f64 = data.iter().map(|&x| log_factorial(x as u64)).sum();
Ok(sum_x * e_log_lambda - n * e_lambda - sum_log_fact)
}
ConjugateModel::NormalNormal => {
let sigma_sq = if eta.len() >= 3 {
eta[2].max(f64::EPSILON)
} else {
1.0
};
let eta1 = eta.first().copied().unwrap_or(0.0);
let eta2 = eta.get(1).copied().unwrap_or(-0.5);
let sigma_q_sq = if eta2 < -f64::EPSILON {
-1.0 / (2.0 * eta2)
} else {
1.0
};
let mu = eta1 * sigma_q_sq;
let n = data.len() as f64;
let sum_x: f64 = data.iter().sum();
let sum_x_sq: f64 = data.iter().map(|&x| x * x).sum();
let e_theta_sq = mu * mu + sigma_q_sq;
let ll = -0.5 * n * (2.0 * std::f64::consts::PI * sigma_sq).ln()
- 1.0 / (2.0 * sigma_sq) * (sum_x_sq - 2.0 * sum_x * mu + n * e_theta_sq);
Ok(ll)
}
ConjugateModel::DirichletMultinomial => {
let alpha: Vec<f64> = eta.iter().map(|&e| (e + 1.0).max(1e-10)).collect();
let sum_alpha: f64 = alpha.iter().sum();
let k = alpha.len();
let mut counts = vec![0.0_f64; k];
for &x in data {
let idx = x as usize;
if idx < k {
counts[idx] += 1.0;
}
}
let ll: f64 = alpha
.iter()
.enumerate()
.map(|(i, &a)| counts[i] * (digamma(a) - digamma(sum_alpha)))
.sum();
Ok(ll)
}
_ => Ok(0.0),
}
}
fn compute_kl(model: &ConjugateModel, eta: &[f64], prior_eta: &[f64]) -> StatsResult<f64> {
match model {
ConjugateModel::BetaBernoulli => {
let alpha_q = (eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let beta_q = (eta.get(1).copied().unwrap_or(0.0) + 1.0).max(1e-10);
let alpha_0 = (prior_eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let beta_0 = (prior_eta.get(1).copied().unwrap_or(0.0) + 1.0).max(1e-10);
Ok(kl_beta(alpha_q, beta_q, alpha_0, beta_0))
}
ConjugateModel::GammaPoisson => {
let a_q = (eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let b_q = (-eta.get(1).copied().unwrap_or(-1.0)).max(1e-10);
let a_0 = (prior_eta.first().copied().unwrap_or(0.0) + 1.0).max(1e-10);
let b_0 = (-prior_eta.get(1).copied().unwrap_or(-1.0)).max(1e-10);
Ok(kl_gamma(a_q, b_q, a_0, b_0))
}
ConjugateModel::NormalNormal => {
let eta1_q = eta.first().copied().unwrap_or(0.0);
let eta2_q = eta.get(1).copied().unwrap_or(-0.5);
let sigma_q_sq = if eta2_q < -f64::EPSILON {
-1.0 / (2.0 * eta2_q)
} else {
1.0
};
let mu_q = eta1_q * sigma_q_sq;
let eta1_0 = prior_eta.first().copied().unwrap_or(0.0);
let eta2_0 = prior_eta.get(1).copied().unwrap_or(-0.5);
let sigma_0_sq = if eta2_0 < -f64::EPSILON {
-1.0 / (2.0 * eta2_0)
} else {
1.0
};
let mu_0 = eta1_0 * sigma_0_sq;
Ok(kl_normal(mu_q, sigma_q_sq, mu_0, sigma_0_sq))
}
ConjugateModel::DirichletMultinomial => {
let alpha_q: Vec<f64> = eta.iter().map(|&e| (e + 1.0).max(1e-10)).collect();
let alpha_0: Vec<f64> = prior_eta.iter().map(|&e| (e + 1.0).max(1e-10)).collect();
Ok(kl_dirichlet(&alpha_q, &alpha_0))
}
_ => Ok(0.0),
}
}
fn kl_beta(a1: f64, b1: f64, a2: f64, b2: f64) -> f64 {
lgamma(a1 + b1) - lgamma(a1) - lgamma(b1) - lgamma(a2 + b2)
+ lgamma(a2)
+ lgamma(b2)
+ (a1 - a2) * digamma(a1)
+ (b1 - b2) * digamma(b1)
+ (a2 - a1 + b2 - b1) * digamma(a1 + b1)
}
fn kl_gamma(a1: f64, b1: f64, a2: f64, b2: f64) -> f64 {
(a1 - a2) * digamma(a1) - lgamma(a1) + lgamma(a2) + a2 * (b1 / b2).ln()
- a1 * (b1.ln() - b2.ln())
+ a1 * (b2 - b1) / b2
}
fn kl_normal(mu1: f64, s1: f64, mu2: f64, s2: f64) -> f64 {
let s1 = s1.max(f64::EPSILON);
let s2 = s2.max(f64::EPSILON);
0.5 * ((s1 / s2).ln() + s1 / s2 + (mu1 - mu2).powi(2) / s2 - 1.0)
}
fn kl_dirichlet(alpha_q: &[f64], alpha_0: &[f64]) -> f64 {
let sum_q: f64 = alpha_q.iter().sum();
let sum_0: f64 = alpha_0.iter().sum();
let mut kl = lgamma(sum_q) - lgamma(sum_0);
for i in 0..alpha_q.len().min(alpha_0.len()) {
kl += lgamma(alpha_0[i]) - lgamma(alpha_q[i]);
kl += (alpha_q[i] - alpha_0[i]) * (digamma(alpha_q[i]) - digamma(sum_q));
}
kl.max(0.0) }
fn lgamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::INFINITY;
}
lanczos_lgamma(x)
}
fn lanczos_lgamma(x: f64) -> f64 {
const G: f64 = 7.0;
const C: [f64; 9] = [
0.999_999_999_999_809_93,
676.520_368_121_885_10,
-1_259.139_216_722_402_9,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_312e-7,
];
if x < 0.5 {
return std::f64::consts::PI.ln()
- (std::f64::consts::PI * x).sin().abs().ln()
- lanczos_lgamma(1.0 - x);
}
let x = x - 1.0;
let mut a = C[0];
let t = x + G + 0.5;
for (i, &c) in C.iter().enumerate().skip(1) {
a += c / (x + i as f64);
}
0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * t.ln() - t + a.ln()
}
fn digamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
if x < 6.0 {
return digamma(x + 1.0) - 1.0 / x;
}
let x2 = x * x;
x.ln() - 0.5 / x - 1.0 / (12.0 * x2) + 1.0 / (120.0 * x2 * x2) - 1.0 / (252.0 * x2 * x2 * x2)
}
fn log_factorial(n: u64) -> f64 {
if n <= 1 {
return 0.0;
}
lgamma(n as f64 + 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_beta_bernoulli(alpha0: f64, beta0: f64) -> OnlineVbEstimator {
let config = OnlineVbConfig {
model: ConjugateModel::BetaBernoulli,
forgetting_factor: 0.7,
delay: 1.0,
learning_rate: 1.0,
mini_batch_size: 32,
};
OnlineVbEstimator::new(config, &[alpha0, beta0], 1000).expect("valid estimator")
}
#[test]
fn test_online_vb_beta_bernoulli_converges() {
let mut est = make_beta_bernoulli(1.0, 1.0);
let mut rng_state: u64 = 42;
for _ in 0..200 {
let batch: Vec<f64> = (0..32)
.map(|_| {
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
if (rng_state >> 33) % 100 < 70 {
1.0
} else {
0.0
}
})
.collect();
est.update(&batch).expect("update ok");
}
let pred = est.predict();
assert!((pred - 0.7).abs() < 0.1, "predict={pred}, expected ~0.7");
}
#[test]
fn test_online_vb_elbo_finite() {
let mut est = make_beta_bernoulli(2.0, 2.0);
let batch = vec![1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0];
let elbo = est.update(&batch).expect("update ok");
assert!(elbo.is_finite(), "ELBO should be finite");
}
#[test]
fn test_online_vb_forgetting_factor_stability() {
let config_stable = OnlineVbConfig {
model: ConjugateModel::BetaBernoulli,
forgetting_factor: 1.0,
delay: 1.0,
learning_rate: 1.0,
mini_batch_size: 16,
};
let mut est =
OnlineVbEstimator::new(config_stable, &[1.0, 1.0], 500).expect("valid estimator");
let batch = vec![1.0; 16];
for _ in 0..50 {
est.update(&batch).expect("update ok");
}
let pp = est.posterior_params();
assert!(
pp[0] > pp[1],
"alpha should dominate after all-ones batches"
);
}
#[test]
fn test_online_vb_posterior_params_shape() {
let est = make_beta_bernoulli(2.0, 3.0);
let pp = est.posterior_params();
assert_eq!(pp.len(), 2);
assert!(pp[0] > 0.0 && pp[1] > 0.0, "alpha, beta must be > 0");
}
#[test]
fn test_online_vb_empty_batch_error() {
let mut est = make_beta_bernoulli(1.0, 1.0);
let result = est.update(&[]);
assert!(result.is_err(), "empty batch should return error");
}
#[test]
fn test_online_vb_invalid_prior_error() {
let config = OnlineVbConfig {
model: ConjugateModel::BetaBernoulli,
..Default::default()
};
let result = OnlineVbEstimator::new(config, &[-1.0, 1.0], 100);
assert!(result.is_err(), "negative alpha should fail");
}
#[test]
fn test_online_vb_normal_normal_mean_convergence() {
let config = OnlineVbConfig {
model: ConjugateModel::NormalNormal,
forgetting_factor: 0.7,
delay: 1.0,
learning_rate: 1.0,
mini_batch_size: 20,
};
let mut est =
OnlineVbEstimator::new(config, &[0.0, 10.0, 1.0], 2000).expect("valid estimator");
let mut rng_state: u64 = 123;
for _ in 0..100 {
let batch: Vec<f64> = (0..20)
.map(|_| {
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u = (rng_state >> 11) as f64 / (1u64 << 53) as f64;
5.0 + (u - 0.5) * 2.0 })
.collect();
est.update(&batch).expect("update ok");
}
let pred = est.predict();
assert!((pred - 5.0).abs() < 1.0, "predict={pred}, expected ~5.0");
}
#[test]
fn test_online_vb_normal_normal_posterior_shape() {
let config = OnlineVbConfig {
model: ConjugateModel::NormalNormal,
..Default::default()
};
let est = OnlineVbEstimator::new(config, &[0.0, 1.0, 0.5], 100).expect("valid estimator");
let pp = est.posterior_params();
assert_eq!(pp.len(), 2, "NormalNormal posterior has [mu, sigma^2]");
assert!(pp[1] > 0.0, "posterior variance must be positive");
}
#[test]
fn test_online_vb_gamma_poisson_rate_estimation() {
let config = OnlineVbConfig {
model: ConjugateModel::GammaPoisson,
forgetting_factor: 0.7,
delay: 1.0,
learning_rate: 1.0,
mini_batch_size: 20,
};
let mut est = OnlineVbEstimator::new(config, &[1.0, 1.0], 2000).expect("valid estimator");
let mut rng_state: u64 = 999;
for _ in 0..100 {
let batch: Vec<f64> = (0..20)
.map(|_| {
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u: f64 = (rng_state >> 11) as f64 / (1u64 << 53) as f64;
((-3.0f64).exp() * (3.0f64).powi(((u * 7.0) as i32).min(10)).max(1e-300)
/ lgamma(((u * 7.0) as i32 + 1).min(11) as f64).exp())
.min(10.0)
.max(0.0)
.round()
})
.collect();
let _ = est.update(&batch);
}
let pp = est.posterior_params();
assert!(pp[0] > 0.0, "a > 0");
assert!(pp[1] > 0.0, "b > 0");
}
#[test]
fn test_online_vb_dirichlet_multinomial_proportions() {
let config = OnlineVbConfig {
model: ConjugateModel::DirichletMultinomial,
forgetting_factor: 0.7,
delay: 1.0,
learning_rate: 1.0,
mini_batch_size: 30,
};
let mut est =
OnlineVbEstimator::new(config, &[1.0, 1.0, 1.0], 3000).expect("valid estimator");
let mut rng_state: u64 = 77;
for _ in 0..100 {
let batch: Vec<f64> = (0..30)
.map(|_| {
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u = (rng_state >> 11) % 100;
if u < 60 {
0.0
} else if u < 90 {
1.0
} else {
2.0
}
})
.collect();
est.update(&batch).expect("update ok");
}
let pp = est.posterior_params();
assert_eq!(pp.len(), 3);
assert!(pp[0] > pp[1], "alpha[0] > alpha[1] for 60% vs 30%");
assert!(pp[1] > pp[2], "alpha[1] > alpha[2] for 30% vs 10%");
}
#[test]
fn test_online_vb_invalid_kappa_error() {
let config = OnlineVbConfig {
model: ConjugateModel::BetaBernoulli,
forgetting_factor: 0.3, ..Default::default()
};
let result = OnlineVbEstimator::new(config, &[1.0, 1.0], 100);
assert!(result.is_err());
}
#[test]
fn test_online_vb_elbo_trend() {
let mut est = make_beta_bernoulli(1.0, 1.0);
let batch: Vec<f64> = vec![1.0; 16]; let mut first_elbo = f64::NEG_INFINITY;
let mut last_elbo = f64::NEG_INFINITY;
for i in 0..100 {
let elbo = est.update(&batch).expect("update ok");
if i == 0 {
first_elbo = elbo;
}
last_elbo = elbo;
}
assert!(last_elbo >= first_elbo - 1.0, "ELBO should trend upward");
}
#[test]
fn test_online_vb_n_processed_tracking() {
let mut est = make_beta_bernoulli(1.0, 1.0);
let batch = vec![1.0; 10];
est.update(&batch).expect("update ok");
est.update(&batch).expect("update ok");
assert_eq!(est.state().n_processed, 20);
}
#[test]
fn test_digamma_known_values() {
let psi1 = digamma(1.0);
assert!((psi1 - (-0.5772156649)).abs() < 1e-6, "ψ(1) ≈ -0.5772");
let psi2 = digamma(2.0);
assert!((psi2 - 0.4227843351).abs() < 1e-6, "ψ(2) ≈ 0.4228");
}
#[test]
fn test_lgamma_known_values() {
assert!(lgamma(1.0).abs() < 1e-10);
assert!(lgamma(2.0).abs() < 1e-10);
assert!((lgamma(3.0) - 2.0f64.ln()).abs() < 1e-10);
assert!((lgamma(0.5) - 0.5 * std::f64::consts::PI.ln()).abs() < 1e-8);
}
}