use crate::error::{StatsError, StatsResult as Result};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct BbviConfig {
pub n_samples: usize,
pub learning_rate: f64,
pub n_iter: usize,
pub use_vimco: bool,
pub baseline_decay: f64,
pub seed: u64,
pub tol: f64,
}
impl Default for BbviConfig {
fn default() -> Self {
BbviConfig {
n_samples: 10,
learning_rate: 0.01,
n_iter: 1000,
use_vimco: false,
baseline_decay: 0.9,
seed: 42,
tol: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct ElboEstimate<F: Float> {
pub value: F,
pub std_error: F,
pub gradient_variance: F,
}
#[derive(Debug, Clone)]
pub struct BbviResult<F: Float> {
pub elbo_history: Vec<F>,
pub final_params: Vec<F>,
pub n_iters: usize,
pub converged: bool,
}
pub trait VariationalDistribution<F: Float + FromPrimitive + Clone + Debug> {
fn sample(&self, seed: u64) -> Vec<F>;
fn log_prob(&self, z: &[F]) -> F;
fn params(&self) -> &[F];
fn params_mut(&mut self) -> &mut Vec<F>;
fn score_function(&self, z: &[F]) -> Vec<F>;
fn latent_dim(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct MeanFieldGaussian<F: Float + FromPrimitive + Clone + Debug> {
params: Vec<F>,
dim: usize,
}
impl<F: Float + FromPrimitive + Clone + Debug> MeanFieldGaussian<F> {
pub fn new(dim: usize) -> Self {
let mut params = vec![F::zero(); 2 * dim];
for i in dim..(2 * dim) {
params[i] = F::zero();
}
MeanFieldGaussian { params, dim }
}
pub fn from_params(mu: Vec<F>, log_sigma: Vec<F>) -> Result<Self> {
if mu.len() != log_sigma.len() {
return Err(StatsError::DimensionMismatch(format!(
"mu length {} != log_sigma length {}",
mu.len(),
log_sigma.len()
)));
}
let dim = mu.len();
let mut params = Vec::with_capacity(2 * dim);
params.extend_from_slice(&mu);
params.extend_from_slice(&log_sigma);
Ok(MeanFieldGaussian { params, dim })
}
pub fn mean(&self) -> &[F] {
&self.params[..self.dim]
}
pub fn log_sigma(&self) -> &[F] {
&self.params[self.dim..]
}
}
struct LcgPrng {
state: u64,
}
impl LcgPrng {
fn new(seed: u64) -> Self {
LcgPrng { state: seed ^ 6364136223846793005 }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.state
}
fn uniform(&mut self) -> f64 {
let bits = self.next_u64() >> 11;
(bits as f64 + 0.5) / (1u64 << 53) as f64
}
fn normal_pair(&mut self) -> (f64, f64) {
use std::f64::consts::PI;
let u1 = self.uniform();
let u2 = self.uniform();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * PI * u2;
(r * theta.cos(), r * theta.sin())
}
}
impl<F: Float + FromPrimitive + Clone + Debug> VariationalDistribution<F>
for MeanFieldGaussian<F>
{
fn sample(&self, seed: u64) -> Vec<F> {
let mut rng = LcgPrng::new(seed);
let mut z = Vec::with_capacity(self.dim);
let mut i = 0;
while i < self.dim {
let (n1, n2) = rng.normal_pair();
let sigma_i = self.params[self.dim + i].exp();
let mu_i = self.params[i];
z.push(mu_i + sigma_i * F::from_f64(n1).unwrap_or(F::zero()));
if i + 1 < self.dim {
let sigma_i1 = self.params[self.dim + i + 1].exp();
let mu_i1 = self.params[i + 1];
z.push(mu_i1 + sigma_i1 * F::from_f64(n2).unwrap_or(F::zero()));
}
i += 2;
}
z.truncate(self.dim);
z
}
fn log_prob(&self, z: &[F]) -> F {
if z.len() != self.dim {
return F::neg_infinity();
}
let two_pi = F::from_f64(2.0 * std::f64::consts::PI).unwrap_or(F::one());
let half = F::from_f64(0.5).unwrap_or(F::one());
let mut log_p = F::zero();
for i in 0..self.dim {
let mu = self.params[i];
let log_s = self.params[self.dim + i];
let sigma = log_s.exp();
let diff = z[i] - mu;
log_p = log_p - half * two_pi.ln() - log_s - half * (diff * diff) / (sigma * sigma);
}
log_p
}
fn params(&self) -> &[F] {
&self.params
}
fn params_mut(&mut self) -> &mut Vec<F> {
&mut self.params
}
fn score_function(&self, z: &[F]) -> Vec<F> {
let mut grad = vec![F::zero(); 2 * self.dim];
for i in 0..self.dim {
let mu = self.params[i];
let log_s = self.params[self.dim + i];
let sigma2 = (log_s * F::from_f64(2.0).unwrap_or(F::one())).exp();
let diff = z[i] - mu;
grad[i] = diff / sigma2;
grad[self.dim + i] = (diff * diff) / sigma2 - F::one();
}
grad
}
fn latent_dim(&self) -> usize {
self.dim
}
}
fn log_sum_exp<F: Float + FromPrimitive>(values: &[F]) -> F {
if values.is_empty() {
return F::neg_infinity();
}
let max_val = values.iter().cloned().fold(F::neg_infinity(), F::max);
if max_val.is_infinite() {
return F::neg_infinity();
}
let sum_exp = values
.iter()
.fold(F::zero(), |acc, &v| acc + (v - max_val).exp());
max_val + sum_exp.ln()
}
pub fn reinforce_gradient<F>(
q: &dyn VariationalDistribution<F>,
log_joint_fn: &dyn Fn(&[F]) -> F,
config: &BbviConfig,
) -> Result<(ElboEstimate<F>, Vec<F>)>
where
F: Float + FromPrimitive + Clone + Debug,
{
let n = config.n_samples;
let n_params = q.params().len();
if n == 0 {
return Err(StatsError::InvalidArgument(
"n_samples must be > 0".to_string(),
));
}
let mut rewards: Vec<F> = Vec::with_capacity(n);
let mut samples: Vec<Vec<F>> = Vec::with_capacity(n);
let mut scores: Vec<Vec<F>> = Vec::with_capacity(n);
for i in 0..n {
let seed = config.seed.wrapping_add(i as u64).wrapping_mul(2654435761);
let z = q.sample(seed);
let log_q = q.log_prob(&z);
let log_p = log_joint_fn(&z);
let reward = log_p - log_q;
let score = q.score_function(&z);
rewards.push(reward);
samples.push(z);
scores.push(score);
}
let n_f = F::from_usize(n).ok_or_else(|| {
StatsError::ComputationError("Cannot convert n_samples to F".to_string())
})?;
let reward_sum = rewards.iter().cloned().fold(F::zero(), |acc, v| acc + v);
let baseline = reward_sum / n_f;
let reward_var = {
let sq_sum = rewards
.iter()
.fold(F::zero(), |acc, &r| acc + (r - baseline) * (r - baseline));
if n > 1 {
sq_sum / F::from_usize(n - 1).unwrap_or(F::one())
} else {
F::zero()
}
};
let std_err = if n > 1 {
(reward_var / n_f).sqrt()
} else {
F::zero()
};
let mut gradient = vec![F::zero(); n_params];
let mut per_sample_norms: Vec<F> = Vec::with_capacity(n);
for i in 0..n {
let weight = rewards[i] - baseline;
let mut norm_sq = F::zero();
for j in 0..n_params {
let g_ij = weight * scores[i][j];
gradient[j] = gradient[j] + g_ij;
norm_sq = norm_sq + g_ij * g_ij;
}
per_sample_norms.push(norm_sq.sqrt());
}
for g in gradient.iter_mut() {
*g = *g / n_f;
}
let norm_mean = per_sample_norms.iter().cloned().fold(F::zero(), |a, v| a + v) / n_f;
let grad_var = per_sample_norms
.iter()
.fold(F::zero(), |acc, &v| acc + (v - norm_mean) * (v - norm_mean))
/ n_f;
let elbo = ElboEstimate {
value: baseline, std_error: if std_err < F::zero() { F::zero() } else { std_err },
gradient_variance: if grad_var < F::zero() { F::zero() } else { grad_var },
};
Ok((elbo, gradient))
}
pub fn vimco_gradient<F>(
q: &dyn VariationalDistribution<F>,
log_joint_fn: &dyn Fn(&[F]) -> F,
config: &BbviConfig,
) -> Result<(ElboEstimate<F>, Vec<F>)>
where
F: Float + FromPrimitive + Clone + Debug,
{
let k = config.n_samples;
let n_params = q.params().len();
if k == 0 {
return Err(StatsError::InvalidArgument(
"n_samples must be > 0".to_string(),
));
}
if k == 1 {
return reinforce_gradient(q, log_joint_fn, config);
}
let k_f = F::from_usize(k).ok_or_else(|| {
StatsError::ComputationError("Cannot convert k to F".to_string())
})?;
let mut log_weights: Vec<F> = Vec::with_capacity(k);
let mut scores: Vec<Vec<F>> = Vec::with_capacity(k);
for j in 0..k {
let seed = config.seed.wrapping_add(j as u64).wrapping_mul(2654435761);
let z = q.sample(seed);
let log_q = q.log_prob(&z);
let log_p = log_joint_fn(&z);
log_weights.push(log_p - log_q);
scores.push(q.score_function(&z));
}
let log_k = k_f.ln();
let lse_all = log_sum_exp(&log_weights);
let elbo_val = lse_all - log_k;
let log_w_normalized: Vec<F> = log_weights.iter().map(|&w| w - lse_all).collect();
let k_minus1 = k - 1;
let log_k_m1 = F::from_usize(k_minus1)
.ok_or_else(|| StatsError::ComputationError("k-1 overflow".to_string()))?
.ln();
let mut loo_baselines: Vec<F> = Vec::with_capacity(k);
for j in 0..k {
let loo_vals: Vec<F> = log_weights
.iter()
.enumerate()
.filter(|(i, _)| *i != j)
.map(|(_, &w)| w)
.collect();
let loo_lse = if loo_vals.is_empty() {
F::neg_infinity()
} else {
log_sum_exp(&loo_vals)
};
loo_baselines.push(loo_lse - log_k_m1);
}
let mut gradient = vec![F::zero(); n_params];
let mut per_sample_norms: Vec<F> = Vec::with_capacity(k);
for j in 0..k {
let vimco_weight = log_w_normalized[j] - loo_baselines[j];
let mut norm_sq = F::zero();
for p_idx in 0..n_params {
let g = vimco_weight * scores[j][p_idx];
gradient[p_idx] = gradient[p_idx] + g;
norm_sq = norm_sq + g * g;
}
per_sample_norms.push(norm_sq.sqrt());
}
for g in gradient.iter_mut() {
*g = *g / k_f;
}
let norm_mean = per_sample_norms.iter().cloned().fold(F::zero(), |a, v| a + v) / k_f;
let grad_var = per_sample_norms
.iter()
.fold(F::zero(), |acc, &v| acc + (v - norm_mean) * (v - norm_mean))
/ k_f;
let lw_mean = log_weights.iter().cloned().fold(F::zero(), |a, v| a + v) / k_f;
let lw_var = log_weights
.iter()
.fold(F::zero(), |acc, &w| acc + (w - lw_mean) * (w - lw_mean))
/ k_f;
let std_err = (lw_var / k_f).sqrt();
let elbo = ElboEstimate {
value: elbo_val,
std_error: if std_err < F::zero() { F::zero() } else { std_err },
gradient_variance: if grad_var < F::zero() { F::zero() } else { grad_var },
};
Ok((elbo, gradient))
}
struct AdamState<F: Float> {
m: Vec<F>, v: Vec<F>, t: usize, beta1: F,
beta2: F,
epsilon: F,
}
impl<F: Float + FromPrimitive> AdamState<F> {
fn new(n_params: usize) -> Self {
AdamState {
m: vec![F::zero(); n_params],
v: vec![F::zero(); n_params],
t: 0,
beta1: F::from_f64(0.9).unwrap_or(F::one()),
beta2: F::from_f64(0.999).unwrap_or(F::one()),
epsilon: F::from_f64(1e-8).unwrap_or(F::zero()),
}
}
fn step(&mut self, grad: &[F], lr: F) -> Vec<F> {
self.t += 1;
let t_f = F::from_usize(self.t).unwrap_or(F::one());
let one = F::one();
let mut delta = vec![F::zero(); grad.len()];
for i in 0..grad.len() {
self.m[i] = self.beta1 * self.m[i] + (one - self.beta1) * grad[i];
self.v[i] = self.beta2 * self.v[i] + (one - self.beta2) * grad[i] * grad[i];
let m_hat = self.m[i] / (one - self.beta1.powf(t_f));
let v_hat = self.v[i] / (one - self.beta2.powf(t_f));
delta[i] = lr * m_hat / (v_hat.sqrt() + self.epsilon);
}
delta
}
}
pub fn bbvi_optimize<F>(
q: &mut dyn VariationalDistribution<F>,
log_joint_fn: &dyn Fn(&[F]) -> F,
config: &BbviConfig,
) -> Result<BbviResult<F>>
where
F: Float + FromPrimitive + Clone + Debug,
{
let n_params = q.params().len();
let lr = F::from_f64(config.learning_rate).ok_or_else(|| {
StatsError::InvalidArgument("learning_rate cannot be represented as F".to_string())
})?;
let tol = F::from_f64(config.tol).unwrap_or(F::from_f64(1e-6).unwrap_or(F::zero()));
let mut adam = AdamState::new(n_params);
let mut elbo_history: Vec<F> = Vec::with_capacity(config.n_iter);
let mut prev_elbo = F::neg_infinity();
let mut converged = false;
let mut iter_config = config.clone();
for iter in 0..config.n_iter {
iter_config.seed = config
.seed
.wrapping_add(iter as u64)
.wrapping_mul(6364136223846793005);
let (elbo_est, gradient) = if config.use_vimco {
vimco_gradient(q, log_joint_fn, &iter_config)?
} else {
reinforce_gradient(q, log_joint_fn, &iter_config)?
};
elbo_history.push(elbo_est.value);
let delta = adam.step(&gradient, lr);
let params = q.params_mut();
for i in 0..n_params.min(delta.len()) {
params[i] = params[i] + delta[i];
}
let elbo_change = (elbo_est.value - prev_elbo).abs();
if iter > 10 && elbo_change < tol {
converged = true;
break;
}
prev_elbo = elbo_est.value;
}
let final_params = q.params().to_vec();
let n_iters = elbo_history.len();
Ok(BbviResult {
elbo_history,
final_params,
n_iters,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_gaussian_log_joint(z: &[f64]) -> f64 {
let mut lp = 0.0f64;
for &zi in z {
let diff = zi - 1.0;
lp -= 0.5 * (diff * diff) / 0.25 + 0.5 * (2.0 * std::f64::consts::PI * 0.25).ln();
}
lp
}
#[test]
fn test_mean_field_gaussian_sample_shape() {
let q = MeanFieldGaussian::<f64>::new(5);
let z = q.sample(42);
assert_eq!(z.len(), 5, "sample should have same dim as latent space");
}
#[test]
fn test_mean_field_gaussian_log_prob_finite() {
let q = MeanFieldGaussian::<f64>::new(3);
let z = q.sample(1);
let lp = q.log_prob(&z);
assert!(lp.is_finite(), "log_prob should be finite for a valid sample");
}
#[test]
fn test_mean_field_gaussian_params_length() {
let dim = 4;
let q = MeanFieldGaussian::<f64>::new(dim);
assert_eq!(q.params().len(), 2 * dim, "params should have length 2*dim");
}
#[test]
fn test_reinforce_gradient_output_shape() {
let q = MeanFieldGaussian::<f64>::new(3);
let config = BbviConfig { n_samples: 5, ..Default::default() };
let result = reinforce_gradient(&q, &simple_gaussian_log_joint, &config);
assert!(result.is_ok());
let (_, grad) = result.unwrap();
assert_eq!(grad.len(), q.params().len());
}
#[test]
fn test_elbo_estimate_std_error_non_negative() {
let q = MeanFieldGaussian::<f64>::new(2);
let config = BbviConfig { n_samples: 20, ..Default::default() };
let (elbo, _) = reinforce_gradient(&q, &simple_gaussian_log_joint, &config).unwrap();
assert!(elbo.std_error >= 0.0, "std_error must be non-negative");
}
#[test]
fn test_vimco_k1_equivalent_to_reinforce() {
let q1 = MeanFieldGaussian::<f64>::new(2);
let config = BbviConfig {
n_samples: 1,
use_vimco: false,
seed: 99,
..Default::default()
};
let (e1, g1) = reinforce_gradient(&q1, &simple_gaussian_log_joint, &config).unwrap();
let q2 = MeanFieldGaussian::<f64>::new(2);
let (e2, g2) = vimco_gradient(&q2, &simple_gaussian_log_joint, &config).unwrap();
assert!(
(e1.value - e2.value).abs() < 1e-10,
"k=1 VIMCO should match REINFORCE: {} vs {}",
e1.value,
e2.value
);
for (a, b) in g1.iter().zip(g2.iter()) {
assert!((a - b).abs() < 1e-10, "gradients should match for k=1");
}
}
#[test]
fn test_vimco_gradient_output_shape() {
let q = MeanFieldGaussian::<f64>::new(4);
let config = BbviConfig {
n_samples: 8,
use_vimco: true,
..Default::default()
};
let result = vimco_gradient(&q, &simple_gaussian_log_joint, &config);
assert!(result.is_ok());
let (_, grad) = result.unwrap();
assert_eq!(grad.len(), q.params().len());
}
#[test]
fn test_bbvi_optimize_elbo_increases() {
let mut q = MeanFieldGaussian::<f64>::new(2);
let config = BbviConfig {
n_samples: 20,
learning_rate: 0.05,
n_iter: 200,
use_vimco: false,
seed: 7,
tol: 1e-8,
..Default::default()
};
let result = bbvi_optimize(&mut q, &simple_gaussian_log_joint, &config).unwrap();
assert!(!result.elbo_history.is_empty());
let n = result.elbo_history.len();
let first_half_avg: f64 = result.elbo_history[..n / 4].iter().sum::<f64>() / (n / 4) as f64;
let second_half_avg: f64 =
result.elbo_history[3 * n / 4..].iter().sum::<f64>() / (n / 4).max(1) as f64;
assert!(
second_half_avg > first_half_avg - 5.0,
"ELBO should generally increase: early={}, late={}",
first_half_avg,
second_half_avg
);
}
#[test]
fn test_bbvi_config_defaults() {
let cfg = BbviConfig::default();
assert_eq!(cfg.n_samples, 10);
assert!((cfg.learning_rate - 0.01).abs() < 1e-12);
assert_eq!(cfg.n_iter, 1000);
assert!(!cfg.use_vimco);
assert!((cfg.baseline_decay - 0.9).abs() < 1e-12);
assert_eq!(cfg.seed, 42);
}
#[test]
fn test_vimco_variance_reduction() {
let config_reinforce = BbviConfig {
n_samples: 10,
use_vimco: false,
seed: 100,
..Default::default()
};
let config_vimco = BbviConfig {
n_samples: 10,
use_vimco: true,
seed: 100,
..Default::default()
};
let q = MeanFieldGaussian::<f64>::new(3);
let (e_r, _) =
reinforce_gradient(&q, &simple_gaussian_log_joint, &config_reinforce).unwrap();
let (e_v, _) =
vimco_gradient(&q, &simple_gaussian_log_joint, &config_vimco).unwrap();
assert!(e_r.gradient_variance.is_finite());
assert!(e_v.gradient_variance.is_finite());
assert!(e_v.gradient_variance >= 0.0);
}
#[test]
fn test_bbvi_result_elbo_history_length() {
let mut q = MeanFieldGaussian::<f64>::new(2);
let config = BbviConfig {
n_iter: 50,
tol: 0.0, ..Default::default()
};
let result = bbvi_optimize(&mut q, &simple_gaussian_log_joint, &config).unwrap();
assert_eq!(result.n_iters, result.elbo_history.len());
assert!(result.n_iters <= 50);
}
#[test]
fn test_log_sum_exp_stability() {
let vals = vec![1000.0f64, 1001.0, 999.0];
let result = log_sum_exp(&vals);
assert!(result.is_finite(), "logsumexp should be stable for large values");
let expected = 1001.0 + (1.0 + (-1.0f64).exp() + (-2.0f64).exp()).ln();
assert!((result - expected).abs() < 1e-10, "logsumexp value incorrect: {} vs {}", result, expected);
}
#[test]
fn test_mean_field_gaussian_from_params() {
let mu = vec![1.0f64, 2.0, 3.0];
let log_sigma = vec![0.0f64, -0.5, 0.5];
let q = MeanFieldGaussian::<f64>::from_params(mu.clone(), log_sigma.clone()).unwrap();
assert_eq!(q.latent_dim(), 3);
assert_eq!(q.mean(), &mu[..]);
assert_eq!(q.log_sigma(), &log_sigma[..]);
}
#[test]
fn test_reinforce_gradient_finite() {
let q = MeanFieldGaussian::<f64>::new(4);
let config = BbviConfig { n_samples: 15, ..Default::default() };
let (elbo, grad) = reinforce_gradient(&q, &simple_gaussian_log_joint, &config).unwrap();
assert!(elbo.value.is_finite(), "ELBO should be finite");
for g in &grad {
assert!(g.is_finite(), "gradient should be finite");
}
}
}