use crate::error::{StatsError, StatsResult};
use super::types::{AdviConfig, AdviResult};
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((self.state >> 11) as f64) * (1.0 / (1u64 << 53) as f64)
}
fn randn(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-300);
let u2 = self.next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
}
#[derive(Debug, Clone)]
struct Adam {
m: Vec<f64>,
v: Vec<f64>,
t: usize,
beta1: f64,
beta2: f64,
eps: f64,
}
impl Adam {
fn new(n: usize) -> Self {
Self {
m: vec![0.0; n],
v: vec![0.0; n],
t: 0,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
}
}
fn step(&mut self, grad: &[f64], lr: f64) -> Vec<f64> {
self.t += 1;
let t = self.t as f64;
let bc1 = 1.0 - self.beta1.powf(t);
let bc2 = 1.0 - self.beta2.powf(t);
let mut delta = vec![0.0; grad.len()];
for i in 0..grad.len() {
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
delta[i] = lr * m_hat / (v_hat.sqrt() + self.eps);
}
delta
}
}
fn elbo_and_gradients(
log_joint_fn: &dyn Fn(&[f64]) -> f64,
mu: &[f64],
log_sigma: &[f64],
config: &AdviConfig,
rng: &mut Lcg,
) -> (f64, Vec<f64>, Vec<f64>) {
let n = mu.len();
let s = config.n_samples;
let h = config.fd_step;
let mut elbo_sum = 0.0;
let mut grad_mu = vec![0.0f64; n];
let mut grad_omega = vec![0.0f64; n];
for _ in 0..s {
let eps: Vec<f64> = (0..n).map(|_| rng.randn()).collect();
let theta: Vec<f64> = (0..n)
.map(|i| mu[i] + log_sigma[i].exp() * eps[i])
.collect();
let log_p = log_joint_fn(&theta);
elbo_sum += log_p;
let mut grad_logp = vec![0.0f64; n];
for j in 0..n {
let mut theta_fwd = theta.clone();
let mut theta_bwd = theta.clone();
theta_fwd[j] += h;
theta_bwd[j] -= h;
grad_logp[j] = (log_joint_fn(&theta_fwd) - log_joint_fn(&theta_bwd)) / (2.0 * h);
}
for i in 0..n {
grad_mu[i] += grad_logp[i];
grad_omega[i] += eps[i] * grad_logp[i];
}
}
let s_f = s as f64;
let log2pi = (2.0 * std::f64::consts::PI).ln();
for i in 0..n {
grad_mu[i] /= s_f;
grad_omega[i] /= s_f;
grad_omega[i] += 0.5;
}
let entropy: f64 = log_sigma.iter().map(|&w| 0.5 * (1.0 + log2pi + w)).sum();
let elbo = elbo_sum / s_f + entropy;
(elbo, grad_mu, grad_omega)
}
pub struct AdviOptimizer {
pub config: AdviConfig,
}
impl AdviOptimizer {
pub fn new(config: AdviConfig) -> Self {
Self { config }
}
pub fn default_config() -> Self {
Self {
config: AdviConfig::default(),
}
}
pub fn fit(
&self,
log_joint_fn: &dyn Fn(&[f64]) -> f64,
n_params: usize,
) -> StatsResult<AdviResult> {
if n_params == 0 {
return Err(StatsError::invalid_argument(
"n_params must be > 0 for ADVI",
));
}
let cfg = &self.config;
let mut mu = vec![0.0f64; n_params];
let mut log_sigma = vec![0.0f64; n_params];
let mut adam_mu = Adam::new(n_params);
let mut adam_omega = Adam::new(n_params);
let mut rng = Lcg::new(cfg.seed);
let mut elbo_history: Vec<f64> = Vec::with_capacity(cfg.n_iter);
let mut prev_elbo = f64::NEG_INFINITY;
let mut converged = false;
let mut n_iter_performed = 0;
for _iter in 0..cfg.n_iter {
let (elbo, grad_mu, grad_omega) =
elbo_and_gradients(log_joint_fn, &mu, &log_sigma, cfg, &mut rng);
if !elbo.is_finite() {
elbo_history.push(elbo);
n_iter_performed += 1;
break;
}
elbo_history.push(elbo);
n_iter_performed += 1;
let neg_grad_mu: Vec<f64> = grad_mu.iter().map(|&g| -g).collect();
let neg_grad_omega: Vec<f64> = grad_omega.iter().map(|&g| -g).collect();
let delta_mu = adam_mu.step(&neg_grad_mu, cfg.lr);
let delta_omega = adam_omega.step(&neg_grad_omega, cfg.lr);
for i in 0..n_params {
mu[i] -= delta_mu[i]; log_sigma[i] -= delta_omega[i];
}
if (elbo - prev_elbo).abs() < cfg.tol {
converged = true;
break;
}
prev_elbo = elbo;
}
Ok(AdviResult {
elbo_history,
mu,
log_sigma,
converged,
n_iter_performed,
})
}
}
pub fn sample_posterior(result: &AdviResult, n: usize, seed: u64) -> StatsResult<Vec<Vec<f64>>> {
let n_params = result.mu.len();
if n_params == 0 {
return Err(StatsError::invalid_argument(
"AdviResult has zero parameters",
));
}
if n == 0 {
return Ok(Vec::new());
}
let mut rng = Lcg::new(seed);
let sigma: Vec<f64> = result.log_sigma.iter().map(|&w| w.exp()).collect();
let samples = (0..n)
.map(|_| {
(0..n_params)
.map(|i| result.mu[i] + sigma[i] * rng.randn())
.collect()
})
.collect();
Ok(samples)
}
pub fn mean_field_entropy(log_sigma: &[f64]) -> f64 {
let log2pi = (2.0 * std::f64::consts::PI).ln();
log_sigma.iter().map(|&w| 0.5 * (1.0 + log2pi + w)).sum()
}
pub fn make_linear_regression_log_joint(
x_data: Vec<Vec<f64>>,
y_data: Vec<f64>,
noise_var: f64,
prior_precision: f64,
) -> impl Fn(&[f64]) -> f64 {
move |beta: &[f64]| {
let n = y_data.len();
let mut log_lik = 0.0;
for i in 0..n {
let mut pred = 0.0;
for (j, &bj) in beta.iter().enumerate() {
if j < x_data[i].len() {
pred += x_data[i][j] * bj;
}
}
let r = y_data[i] - pred;
log_lik -= 0.5 * r * r / noise_var;
}
let log_prior: f64 = beta.iter().map(|&b| -0.5 * prior_precision * b * b).sum();
log_lik + log_prior
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lcg_uniform_range() {
let mut rng = Lcg::new(123);
for _ in 0..1000 {
let v = rng.next_f64();
assert!(v >= 0.0 && v < 1.0, "out of range: {v}");
}
}
#[test]
fn test_adam_moment_update_first_iter() {
let mut adam = Adam::new(2);
let grad = vec![1.0, -2.0];
let delta = adam.step(&grad, 0.01);
assert!(
(delta[0] - 0.01).abs() < 1e-6,
"Adam step[0] = {} ≠ 0.01",
delta[0]
);
assert!((delta[1] - (-0.01)).abs() < 1e-6);
}
#[test]
fn test_advi_zero_params_error() {
let opt = AdviOptimizer::default_config();
let result = opt.fit(&|_: &[f64]| 0.0, 0);
assert!(result.is_err());
}
#[test]
fn test_advi_elbo_increases() {
let log_joint = |theta: &[f64]| {
let t = theta[0];
-0.5 * (t - 3.0) * (t - 3.0)
};
let config = AdviConfig {
n_iter: 200,
n_samples: 5,
lr: 0.05,
tol: 1e-8, seed: 1,
..AdviConfig::default()
};
let opt = AdviOptimizer::new(config);
let result = opt.fit(&log_joint, 1).expect("fit ok");
let n = result.elbo_history.len();
assert!(n > 5, "Should have at least 5 iterations");
let init_elbo = result.elbo_history[..5]
.iter()
.copied()
.filter(|e| e.is_finite())
.fold(f64::NEG_INFINITY, f64::max);
let final_elbo = result.elbo_history[(n - 5)..]
.iter()
.copied()
.filter(|e| e.is_finite())
.fold(f64::NEG_INFINITY, f64::max);
assert!(
final_elbo >= init_elbo - 1.0, "ELBO should not decrease significantly: init={init_elbo:.4}, final={final_elbo:.4}"
);
}
#[test]
fn test_advi_recovers_gaussian_mean() {
let log_joint = |theta: &[f64]| {
let t = theta[0];
-0.5 * (3.0 - t) * (3.0 - t) - 0.5 * t * t
};
let config = AdviConfig {
n_iter: 500,
n_samples: 10,
lr: 0.05,
tol: 1e-7,
seed: 7,
..AdviConfig::default()
};
let opt = AdviOptimizer::new(config);
let result = opt.fit(&log_joint, 1).expect("fit ok");
let mu = result.mu[0];
assert!((mu - 1.5).abs() < 0.5, "Expected mu ≈ 1.5, got {mu:.4}");
}
#[test]
fn test_advi_posterior_samples_shape() {
let config = AdviConfig {
n_iter: 10,
n_samples: 1,
seed: 42,
..AdviConfig::default()
};
let opt = AdviOptimizer::new(config);
let result = opt.fit(&|_: &[f64]| -1.0, 3).expect("fit ok");
let samples = sample_posterior(&result, 50, 99).expect("samples ok");
assert_eq!(samples.len(), 50);
for s in &samples {
assert_eq!(s.len(), 3);
}
}
#[test]
fn test_advi_converged_flag() {
let log_joint = |_: &[f64]| 0.0;
let config = AdviConfig {
n_iter: 2000,
n_samples: 1,
lr: 0.01,
tol: 1e-3,
seed: 5,
..AdviConfig::default()
};
let opt = AdviOptimizer::new(config);
let result = opt.fit(&log_joint, 2).expect("fit ok");
assert!(
result.converged || result.n_iter_performed == 2000,
"Should converge or exhaust iterations"
);
}
#[test]
fn test_advi_mean_field_entropy() {
let log_sigma = vec![0.0, 1.0, -1.0];
let entropy = mean_field_entropy(&log_sigma);
let log2pi = (2.0 * std::f64::consts::PI).ln();
let expected: f64 = log_sigma.iter().map(|&w| 0.5 * (1.0 + log2pi + w)).sum();
assert!((entropy - expected).abs() < 1e-12);
}
#[test]
fn test_sample_posterior_empty_params_error() {
let result = AdviResult {
elbo_history: vec![],
mu: vec![],
log_sigma: vec![],
converged: false,
n_iter_performed: 0,
};
assert!(sample_posterior(&result, 10, 1).is_err());
}
#[test]
fn test_sample_posterior_zero_samples() {
let result = AdviResult {
elbo_history: vec![-1.0],
mu: vec![0.0],
log_sigma: vec![0.0],
converged: false,
n_iter_performed: 1,
};
let samples = sample_posterior(&result, 0, 1).expect("ok");
assert!(samples.is_empty());
}
#[test]
fn test_make_linear_regression_log_joint() {
let x = vec![vec![1.0]];
let y = vec![2.0];
let log_joint = make_linear_regression_log_joint(x, y, 1.0, 0.01);
let lp = log_joint(&[2.0]);
assert!((lp - (-0.02)).abs() < 1e-10, "log_joint(β=2) = {lp}");
let lp0 = log_joint(&[0.0]);
assert!((lp0 - (-2.0)).abs() < 1e-10, "log_joint(β=0) = {lp0}");
}
}