use crate::error::{StatsError, StatsResult};
use scirs2_core::random::rngs::SmallRng;
use scirs2_core::random::{RngExt, SeedableRng};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GradientEstimator {
Reinforce,
Vimco,
PathwiseDifferentiable,
}
#[derive(Debug, Clone)]
pub struct BbviConfig {
pub n_samples: usize,
pub n_vimco_samples: usize,
pub learning_rate: f64,
pub max_iter: usize,
pub tol: f64,
pub use_baseline: bool,
pub estimator: GradientEstimator,
pub seed: u64,
pub fd_step: f64,
pub baseline_decay: f64,
pub adam_beta1: f64,
pub adam_beta2: f64,
pub adam_eps: f64,
}
impl Default for BbviConfig {
fn default() -> Self {
BbviConfig {
n_samples: 10,
n_vimco_samples: 5,
learning_rate: 0.01,
max_iter: 1000,
tol: 1e-5,
use_baseline: true,
estimator: GradientEstimator::Reinforce,
seed: 42,
fd_step: 1e-4,
baseline_decay: 0.9,
adam_beta1: 0.9,
adam_beta2: 0.999,
adam_eps: 1e-8,
}
}
}
pub trait BbviModel {
fn log_joint(&self, z: &[f64], data: &[f64]) -> f64;
fn sample_variational(&self, params: &[f64], rng: &mut dyn RngCore64) -> Vec<f64>;
fn log_variational(&self, z: &[f64], params: &[f64]) -> f64;
fn n_params(&self) -> usize;
}
pub trait RngCore64 {
fn next_f64(&mut self) -> f64;
fn next_normal(&mut self) -> f64;
}
pub struct RngAdapter<R: RngExt>(pub R);
impl<R: RngExt> RngCore64 for RngAdapter<R> {
fn next_f64(&mut self) -> f64 {
self.0.random::<f64>()
}
fn next_normal(&mut self) -> f64 {
use std::f64::consts::TAU;
let u1 = (self.0.random::<f64>()).max(1e-300);
let u2 = self.0.random::<f64>();
let r = (-2.0_f64 * u1.ln()).sqrt();
r * (TAU * u2).cos()
}
}
#[derive(Debug, Clone)]
pub struct BbviResult {
pub variational_params: Vec<f64>,
pub elbo_history: Vec<f64>,
pub n_iter: usize,
pub converged: bool,
pub gradient_variance: Vec<f64>,
}
struct AdamState {
m: Vec<f64>, v: Vec<f64>, t: usize, beta1: f64,
beta2: f64,
eps: f64,
lr: f64,
}
impl AdamState {
fn new(n_params: usize, lr: f64, beta1: f64, beta2: f64, eps: f64) -> Self {
AdamState {
m: vec![0.0; n_params],
v: vec![0.0; n_params],
t: 0,
beta1,
beta2,
eps,
lr,
}
}
fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) {
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..params.len() {
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
params[i] += self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
}
}
pub struct BbviSolver<M: BbviModel> {
pub model: M,
pub config: BbviConfig,
}
impl<M: BbviModel> BbviSolver<M> {
pub fn new(model: M, config: BbviConfig) -> Self {
BbviSolver { model, config }
}
pub fn fit(&mut self, data: &[f64], init_params: &[f64]) -> StatsResult<BbviResult> {
let n_params = self.model.n_params();
if init_params.len() != n_params {
return Err(StatsError::DimensionMismatch(format!(
"init_params length {} != model.n_params() {}",
init_params.len(),
n_params
)));
}
let mut params = init_params.to_vec();
let mut adam = AdamState::new(
n_params,
self.config.learning_rate,
self.config.adam_beta1,
self.config.adam_beta2,
self.config.adam_eps,
);
let mut rng: RngAdapter<SmallRng> = RngAdapter(SmallRng::seed_from_u64(self.config.seed));
let mut elbo_history = Vec::with_capacity(self.config.max_iter);
let mut baseline = 0.0_f64;
let mut prev_elbo = f64::NEG_INFINITY;
let mut converged = false;
let mut gradient_variance = vec![0.0_f64; n_params];
for iter in 0..self.config.max_iter {
let (grad, elbo) = match &self.config.estimator {
GradientEstimator::Reinforce => self.reinforce_gradient(
¶ms,
data,
&mut rng,
&mut baseline,
self.config.n_samples,
),
GradientEstimator::Vimco => {
self.vimco_gradient(¶ms, data, &mut rng, self.config.n_vimco_samples)
}
GradientEstimator::PathwiseDifferentiable => {
self.pathwise_gradient(¶ms, data, &mut rng, self.config.n_samples)
}
};
elbo_history.push(elbo);
if iter == 0 {
gradient_variance = vec![0.0_f64; n_params];
} else {
for i in 0..n_params {
let delta = grad[i] - gradient_variance[i];
gradient_variance[i] += delta / (iter + 1) as f64;
}
}
adam.step(&mut params, &grad);
let elbo_change = (elbo - prev_elbo).abs();
if iter > 5 && elbo_change < self.config.tol {
converged = true;
break;
}
prev_elbo = elbo;
}
let n_iter_actual = elbo_history.len();
Ok(BbviResult {
variational_params: params,
elbo_history,
n_iter: if converged {
n_iter_actual
} else {
self.config.max_iter
},
converged,
gradient_variance,
})
}
fn reinforce_gradient(
&self,
params: &[f64],
data: &[f64],
rng: &mut dyn RngCore64,
baseline: &mut f64,
n_samples: usize,
) -> (Vec<f64>, f64) {
let n_params = params.len();
let k = n_samples.max(1);
let mut rewards: Vec<f64> = Vec::with_capacity(k);
let mut scores: Vec<Vec<f64>> = Vec::with_capacity(k);
let mut samples: Vec<Vec<f64>> = Vec::with_capacity(k);
for _ in 0..k {
let z = self.model.sample_variational(params, rng);
let log_p = self.model.log_joint(&z, data);
let log_q = self.model.log_variational(&z, params);
let reward = log_p - log_q;
let score = score_function_gaussian(params, &z);
rewards.push(reward);
scores.push(score);
samples.push(z);
}
let elbo: f64 = rewards.iter().sum::<f64>() / k as f64;
if self.config.use_baseline {
*baseline = self.config.baseline_decay * (*baseline)
+ (1.0 - self.config.baseline_decay) * elbo;
}
let mut grad = vec![0.0_f64; n_params];
for i in 0..k {
let centered_reward = if self.config.use_baseline {
let other_mean = if k > 1 {
let sum_other: f64 = rewards
.iter()
.enumerate()
.filter(|&(j, _)| j != i)
.map(|(_, &r)| r)
.sum();
sum_other / (k - 1) as f64
} else {
*baseline
};
rewards[i] - other_mean
} else {
rewards[i]
};
for p in 0..n_params {
grad[p] += scores[i][p] * centered_reward;
}
}
for g in grad.iter_mut() {
*g /= k as f64;
}
(grad, elbo)
}
fn vimco_gradient(
&self,
params: &[f64],
data: &[f64],
rng: &mut dyn RngCore64,
k: usize,
) -> (Vec<f64>, f64) {
let k = k.max(2); let n_params = params.len();
let mut log_ws: Vec<f64> = Vec::with_capacity(k);
let mut scores: Vec<Vec<f64>> = Vec::with_capacity(k);
for _ in 0..k {
let z = self.model.sample_variational(params, rng);
let log_p = self.model.log_joint(&z, data);
let log_q = self.model.log_variational(&z, params);
log_ws.push(log_p - log_q);
scores.push(score_function_gaussian(params, &z));
}
let log_sum_w = logsumexp(&log_ws);
let elbo = log_sum_w - (k as f64).ln();
let mut grad = vec![0.0_f64; n_params];
for i in 0..k {
let log_loo = {
let mut loo_ws: Vec<f64> = Vec::with_capacity(k - 1);
for (j, &lw) in log_ws.iter().enumerate() {
if j != i {
loo_ws.push(lw);
}
}
logsumexp(&loo_ws) - ((k - 1) as f64).ln()
};
let signal = log_ws[i] - log_loo;
for p in 0..n_params {
grad[p] += scores[i][p] * signal;
}
}
for g in grad.iter_mut() {
*g /= k as f64;
}
(grad, elbo)
}
fn pathwise_gradient(
&self,
params: &[f64],
data: &[f64],
rng: &mut dyn RngCore64,
n_samples: usize,
) -> (Vec<f64>, f64) {
let n_params = params.len();
let k = n_samples.max(1);
let h = self.config.fd_step;
let mut epsilons: Vec<Vec<f64>> = Vec::with_capacity(k);
for _ in 0..k {
let eps: Vec<f64> = (0..n_params / 2).map(|_| rng.next_normal()).collect();
epsilons.push(eps);
}
let elbo_base = self.eval_elbo_reparam(params, data, &epsilons);
let mut grad = vec![0.0_f64; n_params];
for i in 0..n_params {
let mut params_plus = params.to_vec();
params_plus[i] += h;
let elbo_plus = self.eval_elbo_reparam(¶ms_plus, data, &epsilons);
grad[i] = (elbo_plus - elbo_base) / h;
}
(grad, elbo_base)
}
fn eval_elbo_reparam(&self, params: &[f64], data: &[f64], epsilons: &[Vec<f64>]) -> f64 {
let k = epsilons.len();
let d = params.len() / 2;
let mut total = 0.0;
for eps in epsilons {
let z: Vec<f64> = (0..d.min(eps.len()))
.map(|i| {
let mu = params[i];
let log_sigma = if i + d < params.len() {
params[i + d]
} else {
0.0
};
let sigma = log_sigma.exp().max(1e-10);
mu + sigma * eps[i]
})
.collect();
let log_p = self.model.log_joint(&z, data);
let log_q = self.model.log_variational(&z, params);
total += log_p - log_q;
}
if k > 0 {
total / k as f64
} else {
f64::NEG_INFINITY
}
}
}
fn score_function_gaussian(params: &[f64], z: &[f64]) -> Vec<f64> {
let n_params = params.len();
let d = n_params / 2;
let mut score = vec![0.0_f64; n_params];
for i in 0..d.min(z.len()) {
let mu = params[i];
let log_sigma = params[i + d];
let sigma = log_sigma.exp().max(1e-10);
let sigma_sq = sigma * sigma;
let diff = z[i] - mu;
score[i] = diff / sigma_sq;
score[i + d] = (diff * diff / sigma_sq) - 1.0;
}
score
}
fn logsumexp(xs: &[f64]) -> f64 {
if xs.is_empty() {
return f64::NEG_INFINITY;
}
let max_x = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if !max_x.is_finite() {
return f64::NEG_INFINITY;
}
let sum: f64 = xs.iter().map(|&x| (x - max_x).exp()).sum();
max_x + sum.ln()
}
#[derive(Debug, Clone)]
pub struct GaussianBbviModel {
pub target_mean: f64,
pub target_std: f64,
pub prior_std: f64,
}
impl GaussianBbviModel {
pub fn new(target_mean: f64, target_std: f64, prior_std: f64) -> Self {
GaussianBbviModel {
target_mean,
target_std,
prior_std,
}
}
}
impl BbviModel for GaussianBbviModel {
fn log_joint(&self, z: &[f64], _data: &[f64]) -> f64 {
if z.is_empty() {
return f64::NEG_INFINITY;
}
let z0 = z[0];
let log_prior = -0.5 * (z0 / self.prior_std).powi(2)
- self.prior_std.ln()
- 0.5 * std::f64::consts::LN_2
- 0.5 * std::f64::consts::PI.ln();
let log_lik = -0.5 * ((self.target_mean - z0) / self.target_std).powi(2)
- self.target_std.ln()
- 0.5 * std::f64::consts::LN_2
- 0.5 * std::f64::consts::PI.ln();
log_prior + log_lik
}
fn sample_variational(&self, params: &[f64], rng: &mut dyn RngCore64) -> Vec<f64> {
if params.len() < 2 {
return vec![0.0];
}
let mu = params[0];
let sigma = params[1].exp().max(1e-10);
let eps = rng.next_normal();
vec![mu + sigma * eps]
}
fn log_variational(&self, z: &[f64], params: &[f64]) -> f64 {
if z.is_empty() || params.len() < 2 {
return f64::NEG_INFINITY;
}
let mu = params[0];
let log_sigma = params[1];
let sigma = log_sigma.exp().max(1e-10);
let diff = z[0] - mu;
-0.5 * (diff / sigma).powi(2)
- log_sigma
- 0.5 * std::f64::consts::LN_2
- 0.5 * std::f64::consts::PI.ln()
}
fn n_params(&self) -> usize {
2 }
}
#[cfg(test)]
mod tests {
use super::*;
fn default_model() -> GaussianBbviModel {
GaussianBbviModel::new(2.0, 0.5, 5.0)
}
fn make_solver(estimator: GradientEstimator) -> BbviSolver<GaussianBbviModel> {
let config = BbviConfig {
n_samples: 20,
n_vimco_samples: 5,
learning_rate: 0.05,
max_iter: 600,
tol: 1e-6,
use_baseline: true,
estimator,
seed: 123,
..BbviConfig::default()
};
BbviSolver::new(default_model(), config)
}
#[test]
fn test_reinforce_elbo_increases() {
let mut solver = make_solver(GradientEstimator::Reinforce);
let init = vec![0.0, 0.0]; let result = solver.fit(&[], &init).expect("BBVI fit failed");
let n = result.elbo_history.len();
assert!(n > 0, "No ELBO history recorded");
let first_elbo = result.elbo_history[0];
let last_elbo = result.elbo_history[n - 1];
assert!(
last_elbo >= first_elbo - 0.5,
"ELBO degraded significantly: {:.4} -> {:.4}",
first_elbo,
last_elbo
);
}
#[test]
fn test_reinforce_returns_valid_params() {
let mut solver = make_solver(GradientEstimator::Reinforce);
let result = solver.fit(&[], &[0.0, 0.0]).expect("fit failed");
assert_eq!(result.variational_params.len(), 2);
for &p in &result.variational_params {
assert!(p.is_finite(), "Non-finite parameter: {}", p);
}
}
#[test]
fn test_vimco_multi_sample() {
let config = BbviConfig {
n_vimco_samples: 5,
learning_rate: 0.05,
max_iter: 200,
estimator: GradientEstimator::Vimco,
seed: 77,
..BbviConfig::default()
};
let mut solver = BbviSolver::new(default_model(), config);
let result = solver.fit(&[], &[0.0, 0.0]).expect("VIMCO fit failed");
assert!(!result.elbo_history.is_empty());
for &elbo in &result.elbo_history {
assert!(elbo.is_finite(), "Non-finite ELBO: {}", elbo);
}
}
#[test]
fn test_vimco_elbo_finite() {
let config = BbviConfig {
n_vimco_samples: 10,
learning_rate: 0.01,
max_iter: 100,
estimator: GradientEstimator::Vimco,
seed: 999,
..BbviConfig::default()
};
let mut solver = BbviSolver::new(default_model(), config);
let result = solver.fit(&[], &[1.0, -0.5]).expect("VIMCO fit failed");
assert!(result
.elbo_history
.last()
.map(|&e| e.is_finite())
.unwrap_or(false));
}
#[test]
fn test_bbvi_gaussian_approx() {
let config = BbviConfig {
n_samples: 30,
learning_rate: 0.03,
max_iter: 1500,
tol: 1e-7,
use_baseline: true,
estimator: GradientEstimator::Reinforce,
seed: 42,
..BbviConfig::default()
};
let mut solver = BbviSolver::new(default_model(), config);
let result = solver.fit(&[], &[0.0, 0.0]).expect("fit failed");
let mu = result.variational_params[0];
assert!(
(mu - 2.0).abs() < 1.5,
"Mean too far from target: μ={:.3}",
mu
);
}
#[test]
fn test_bbvi_baseline_reduces_variance() {
let config_with = BbviConfig {
n_samples: 50,
max_iter: 30,
use_baseline: true,
estimator: GradientEstimator::Reinforce,
seed: 10,
..BbviConfig::default()
};
let config_without = BbviConfig {
n_samples: 50,
max_iter: 30,
use_baseline: false,
estimator: GradientEstimator::Reinforce,
seed: 10,
..BbviConfig::default()
};
let mut solver_with = BbviSolver::new(default_model(), config_with);
let mut solver_without = BbviSolver::new(default_model(), config_without);
let result_with = solver_with.fit(&[], &[0.0, 0.0]).expect("fit failed");
let result_without = solver_without.fit(&[], &[0.0, 0.0]).expect("fit failed");
assert!(result_with.elbo_history.iter().all(|&e| e.is_finite()));
assert!(result_without.elbo_history.iter().all(|&e| e.is_finite()));
}
#[test]
fn test_bbvi_convergence() {
let config = BbviConfig {
n_samples: 20,
learning_rate: 0.05,
max_iter: 500,
tol: 1e-5,
use_baseline: true,
estimator: GradientEstimator::Reinforce,
seed: 55,
..BbviConfig::default()
};
let model = GaussianBbviModel::new(1.5, 0.3, 3.0);
let mut solver = BbviSolver::new(model, config);
let result = solver.fit(&[], &[0.0, 0.0]).expect("fit failed");
assert!(
result.converged || result.elbo_history.len() == 500,
"Solver returned inconsistent state"
);
for &p in &result.variational_params {
assert!(p.is_finite(), "Non-finite parameter after convergence");
}
}
#[test]
fn test_pathwise_gradient() {
let config = BbviConfig {
n_samples: 10,
learning_rate: 0.01,
max_iter: 100,
estimator: GradientEstimator::PathwiseDifferentiable,
seed: 7,
fd_step: 1e-4,
..BbviConfig::default()
};
let mut solver = BbviSolver::new(default_model(), config);
let result = solver.fit(&[], &[0.0, 0.0]).expect("pathwise fit failed");
assert!(!result.elbo_history.is_empty());
for &e in &result.elbo_history {
assert!(e.is_finite(), "Non-finite ELBO in pathwise: {}", e);
}
}
#[test]
fn test_bbvi_mismatched_params_error() {
let config = BbviConfig::default();
let mut solver = BbviSolver::new(default_model(), config);
let res = solver.fit(&[], &[0.0, 0.0, 0.0]);
assert!(res.is_err(), "Expected error for mismatched params");
}
#[test]
fn test_bbvi_single_sample_reinforce() {
let config = BbviConfig {
n_samples: 1,
max_iter: 50,
estimator: GradientEstimator::Reinforce,
seed: 1,
..BbviConfig::default()
};
let mut solver = BbviSolver::new(default_model(), config);
let result = solver
.fit(&[], &[0.0, 0.0])
.expect("single-sample fit failed");
assert_eq!(result.variational_params.len(), 2);
}
#[test]
fn test_bbvi_elbo_history_length() {
let config = BbviConfig {
max_iter: 50,
tol: 0.0, ..BbviConfig::default()
};
let mut solver = BbviSolver::new(default_model(), config);
let result = solver.fit(&[], &[0.0, 0.0]).expect("fit failed");
assert!(result.elbo_history.len() <= 50);
}
}