use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2};
use std::f64::consts::PI;
use super::{PosteriorResult, VariationalInference};
#[derive(Debug, Clone)]
pub enum AdviTransform {
Identity,
Log,
Logit,
Bounded {
lower: f64,
upper: f64,
},
}
impl AdviTransform {
pub fn forward(&self, eta: f64) -> f64 {
match self {
AdviTransform::Identity => eta,
AdviTransform::Log => eta.exp(),
AdviTransform::Logit => 1.0 / (1.0 + (-eta).exp()),
AdviTransform::Bounded { lower, upper } => {
let s = 1.0 / (1.0 + (-eta).exp());
lower + (upper - lower) * s
}
}
}
pub fn inverse(&self, theta: f64) -> StatsResult<f64> {
match self {
AdviTransform::Identity => Ok(theta),
AdviTransform::Log => {
if theta <= 0.0 {
return Err(StatsError::InvalidArgument(format!(
"Log transform requires positive value, got {}",
theta
)));
}
Ok(theta.ln())
}
AdviTransform::Logit => {
if theta <= 0.0 || theta >= 1.0 {
return Err(StatsError::InvalidArgument(format!(
"Logit transform requires value in (0, 1), got {}",
theta
)));
}
Ok((theta / (1.0 - theta)).ln())
}
AdviTransform::Bounded { lower, upper } => {
if theta <= *lower || theta >= *upper {
return Err(StatsError::InvalidArgument(format!(
"Bounded transform requires value in ({}, {}), got {}",
lower, upper, theta
)));
}
let s = (theta - lower) / (upper - lower);
Ok((s / (1.0 - s)).ln())
}
}
}
pub fn log_det_jacobian(&self, eta: f64) -> f64 {
match self {
AdviTransform::Identity => 0.0,
AdviTransform::Log => eta,
AdviTransform::Logit => {
let sp = softplus(eta);
eta - 2.0 * sp
}
AdviTransform::Bounded { lower, upper } => {
let log_range = (upper - lower).ln();
let sp = softplus(eta);
log_range + eta - 2.0 * sp
}
}
}
pub fn grad_log_det_jacobian(&self, eta: f64) -> f64 {
match self {
AdviTransform::Identity => 0.0,
AdviTransform::Log => 1.0,
AdviTransform::Logit | AdviTransform::Bounded { .. } => {
let s = sigmoid(eta);
1.0 - 2.0 * s
}
}
}
pub fn forward_grad(&self, eta: f64) -> f64 {
match self {
AdviTransform::Identity => 1.0,
AdviTransform::Log => eta.exp(),
AdviTransform::Logit => {
let s = sigmoid(eta);
s * (1.0 - s)
}
AdviTransform::Bounded { lower, upper } => {
let s = sigmoid(eta);
(upper - lower) * s * (1.0 - s)
}
}
}
}
fn softplus(x: f64) -> f64 {
if x > 20.0 {
x
} else if x < -20.0 {
x.exp()
} else {
(1.0 + x.exp()).ln()
}
}
fn sigmoid(x: f64) -> f64 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AdviApproximation {
MeanField,
FullRank,
}
#[derive(Debug, Clone)]
struct AdviAdamState {
m: Array1<f64>,
v: Array1<f64>,
t: usize,
beta1: f64,
beta2: f64,
epsilon: f64,
}
impl AdviAdamState {
fn new(n_params: usize) -> Self {
Self {
m: Array1::zeros(n_params),
v: Array1::zeros(n_params),
t: 0,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
}
}
fn update(&mut self, grad: &Array1<f64>) -> Array1<f64> {
self.t += 1;
let n = grad.len();
let mut direction = Array1::zeros(n);
for i in 0..n {
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
let m_hat = self.m[i] / (1.0 - self.beta1.powi(self.t as i32));
let v_hat = self.v[i] / (1.0 - self.beta2.powi(self.t as i32));
direction[i] = m_hat / (v_hat.sqrt() + self.epsilon);
}
direction
}
}
#[derive(Debug, Clone)]
pub struct AdviConfig {
pub approximation: AdviApproximation,
pub transforms: Vec<AdviTransform>,
pub num_samples: usize,
pub learning_rate: f64,
pub max_iterations: usize,
pub tolerance: f64,
pub seed: u64,
pub convergence_window: usize,
}
impl Default for AdviConfig {
fn default() -> Self {
Self {
approximation: AdviApproximation::MeanField,
transforms: Vec::new(),
num_samples: 10,
learning_rate: 0.01,
max_iterations: 10000,
tolerance: 1e-4,
seed: 42,
convergence_window: 50,
}
}
}
#[derive(Debug, Clone)]
pub struct Advi {
pub config: AdviConfig,
}
impl Advi {
pub fn new(config: AdviConfig) -> Self {
Self { config }
}
fn generate_epsilon(&self, dim: usize, seed: u64) -> Array1<f64> {
let golden = 1.618033988749895_f64;
let plastic = 1.324717957244746_f64;
Array1::from_shape_fn(dim, |i| {
let u1 = ((seed as f64 * golden + i as f64 * plastic) % 1.0).abs();
let u2 = ((seed as f64 * plastic + i as f64 * golden) % 1.0).abs();
let u1 = u1.max(1e-10).min(1.0 - 1e-10);
let u2 = u2.max(1e-10).min(1.0 - 1e-10);
let r = (-2.0 * u1.ln()).sqrt();
r * (2.0 * PI * u2).cos()
})
}
fn get_transform(&self, i: usize) -> &AdviTransform {
if i < self.config.transforms.len() {
&self.config.transforms[i]
} else {
&AdviTransform::Identity
}
}
fn transform_to_constrained(&self, eta: &Array1<f64>) -> Array1<f64> {
Array1::from_shape_fn(eta.len(), |i| self.get_transform(i).forward(eta[i]))
}
fn total_log_det_jacobian(&self, eta: &Array1<f64>) -> f64 {
(0..eta.len())
.map(|i| self.get_transform(i).log_det_jacobian(eta[i]))
.sum()
}
fn grad_log_det_jacobian(&self, eta: &Array1<f64>) -> Array1<f64> {
Array1::from_shape_fn(eta.len(), |i| {
self.get_transform(i).grad_log_det_jacobian(eta[i])
})
}
fn forward_grad(&self, eta: &Array1<f64>) -> Array1<f64> {
Array1::from_shape_fn(eta.len(), |i| self.get_transform(i).forward_grad(eta[i]))
}
fn fit_mean_field<F>(&self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
where
F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
{
let n_params = 2 * dim;
let mut mu = Array1::zeros(dim);
let mut log_sigma = Array1::zeros(dim);
let mut adam = AdviAdamState::new(n_params);
let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
let mut converged = false;
for iter in 0..self.config.max_iterations {
let mut elbo_sum = 0.0;
let mut grad_mu_sum = Array1::zeros(dim);
let mut grad_log_sigma_sum = Array1::zeros(dim);
for s in 0..self.config.num_samples {
let seed = self
.config
.seed
.wrapping_add(iter as u64 * 1000)
.wrapping_add(s as u64);
let epsilon = self.generate_epsilon(dim, seed);
let sigma = log_sigma.mapv(f64::exp);
let eta = &mu + &(&sigma * &epsilon);
let theta = self.transform_to_constrained(&eta);
let (log_p, grad_theta) = log_joint(&theta)?;
let ldj = self.total_log_det_jacobian(&eta);
let grad_ldj = self.grad_log_det_jacobian(&eta);
let fwd_grad = self.forward_grad(&eta);
let grad_eta: Array1<f64> =
Array1::from_shape_fn(dim, |i| grad_theta[i] * fwd_grad[i] + grad_ldj[i]);
let elbo_s = log_p + ldj;
elbo_sum += elbo_s;
for i in 0..dim {
grad_mu_sum[i] += grad_eta[i];
grad_log_sigma_sum[i] += grad_eta[i] * sigma[i] * epsilon[i];
}
}
let n_s = self.config.num_samples as f64;
elbo_sum /= n_s;
grad_mu_sum /= n_s;
grad_log_sigma_sum /= n_s;
for i in 0..dim {
grad_log_sigma_sum[i] += 1.0;
}
let entropy: f64 = (0..dim)
.map(|i| 0.5 * (1.0 + (2.0 * PI).ln()) + log_sigma[i])
.sum();
elbo_sum += entropy;
elbo_history.push(elbo_sum);
let mut full_grad = Array1::zeros(n_params);
for i in 0..dim {
full_grad[i] = grad_mu_sum[i];
full_grad[dim + i] = grad_log_sigma_sum[i];
}
let direction = adam.update(&full_grad);
let lr = self.config.learning_rate;
for i in 0..dim {
mu[i] += lr * direction[i];
log_sigma[i] += lr * direction[dim + i];
log_sigma[i] = log_sigma[i].max(-10.0).min(10.0);
}
if elbo_history.len() >= self.config.convergence_window {
let n = elbo_history.len();
let w = self.config.convergence_window;
let recent_avg: f64 =
elbo_history[n - w / 2..n].iter().sum::<f64>() / (w / 2) as f64;
let earlier_avg: f64 =
elbo_history[n - w..n - w / 2].iter().sum::<f64>() / (w / 2) as f64;
if (recent_avg - earlier_avg).abs() < self.config.tolerance {
converged = true;
break;
}
}
}
let sigma = log_sigma.mapv(f64::exp);
let constrained_means = self.transform_to_constrained(&mu);
let fwd_grad = self.forward_grad(&mu);
let constrained_stds = Array1::from_shape_fn(dim, |i| (fwd_grad[i] * sigma[i]).abs());
Ok(PosteriorResult {
means: constrained_means,
std_devs: constrained_stds,
elbo_history: elbo_history.clone(),
iterations: elbo_history.len(),
converged,
samples: None,
})
}
fn fit_full_rank<F>(&self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
where
F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
{
let n_tril = dim * (dim + 1) / 2;
let n_params = dim + n_tril;
let mut mu = Array1::zeros(dim);
let mut l_entries = Array1::zeros(n_tril);
{
let mut idx = 0;
for row in 0..dim {
for col in 0..=row {
if row == col {
l_entries[idx] = 1.0; }
idx += 1;
}
}
}
let mut adam = AdviAdamState::new(n_params);
let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
let mut converged = false;
for iter in 0..self.config.max_iterations {
let l_mat = tril_to_matrix(dim, &l_entries);
let mut elbo_sum = 0.0;
let mut grad_mu_sum = Array1::zeros(dim);
let mut grad_l_sum = Array1::zeros(n_tril);
for s in 0..self.config.num_samples {
let seed = self
.config
.seed
.wrapping_add(iter as u64 * 1000)
.wrapping_add(s as u64);
let epsilon = self.generate_epsilon(dim, seed);
let l_eps = l_mat.dot(&epsilon);
let eta = &mu + &l_eps;
let theta = self.transform_to_constrained(&eta);
let (log_p, grad_theta) = log_joint(&theta)?;
let ldj = self.total_log_det_jacobian(&eta);
let grad_ldj = self.grad_log_det_jacobian(&eta);
let fwd_grad = self.forward_grad(&eta);
let grad_eta: Array1<f64> =
Array1::from_shape_fn(dim, |i| grad_theta[i] * fwd_grad[i] + grad_ldj[i]);
let elbo_s = log_p + ldj;
elbo_sum += elbo_s;
for i in 0..dim {
grad_mu_sum[i] += grad_eta[i];
}
let mut idx = 0;
for row in 0..dim {
for col in 0..=row {
grad_l_sum[idx] += grad_eta[row] * epsilon[col];
idx += 1;
}
}
}
let n_s = self.config.num_samples as f64;
elbo_sum /= n_s;
grad_mu_sum /= n_s;
grad_l_sum /= n_s;
let mut entropy = 0.5 * dim as f64 * (1.0 + (2.0 * PI).ln());
{
let mut idx = 0;
for row in 0..dim {
for col in 0..=row {
if row == col {
entropy += l_entries[idx].abs().max(1e-15).ln();
let l_ii = l_entries[idx];
if l_ii.abs() > 1e-15 {
grad_l_sum[idx] += 1.0 / l_ii;
}
}
idx += 1;
}
}
}
elbo_sum += entropy;
elbo_history.push(elbo_sum);
let mut full_grad = Array1::zeros(n_params);
for i in 0..dim {
full_grad[i] = grad_mu_sum[i];
}
for i in 0..n_tril {
full_grad[dim + i] = grad_l_sum[i];
}
let direction = adam.update(&full_grad);
let lr = self.config.learning_rate;
for i in 0..dim {
mu[i] += lr * direction[i];
}
for i in 0..n_tril {
l_entries[i] += lr * direction[dim + i];
}
{
let mut idx = 0;
for row in 0..dim {
for col in 0..=row {
if row == col {
l_entries[idx] = l_entries[idx].abs().max(1e-6);
}
l_entries[idx] = l_entries[idx].max(-10.0).min(10.0);
idx += 1;
}
}
}
if elbo_history.len() >= self.config.convergence_window {
let n = elbo_history.len();
let w = self.config.convergence_window;
let recent_avg: f64 =
elbo_history[n - w / 2..n].iter().sum::<f64>() / (w / 2) as f64;
let earlier_avg: f64 =
elbo_history[n - w..n - w / 2].iter().sum::<f64>() / (w / 2) as f64;
if (recent_avg - earlier_avg).abs() < self.config.tolerance {
converged = true;
break;
}
}
}
let l_mat = tril_to_matrix(dim, &l_entries);
let constrained_means = self.transform_to_constrained(&mu);
let cov = l_mat.dot(&l_mat.t());
let fwd_grad = self.forward_grad(&mu);
let constrained_stds =
Array1::from_shape_fn(dim, |i| (fwd_grad[i] * fwd_grad[i] * cov[[i, i]]).sqrt());
Ok(PosteriorResult {
means: constrained_means,
std_devs: constrained_stds,
elbo_history: elbo_history.clone(),
iterations: elbo_history.len(),
converged,
samples: None,
})
}
}
impl VariationalInference for Advi {
fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
where
F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
{
if dim == 0 {
return Err(StatsError::InvalidArgument(
"Dimension must be at least 1".to_string(),
));
}
if self.config.num_samples == 0 {
return Err(StatsError::InvalidArgument(
"num_samples must be at least 1".to_string(),
));
}
if self.config.learning_rate <= 0.0 {
return Err(StatsError::InvalidArgument(
"learning_rate must be positive".to_string(),
));
}
match self.config.approximation {
AdviApproximation::MeanField => self.fit_mean_field(log_joint, dim),
AdviApproximation::FullRank => self.fit_full_rank(log_joint, dim),
}
}
}
fn tril_to_matrix(dim: usize, entries: &Array1<f64>) -> Array2<f64> {
let mut mat = Array2::zeros((dim, dim));
let mut idx = 0;
for row in 0..dim {
for col in 0..=row {
mat[[row, col]] = entries[idx];
idx += 1;
}
}
mat
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advi_gaussian_posterior_recovery() {
let data_mean = 3.0_f64;
let n_data = 10.0_f64;
let prior_mean = 0.0_f64;
let prior_var = 1.0_f64;
let lik_var = 1.0_f64;
let config = AdviConfig {
approximation: AdviApproximation::MeanField,
transforms: vec![AdviTransform::Identity],
num_samples: 20,
learning_rate: 0.05,
max_iterations: 3000,
tolerance: 1e-5,
seed: 123,
convergence_window: 100,
};
let mut advi = Advi::new(config);
let result = advi
.fit(
move |theta: &Array1<f64>| {
let mu = theta[0];
let log_prior = -0.5 * (mu - prior_mean).powi(2) / prior_var;
let log_lik = -n_data / 2.0 * (mu - data_mean).powi(2) / lik_var;
let log_p = log_prior + log_lik;
let grad_prior = -(mu - prior_mean) / prior_var;
let grad_lik = -n_data * (mu - data_mean) / lik_var;
let grad = Array1::from_vec(vec![grad_prior + grad_lik]);
Ok((log_p, grad))
},
1,
)
.expect("ADVI should not fail");
let expected_mean = (n_data * data_mean / lik_var + prior_mean / prior_var)
/ (n_data / lik_var + 1.0 / prior_var);
let expected_std = (1.0 / (n_data / lik_var + 1.0 / prior_var)).sqrt();
assert!(
(result.means[0] - expected_mean).abs() < 0.3,
"Mean should be close to {}, got {}",
expected_mean,
result.means[0]
);
assert!(
(result.std_devs[0] - expected_std).abs() < 0.2,
"Std should be close to {}, got {}",
expected_std,
result.std_devs[0]
);
}
#[test]
fn test_advi_elbo_increases() {
let config = AdviConfig {
approximation: AdviApproximation::MeanField,
transforms: vec![AdviTransform::Identity, AdviTransform::Identity],
num_samples: 15,
learning_rate: 0.02,
max_iterations: 500,
tolerance: 1e-6,
seed: 77,
convergence_window: 50,
};
let mut advi = Advi::new(config);
let result = advi
.fit(
|theta: &Array1<f64>| {
let diff0 = theta[0] - 1.0;
let diff1 = theta[1] - 2.0;
let log_p = -0.5 * (diff0 * diff0 + diff1 * diff1);
let grad = Array1::from_vec(vec![-diff0, -diff1]);
Ok((log_p, grad))
},
2,
)
.expect("ADVI should succeed");
let n = result.elbo_history.len();
assert!(n > 100, "Should run at least 100 iterations");
let early_avg: f64 = result.elbo_history[..50].iter().sum::<f64>() / 50.0;
let late_avg: f64 = result.elbo_history[n - 50..].iter().sum::<f64>() / 50.0;
assert!(
late_avg > early_avg - 1.0,
"Late ELBO ({}) should be higher than early ({})",
late_avg,
early_avg
);
}
#[test]
fn test_advi_mean_field_vs_full_rank() {
let rho = 0.8_f64;
let log_joint = move |theta: &Array1<f64>| {
let x = theta[0];
let y = theta[1];
let det = 1.0 - rho * rho;
let log_p =
-0.5 / det * (x * x - 2.0 * rho * x * y + y * y) - 0.5 * (2.0 * PI * det).ln();
let gx = -1.0 / det * (x - rho * y);
let gy = -1.0 / det * (y - rho * x);
Ok((log_p, Array1::from_vec(vec![gx, gy])))
};
let mf_config = AdviConfig {
approximation: AdviApproximation::MeanField,
num_samples: 20,
learning_rate: 0.02,
max_iterations: 2000,
tolerance: 1e-5,
seed: 42,
convergence_window: 100,
..Default::default()
};
let mut mf_advi = Advi::new(mf_config);
let mf_result = mf_advi.fit(log_joint, 2).expect("MF should succeed");
let fr_config = AdviConfig {
approximation: AdviApproximation::FullRank,
num_samples: 20,
learning_rate: 0.02,
max_iterations: 2000,
tolerance: 1e-5,
seed: 42,
convergence_window: 100,
..Default::default()
};
let mut fr_advi = Advi::new(fr_config);
let fr_result = fr_advi.fit(log_joint, 2).expect("FR should succeed");
let mf_final_elbo = mf_result
.elbo_history
.last()
.copied()
.unwrap_or(f64::NEG_INFINITY);
let fr_final_elbo = fr_result
.elbo_history
.last()
.copied()
.unwrap_or(f64::NEG_INFINITY);
assert!(
fr_final_elbo > mf_final_elbo - 1.0,
"Full-rank ELBO ({}) should be >= mean-field ELBO ({}) minus tolerance",
fr_final_elbo,
mf_final_elbo
);
}
#[test]
fn test_advi_log_transform() {
let config = AdviConfig {
approximation: AdviApproximation::MeanField,
transforms: vec![AdviTransform::Log],
num_samples: 20,
learning_rate: 0.01,
max_iterations: 3000,
tolerance: 1e-5,
seed: 55,
convergence_window: 100,
};
let mut advi = Advi::new(config);
let result = advi
.fit(
|theta: &Array1<f64>| {
let x = theta[0];
if x <= 0.0 {
return Ok((f64::NEG_INFINITY, Array1::zeros(1)));
}
let log_p = 2.0 * x.ln() - x - (2.0_f64).ln(); let grad = Array1::from_vec(vec![2.0 / x - 1.0]);
Ok((log_p, grad))
},
1,
)
.expect("ADVI with log transform should succeed");
assert!(
result.means[0] > 0.0,
"Mean should be positive with log transform"
);
assert!(
(result.means[0] - 3.0).abs() < 1.5,
"Mean should be near 3 (Gamma(3,1) mean), got {}",
result.means[0]
);
}
#[test]
fn test_advi_zero_dim_error() {
let mut advi = Advi::new(AdviConfig::default());
let result = advi.fit(|_theta: &Array1<f64>| Ok((0.0, Array1::zeros(0))), 0);
assert!(result.is_err());
}
#[test]
fn test_transform_roundtrip() {
let transforms = vec![
AdviTransform::Identity,
AdviTransform::Log,
AdviTransform::Logit,
AdviTransform::Bounded {
lower: -2.0,
upper: 5.0,
},
];
let test_vals = vec![1.5, 2.0, 0.3, 1.0];
for (t, v) in transforms.iter().zip(test_vals.iter()) {
let eta = t.inverse(*v).expect("inverse should succeed");
let recovered = t.forward(eta);
assert!(
(recovered - v).abs() < 1e-10,
"Roundtrip failed for {:?}: {} -> {} -> {}",
t,
v,
eta,
recovered
);
}
}
#[test]
fn test_log_det_jacobian_nonzero() {
let transforms = vec![
AdviTransform::Log,
AdviTransform::Logit,
AdviTransform::Bounded {
lower: 0.0,
upper: 10.0,
},
];
for t in &transforms {
let ldj = t.log_det_jacobian(0.5);
assert!(
ldj.is_finite(),
"Log-det-Jacobian should be finite for {:?}",
t
);
}
}
}