use crate::error::{StatsError, StatsResult as Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::validation::*;
use std::f64::consts::PI;
use super::svi::{AdamState, LearningRateSchedule};
use super::{FullRankGaussian, MeanFieldGaussian, VariationalDiagnostics};
#[derive(Debug, Clone)]
pub enum ParameterConstraint {
Real,
Positive,
UnitInterval,
Bounded {
lower: f64,
upper: f64,
},
Simplex {
dim: usize,
},
LowerBounded {
lower: f64,
},
UpperBounded {
upper: f64,
},
}
impl ParameterConstraint {
pub fn forward(&self, unconstrained: f64) -> f64 {
match self {
ParameterConstraint::Real => unconstrained,
ParameterConstraint::Positive => unconstrained.exp(),
ParameterConstraint::UnitInterval => 1.0 / (1.0 + (-unconstrained).exp()),
ParameterConstraint::Bounded { lower, upper } => {
let sigmoid = 1.0 / (1.0 + (-unconstrained).exp());
lower + (upper - lower) * sigmoid
}
ParameterConstraint::LowerBounded { lower } => lower + unconstrained.exp(),
ParameterConstraint::UpperBounded { upper } => upper - (-unconstrained).exp(),
ParameterConstraint::Simplex { .. } => {
1.0 / (1.0 + (-unconstrained).exp())
}
}
}
pub fn inverse(&self, constrained: f64) -> Result<f64> {
match self {
ParameterConstraint::Real => Ok(constrained),
ParameterConstraint::Positive => {
if constrained <= 0.0 {
return Err(StatsError::InvalidArgument(format!(
"Positive constraint requires value > 0, got {}",
constrained
)));
}
Ok(constrained.ln())
}
ParameterConstraint::UnitInterval => {
if constrained <= 0.0 || constrained >= 1.0 {
return Err(StatsError::InvalidArgument(format!(
"Unit interval constraint requires 0 < value < 1, got {}",
constrained
)));
}
Ok((constrained / (1.0 - constrained)).ln())
}
ParameterConstraint::Bounded { lower, upper } => {
if constrained <= *lower || constrained >= *upper {
return Err(StatsError::InvalidArgument(format!(
"Bounded constraint requires {} < value < {}, got {}",
lower, upper, constrained
)));
}
let normalized = (constrained - lower) / (upper - lower);
Ok((normalized / (1.0 - normalized)).ln())
}
ParameterConstraint::LowerBounded { lower } => {
if constrained <= *lower {
return Err(StatsError::InvalidArgument(format!(
"Lower-bounded constraint requires value > {}, got {}",
lower, constrained
)));
}
Ok((constrained - lower).ln())
}
ParameterConstraint::UpperBounded { upper } => {
if constrained >= *upper {
return Err(StatsError::InvalidArgument(format!(
"Upper-bounded constraint requires value < {}, got {}",
upper, constrained
)));
}
Ok(-((*upper - constrained).ln()))
}
ParameterConstraint::Simplex { .. } => {
if constrained <= 0.0 || constrained >= 1.0 {
return Err(StatsError::InvalidArgument(format!(
"Simplex element must be in (0, 1), got {}",
constrained
)));
}
Ok((constrained / (1.0 - constrained)).ln())
}
}
}
pub fn log_det_jacobian(&self, unconstrained: f64) -> f64 {
match self {
ParameterConstraint::Real => 0.0,
ParameterConstraint::Positive => {
unconstrained
}
ParameterConstraint::UnitInterval => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
(s * (1.0 - s)).ln()
}
ParameterConstraint::Bounded { lower, upper } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
((upper - lower) * s * (1.0 - s)).ln()
}
ParameterConstraint::LowerBounded { .. } => unconstrained,
ParameterConstraint::UpperBounded { .. } => unconstrained,
ParameterConstraint::Simplex { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
(s * (1.0 - s)).ln()
}
}
}
}
#[derive(Debug, Clone)]
pub struct AdviConfig {
pub max_iter: usize,
pub tol: f64,
pub n_mc_samples: usize,
pub lr_schedule: LearningRateSchedule,
pub grad_clip: f64,
pub diagnostic_interval: usize,
pub seed: u64,
pub convergence_window: usize,
}
impl Default for AdviConfig {
fn default() -> Self {
Self {
max_iter: 10000,
tol: 1e-4,
n_mc_samples: 1,
lr_schedule: LearningRateSchedule::default_adam(),
grad_clip: 10.0,
diagnostic_interval: 100,
seed: 42,
convergence_window: 50,
}
}
}
#[derive(Debug, Clone)]
pub struct AdviMeanField {
pub variational: MeanFieldGaussian,
pub constraints: Vec<ParameterConstraint>,
pub config: AdviConfig,
pub diagnostics: VariationalDiagnostics,
pub dim: usize,
}
impl AdviMeanField {
pub fn new(constraints: Vec<ParameterConstraint>, config: AdviConfig) -> Result<Self> {
let dim = constraints.len();
if dim == 0 {
return Err(StatsError::InvalidArgument(
"Must have at least one parameter".to_string(),
));
}
let variational = MeanFieldGaussian::new(dim)?;
Ok(Self {
variational,
constraints,
config,
diagnostics: VariationalDiagnostics::new(),
dim,
})
}
pub fn new_unconstrained(dim: usize, config: AdviConfig) -> Result<Self> {
let constraints = vec![ParameterConstraint::Real; dim];
Self::new(constraints, config)
}
pub fn initialize_from_constrained(&mut self, theta: &Array1<f64>) -> Result<()> {
if theta.len() != self.dim {
return Err(StatsError::DimensionMismatch(format!(
"theta length ({}) must match dimension ({})",
theta.len(),
self.dim
)));
}
let mut eta = Array1::zeros(self.dim);
for i in 0..self.dim {
eta[i] = self.constraints[i].inverse(theta[i])?;
}
self.variational.means = eta;
self.variational.log_stds = Array1::from_elem(self.dim, -1.0);
Ok(())
}
pub fn fit<F>(&mut self, log_joint: F) -> Result<AdviResult>
where
F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
{
let n_params = self.variational.n_params();
let mut adam_state = if let LearningRateSchedule::Adam {
lr,
beta1,
beta2,
epsilon,
} = &self.config.lr_schedule
{
Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
} else {
None
};
for iter in 0..self.config.max_iter {
let (elbo, grad) = self.compute_elbo_gradient(&log_joint, iter as u64)?;
self.diagnostics.record_elbo(elbo);
let grad_norm = grad.dot(&grad).sqrt();
self.diagnostics.record_gradient_norm(grad_norm);
let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
&grad * (self.config.grad_clip / grad_norm)
} else {
grad
};
let current_params = self.variational.get_params();
let new_params = if let Some(ref mut adam) = adam_state {
let update = adam.compute_update(&clipped_grad)?;
¤t_params + &update
} else {
let lr = self.config.lr_schedule.get_lr(iter);
¤t_params + &(&clipped_grad * lr)
};
let param_change = (&new_params - ¤t_params).mapv(|x| x * x).sum().sqrt();
self.diagnostics.record_param_change(param_change);
self.variational.set_params(&new_params)?;
if iter > self.config.convergence_window {
if let Some(rel_change) = self
.diagnostics
.relative_elbo_change(self.config.convergence_window)
{
if rel_change < self.config.tol {
self.diagnostics.converged = true;
break;
}
}
}
}
let constrained_means = self.transform_to_constrained(&self.variational.means)?;
Ok(AdviResult {
variational: self.variational.clone(),
constraints: self.constraints.clone(),
constrained_means,
diagnostics: self.diagnostics.clone(),
dim: self.dim,
})
}
fn compute_elbo_gradient<F>(&self, log_joint: &F, seed: u64) -> Result<(f64, Array1<f64>)>
where
F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
{
let dim = self.dim;
let n_samples = self.config.n_mc_samples.max(1);
let n_params = 2 * dim;
let mut total_elbo = 0.0;
let mut total_grad = Array1::zeros(n_params);
let stds = self.variational.stds();
for s in 0..n_samples {
let epsilon = generate_standard_normal_advi(dim, seed * 1000 + s as u64);
let eta = self.variational.sample(&epsilon)?;
let theta = self.transform_to_constrained(&eta)?;
let (log_p, grad_theta) = log_joint(&theta)?;
let mut log_det_j = 0.0;
for i in 0..dim {
log_det_j += self.constraints[i].log_det_jacobian(eta[i]);
}
total_elbo += log_p + log_det_j;
let grad_eta = self.compute_grad_eta(&eta, &grad_theta)?;
let grad_log_det_j = self.compute_grad_log_det_j(&eta)?;
let grad_combined = &grad_eta + &grad_log_det_j;
for i in 0..dim {
total_grad[i] += grad_combined[i];
total_grad[dim + i] += grad_combined[i] * epsilon[i] * stds[i];
}
}
total_elbo /= n_samples as f64;
total_grad /= n_samples as f64;
let entropy = self.variational.entropy();
total_elbo += entropy;
for i in 0..dim {
total_grad[dim + i] += 1.0;
}
Ok((total_elbo, total_grad))
}
fn transform_to_constrained(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
let mut theta = Array1::zeros(self.dim);
for i in 0..self.dim {
theta[i] = self.constraints[i].forward(eta[i]);
}
Ok(theta)
}
fn compute_grad_eta(&self, eta: &Array1<f64>, grad_theta: &Array1<f64>) -> Result<Array1<f64>> {
let mut grad_eta = Array1::zeros(self.dim);
for i in 0..self.dim {
let dtheta_deta = self.compute_transform_derivative(i, eta[i]);
grad_eta[i] = grad_theta[i] * dtheta_deta;
}
Ok(grad_eta)
}
fn compute_transform_derivative(&self, i: usize, unconstrained: f64) -> f64 {
match &self.constraints[i] {
ParameterConstraint::Real => 1.0,
ParameterConstraint::Positive => unconstrained.exp(),
ParameterConstraint::UnitInterval => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
s * (1.0 - s)
}
ParameterConstraint::Bounded { lower, upper } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
(upper - lower) * s * (1.0 - s)
}
ParameterConstraint::LowerBounded { .. } => unconstrained.exp(),
ParameterConstraint::UpperBounded { .. } => (-unconstrained).exp(),
ParameterConstraint::Simplex { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
s * (1.0 - s)
}
}
}
fn compute_grad_log_det_j(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
let mut grad = Array1::zeros(self.dim);
for i in 0..self.dim {
grad[i] = self.compute_grad_log_det_j_single(i, eta[i]);
}
Ok(grad)
}
fn compute_grad_log_det_j_single(&self, i: usize, unconstrained: f64) -> f64 {
match &self.constraints[i] {
ParameterConstraint::Real => 0.0,
ParameterConstraint::Positive => 1.0,
ParameterConstraint::UnitInterval => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
1.0 - 2.0 * s
}
ParameterConstraint::Bounded { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
1.0 - 2.0 * s
}
ParameterConstraint::LowerBounded { .. } => 1.0,
ParameterConstraint::UpperBounded { .. } => 1.0,
ParameterConstraint::Simplex { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
1.0 - 2.0 * s
}
}
}
}
#[derive(Debug, Clone)]
pub struct AdviFullRank {
pub variational: FullRankGaussian,
pub constraints: Vec<ParameterConstraint>,
pub config: AdviConfig,
pub diagnostics: VariationalDiagnostics,
pub dim: usize,
}
impl AdviFullRank {
pub fn new(constraints: Vec<ParameterConstraint>, config: AdviConfig) -> Result<Self> {
let dim = constraints.len();
if dim == 0 {
return Err(StatsError::InvalidArgument(
"Must have at least one parameter".to_string(),
));
}
let variational = FullRankGaussian::new(dim)?;
Ok(Self {
variational,
constraints,
config,
diagnostics: VariationalDiagnostics::new(),
dim,
})
}
pub fn new_unconstrained(dim: usize, config: AdviConfig) -> Result<Self> {
let constraints = vec![ParameterConstraint::Real; dim];
Self::new(constraints, config)
}
pub fn initialize_from_constrained(&mut self, theta: &Array1<f64>) -> Result<()> {
if theta.len() != self.dim {
return Err(StatsError::DimensionMismatch(format!(
"theta length ({}) must match dimension ({})",
theta.len(),
self.dim
)));
}
let mut eta = Array1::zeros(self.dim);
for i in 0..self.dim {
eta[i] = self.constraints[i].inverse(theta[i])?;
}
self.variational.mean = eta;
self.variational.chol_factor = Array2::eye(self.dim) * 0.1;
Ok(())
}
pub fn fit<F>(&mut self, log_joint: F) -> Result<AdviFullRankResult>
where
F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
{
let n_params = self.variational.n_params();
let mut adam_state = if let LearningRateSchedule::Adam {
lr,
beta1,
beta2,
epsilon,
} = &self.config.lr_schedule
{
Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
} else {
None
};
for iter in 0..self.config.max_iter {
let (elbo, grad) = self.compute_elbo_gradient_full_rank(&log_joint, iter as u64)?;
self.diagnostics.record_elbo(elbo);
let grad_norm = grad.dot(&grad).sqrt();
self.diagnostics.record_gradient_norm(grad_norm);
let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
&grad * (self.config.grad_clip / grad_norm)
} else {
grad
};
let current_params = self.variational.get_params();
let new_params = if let Some(ref mut adam) = adam_state {
let update = adam.compute_update(&clipped_grad)?;
¤t_params + &update
} else {
let lr = self.config.lr_schedule.get_lr(iter);
¤t_params + &(&clipped_grad * lr)
};
let param_change = (&new_params - ¤t_params).mapv(|x| x * x).sum().sqrt();
self.diagnostics.record_param_change(param_change);
self.variational.set_params(&new_params)?;
if iter > self.config.convergence_window {
if let Some(rel_change) = self
.diagnostics
.relative_elbo_change(self.config.convergence_window)
{
if rel_change < self.config.tol {
self.diagnostics.converged = true;
break;
}
}
}
}
let constrained_means = self.transform_to_constrained(&self.variational.mean)?;
Ok(AdviFullRankResult {
variational: self.variational.clone(),
constraints: self.constraints.clone(),
constrained_means,
diagnostics: self.diagnostics.clone(),
dim: self.dim,
})
}
fn compute_elbo_gradient_full_rank<F>(
&self,
log_joint: &F,
seed: u64,
) -> Result<(f64, Array1<f64>)>
where
F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
{
let dim = self.dim;
let n_samples = self.config.n_mc_samples.max(1);
let n_params = self.variational.n_params();
let mut total_elbo = 0.0;
let mut total_grad = Array1::zeros(n_params);
let n_tril = dim * (dim + 1) / 2;
for s in 0..n_samples {
let epsilon = generate_standard_normal_advi(dim, seed * 1000 + s as u64);
let eta = self.variational.sample(&epsilon)?;
let theta = self.transform_to_constrained(&eta)?;
let (log_p, grad_theta) = log_joint(&theta)?;
let mut log_det_j = 0.0;
for i in 0..dim {
log_det_j += compute_log_det_jacobian(&self.constraints[i], eta[i]);
}
total_elbo += log_p + log_det_j;
let grad_eta = compute_grad_eta_from_theta(dim, &eta, &grad_theta, &self.constraints)?;
let grad_log_det = compute_grad_log_det(dim, &eta, &self.constraints)?;
let grad_combined: Array1<f64> = &grad_eta + &grad_log_det;
for i in 0..dim {
total_grad[i] += grad_combined[i];
}
let mut l_idx = dim;
for i in 0..dim {
for j in 0..=i {
total_grad[l_idx] += grad_combined[i] * epsilon[j];
l_idx += 1;
}
}
}
total_elbo /= n_samples as f64;
total_grad /= n_samples as f64;
let entropy = self.variational.entropy();
total_elbo += entropy;
let mut l_idx = dim;
for i in 0..dim {
for j in 0..=i {
if i == j {
let l_ii = self.variational.chol_factor[[i, i]];
if l_ii.abs() > 1e-15 {
total_grad[l_idx] += 1.0 / l_ii;
}
}
l_idx += 1;
}
}
Ok((total_elbo, total_grad))
}
fn transform_to_constrained(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
let mut theta = Array1::zeros(self.dim);
for i in 0..self.dim {
theta[i] = self.constraints[i].forward(eta[i]);
}
Ok(theta)
}
}
#[derive(Debug, Clone)]
pub struct AdviResult {
pub variational: MeanFieldGaussian,
pub constraints: Vec<ParameterConstraint>,
pub constrained_means: Array1<f64>,
pub diagnostics: VariationalDiagnostics,
pub dim: usize,
}
impl AdviResult {
pub fn unconstrained_means(&self) -> &Array1<f64> {
&self.variational.means
}
pub fn unconstrained_stds(&self) -> Array1<f64> {
self.variational.stds()
}
pub fn constrained_means(&self) -> &Array1<f64> {
&self.constrained_means
}
pub fn sample_constrained(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
let eta = self.variational.sample(epsilon)?;
let mut theta = Array1::zeros(self.dim);
for i in 0..self.dim {
theta[i] = self.constraints[i].forward(eta[i]);
}
Ok(theta)
}
pub fn approximate_credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
check_probability(confidence, "confidence")?;
let alpha = (1.0 - confidence) / 2.0;
let z_critical = super::normal_ppf(1.0 - alpha)?;
let stds = self.variational.stds();
let mut intervals = Array2::zeros((self.dim, 2));
for i in 0..self.dim {
let eta_low = self.variational.means[i] - z_critical * stds[i];
let eta_high = self.variational.means[i] + z_critical * stds[i];
let theta_low = self.constraints[i].forward(eta_low);
let theta_high = self.constraints[i].forward(eta_high);
intervals[[i, 0]] = theta_low.min(theta_high);
intervals[[i, 1]] = theta_low.max(theta_high);
}
Ok(intervals)
}
}
#[derive(Debug, Clone)]
pub struct AdviFullRankResult {
pub variational: FullRankGaussian,
pub constraints: Vec<ParameterConstraint>,
pub constrained_means: Array1<f64>,
pub diagnostics: VariationalDiagnostics,
pub dim: usize,
}
impl AdviFullRankResult {
pub fn unconstrained_means(&self) -> &Array1<f64> {
&self.variational.mean
}
pub fn unconstrained_covariance(&self) -> Array2<f64> {
self.variational.covariance()
}
pub fn constrained_means(&self) -> &Array1<f64> {
&self.constrained_means
}
pub fn sample_constrained(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
let eta = self.variational.sample(epsilon)?;
let mut theta = Array1::zeros(self.dim);
for i in 0..self.dim {
theta[i] = self.constraints[i].forward(eta[i]);
}
Ok(theta)
}
pub fn approximate_credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
check_probability(confidence, "confidence")?;
let alpha = (1.0 - confidence) / 2.0;
let z_critical = super::normal_ppf(1.0 - alpha)?;
let cov = self.variational.covariance();
let mut intervals = Array2::zeros((self.dim, 2));
for i in 0..self.dim {
let std_i = cov[[i, i]].sqrt();
let eta_low = self.variational.mean[i] - z_critical * std_i;
let eta_high = self.variational.mean[i] + z_critical * std_i;
let theta_low = self.constraints[i].forward(eta_low);
let theta_high = self.constraints[i].forward(eta_high);
intervals[[i, 0]] = theta_low.min(theta_high);
intervals[[i, 1]] = theta_low.max(theta_high);
}
Ok(intervals)
}
}
fn compute_log_det_jacobian(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
constraint.log_det_jacobian(unconstrained)
}
fn compute_grad_eta_from_theta(
dim: usize,
eta: &Array1<f64>,
grad_theta: &Array1<f64>,
constraints: &[ParameterConstraint],
) -> Result<Array1<f64>> {
let mut grad_eta = Array1::zeros(dim);
for i in 0..dim {
let dtheta_deta = compute_transform_deriv(&constraints[i], eta[i]);
grad_eta[i] = grad_theta[i] * dtheta_deta;
}
Ok(grad_eta)
}
fn compute_grad_log_det(
dim: usize,
eta: &Array1<f64>,
constraints: &[ParameterConstraint],
) -> Result<Array1<f64>> {
let mut grad = Array1::zeros(dim);
for i in 0..dim {
grad[i] = compute_grad_log_det_single(&constraints[i], eta[i]);
}
Ok(grad)
}
fn compute_transform_deriv(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
match constraint {
ParameterConstraint::Real => 1.0,
ParameterConstraint::Positive => unconstrained.exp(),
ParameterConstraint::UnitInterval => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
s * (1.0 - s)
}
ParameterConstraint::Bounded { lower, upper } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
(upper - lower) * s * (1.0 - s)
}
ParameterConstraint::LowerBounded { .. } => unconstrained.exp(),
ParameterConstraint::UpperBounded { .. } => (-unconstrained).exp(),
ParameterConstraint::Simplex { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
s * (1.0 - s)
}
}
}
fn compute_grad_log_det_single(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
match constraint {
ParameterConstraint::Real => 0.0,
ParameterConstraint::Positive => 1.0,
ParameterConstraint::UnitInterval => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
1.0 - 2.0 * s
}
ParameterConstraint::Bounded { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
1.0 - 2.0 * s
}
ParameterConstraint::LowerBounded { .. } => 1.0,
ParameterConstraint::UpperBounded { .. } => 1.0,
ParameterConstraint::Simplex { .. } => {
let s = 1.0 / (1.0 + (-unconstrained).exp());
1.0 - 2.0 * s
}
}
}
fn generate_standard_normal_advi(dim: usize, seed: u64) -> Array1<f64> {
let mut result = Array1::zeros(dim);
let golden_ratio = 1.618033988749895;
for i in 0..dim {
let u1 = ((seed as f64 * golden_ratio + i as f64 * 0.7548776662466927) % 1.0).abs();
let u2 = ((seed as f64 * 0.5698402909980532 + i as f64 * golden_ratio) % 1.0).abs();
let u1_safe = u1.max(1e-10).min(1.0 - 1e-10);
let u2_safe = u2.max(1e-10).min(1.0 - 1e-10);
let r = (-2.0 * u1_safe.ln()).sqrt();
let theta = 2.0 * PI * u2_safe;
result[i] = r * theta.cos();
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_constraint_real() {
let c = ParameterConstraint::Real;
assert!((c.forward(1.5) - 1.5).abs() < 1e-10);
let inv = c.inverse(1.5).expect("should invert");
assert!((inv - 1.5).abs() < 1e-10);
assert!((c.log_det_jacobian(1.5)).abs() < 1e-10);
}
#[test]
fn test_constraint_positive() {
let c = ParameterConstraint::Positive;
assert!((c.forward(0.0) - 1.0).abs() < 1e-10);
assert!((c.forward(1.0) - 1.0_f64.exp()).abs() < 1e-10);
let inv = c.inverse(1.0_f64.exp()).expect("should invert");
assert!((inv - 1.0).abs() < 1e-10);
assert!(c.inverse(-1.0).is_err());
}
#[test]
fn test_constraint_unit_interval() {
let c = ParameterConstraint::UnitInterval;
assert!((c.forward(0.0) - 0.5).abs() < 1e-10);
let inv = c.inverse(0.5).expect("should invert");
assert!(inv.abs() < 1e-10);
assert!(c.inverse(0.0).is_err());
assert!(c.inverse(1.0).is_err());
}
#[test]
fn test_constraint_bounded() {
let c = ParameterConstraint::Bounded {
lower: -1.0,
upper: 1.0,
};
assert!((c.forward(0.0)).abs() < 1e-10);
let inv = c.inverse(0.0).expect("should invert");
assert!(inv.abs() < 1e-10);
}
#[test]
fn test_constraint_lower_bounded() {
let c = ParameterConstraint::LowerBounded { lower: 2.0 };
assert!((c.forward(0.0) - 3.0).abs() < 1e-10);
let inv = c.inverse(3.0).expect("should invert");
assert!(inv.abs() < 1e-10);
assert!(c.inverse(1.0).is_err());
}
#[test]
fn test_constraint_roundtrip() {
let constraints = vec![
ParameterConstraint::Real,
ParameterConstraint::Positive,
ParameterConstraint::UnitInterval,
ParameterConstraint::Bounded {
lower: 0.0,
upper: 10.0,
},
];
let unconstrained_values = vec![0.5, 1.0, -0.5, 2.0];
for (c, &eta) in constraints.iter().zip(unconstrained_values.iter()) {
let theta = c.forward(eta);
let eta_back = c.inverse(theta).expect("should invert");
assert!(
(eta_back - eta).abs() < 1e-8,
"Roundtrip failed for {:?}: {} -> {} -> {}",
c,
eta,
theta,
eta_back
);
}
}
#[test]
fn test_advi_mean_field_creation() {
let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Positive];
let config = AdviConfig::default();
let advi = AdviMeanField::new(constraints, config).expect("should create");
assert_eq!(advi.dim, 2);
}
#[test]
fn test_advi_mean_field_simple_gaussian() {
let target_mean = Array1::from_vec(vec![1.0, -2.0]);
let target_precision = 2.0;
let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Real];
let config = AdviConfig {
max_iter: 500,
n_mc_samples: 1,
lr_schedule: LearningRateSchedule::Adam {
lr: 0.05,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
tol: 1e-6,
convergence_window: 20,
..AdviConfig::default()
};
let mut advi = AdviMeanField::new(constraints, config).expect("should create");
let tm = target_mean.clone();
let result = advi
.fit(move |theta: &Array1<f64>| {
let diff = theta - &tm;
let log_p = -0.5 * target_precision * diff.dot(&diff);
let grad = &diff * (-target_precision);
Ok((log_p, grad))
})
.expect("should fit");
assert!(
result.diagnostics.n_iterations > 0,
"Should have performed iterations"
);
assert!(
result.diagnostics.final_elbo.is_finite(),
"ELBO should be finite"
);
}
#[test]
fn test_advi_full_rank_creation() {
let constraints = vec![
ParameterConstraint::Real,
ParameterConstraint::Positive,
ParameterConstraint::UnitInterval,
];
let config = AdviConfig::default();
let advi = AdviFullRank::new(constraints, config).expect("should create");
assert_eq!(advi.dim, 3);
}
#[test]
fn test_advi_full_rank_simple() {
let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Real];
let config = AdviConfig {
max_iter: 200,
n_mc_samples: 1,
lr_schedule: LearningRateSchedule::Adam {
lr: 0.02,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
tol: 1e-5,
convergence_window: 20,
..AdviConfig::default()
};
let mut advi = AdviFullRank::new(constraints, config).expect("should create");
let result = advi
.fit(|theta: &Array1<f64>| {
let log_p = -0.5 * theta.dot(theta);
let grad = theta * (-1.0);
Ok((log_p, grad))
})
.expect("should fit");
assert!(result.diagnostics.n_iterations > 0);
assert!(result.diagnostics.final_elbo.is_finite());
}
#[test]
fn test_advi_with_constrained_params() {
let constraints = vec![
ParameterConstraint::Real, ParameterConstraint::Positive, ];
let config = AdviConfig {
max_iter: 300,
n_mc_samples: 1,
lr_schedule: LearningRateSchedule::Adam {
lr: 0.01,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
tol: 1e-5,
convergence_window: 30,
..AdviConfig::default()
};
let mut advi = AdviMeanField::new(constraints, config).expect("should create");
let result = advi
.fit(|theta: &Array1<f64>| {
let log_p = -0.5 * (theta[0] - 1.0).powi(2) - 2.0 * (theta[1] - 2.0).powi(2);
let mut grad = Array1::zeros(2);
grad[0] = -(theta[0] - 1.0);
grad[1] = -4.0 * (theta[1] - 2.0);
Ok((log_p, grad))
})
.expect("should fit");
assert!(
result.constrained_means[1] > 0.0,
"Positive-constrained parameter should be > 0, got {}",
result.constrained_means[1]
);
}
#[test]
fn test_advi_result_credible_intervals() {
let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Positive];
let config = AdviConfig {
max_iter: 100,
..AdviConfig::default()
};
let mut advi = AdviMeanField::new(constraints, config).expect("should create");
let result = advi
.fit(|theta: &Array1<f64>| {
let log_p = -0.5 * theta.dot(theta);
let grad = theta * (-1.0);
Ok((log_p, grad))
})
.expect("should fit");
let intervals = result
.approximate_credible_intervals(0.95)
.expect("should compute intervals");
assert_eq!(intervals.nrows(), 2);
assert_eq!(intervals.ncols(), 2);
for i in 0..2 {
assert!(
intervals[[i, 0]] <= intervals[[i, 1]],
"Lower bound should be <= upper bound at dim {}",
i
);
}
}
#[test]
fn test_log_det_jacobian_positive() {
let c = ParameterConstraint::Positive;
assert!((c.log_det_jacobian(0.0)).abs() < 1e-10);
assert!((c.log_det_jacobian(1.0) - 1.0).abs() < 1e-10);
assert!((c.log_det_jacobian(-1.0) - (-1.0)).abs() < 1e-10);
}
#[test]
fn test_log_det_jacobian_unit_interval() {
let c = ParameterConstraint::UnitInterval;
let expected = (0.25_f64).ln();
assert!(
(c.log_det_jacobian(0.0) - expected).abs() < 1e-10,
"log det J at 0 should be {}, got {}",
expected,
c.log_det_jacobian(0.0)
);
}
}