scirs2_stats/variational/mod.rs
1//! Variational Inference Methods
2//!
3//! This module provides modern variational inference algorithms for approximate
4//! Bayesian posterior computation:
5//!
6//! - **ADVI**: Automatic Differentiation Variational Inference (Kucukelbir et al. 2017)
7//! with mean-field and full-rank Gaussian approximations, automatic parameter
8//! transformations, ELBO optimization via reparameterization trick + Adam optimizer.
9//!
10//! - **SVGD**: Stein Variational Gradient Descent (Liu & Wang 2016) — a particle-based
11//! method that transports a set of particles to approximate the posterior using
12//! kernelized Stein discrepancy with RBF kernel and median bandwidth heuristic.
13//!
14//! - **Normalizing Flows**: Flexible posterior approximations via invertible
15//! transformations (planar and radial flows) with log-determinant Jacobian tracking.
16
17mod advi;
18pub mod bbvi;
19mod normalizing_flow;
20mod svgd;
21
22pub use advi::*;
23pub use normalizing_flow::*;
24pub use svgd::*;
25
26use crate::error::StatsResult;
27use scirs2_core::ndarray::Array1;
28
29// ============================================================================
30// Common Trait
31// ============================================================================
32
33/// Result of variational inference
34#[derive(Debug, Clone)]
35pub struct PosteriorResult {
36 /// Posterior means (in constrained space)
37 pub means: Array1<f64>,
38 /// Posterior standard deviations (in constrained space)
39 pub std_devs: Array1<f64>,
40 /// ELBO history over iterations
41 pub elbo_history: Vec<f64>,
42 /// Number of iterations performed
43 pub iterations: usize,
44 /// Whether the algorithm converged
45 pub converged: bool,
46 /// Optional: posterior samples (for particle-based methods like SVGD)
47 pub samples: Option<Vec<Array1<f64>>>,
48}
49
50/// Common trait for variational inference methods
51pub trait VariationalInference {
52 /// Fit the variational approximation to a target log-joint distribution.
53 ///
54 /// # Arguments
55 /// * `log_joint` - Function computing `(log p(x, theta), grad_theta log p(x, theta))`
56 /// * `dim` - Dimensionality of the parameter space
57 ///
58 /// # Returns
59 /// A `PosteriorResult` with posterior statistics and convergence info
60 fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
61 where
62 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>;
63}