Skip to main content

scirs2_stats/variational/
svgd.rs

1//! Stein Variational Gradient Descent (SVGD)
2//!
3//! Implements SVGD (Liu & Wang 2016) — a particle-based variational inference method
4//! that maintains a set of particles and iteratively transports them to approximate
5//! the target posterior distribution.
6//!
7//! The update rule uses a kernelized Stein operator:
8//! ```text
9//! theta_i <- theta_i + epsilon * phi*(theta_i)
10//! phi*(theta) = (1/n) sum_j [k(theta_j, theta) * grad_theta_j log p(theta_j | x)
11//!              + grad_theta_j k(theta_j, theta)]
12//! ```
13//! - The first term drives particles toward high-probability regions
14//! - The second term acts as a repulsive force to maintain diversity
15//!
16//! Uses RBF kernel with median bandwidth heuristic and Adam optimizer for
17//! adaptive step sizes.
18
19use crate::error::{StatsError, StatsResult};
20use scirs2_core::ndarray::Array1;
21use std::f64::consts::PI;
22
23use super::{PosteriorResult, VariationalInference};
24
25// ============================================================================
26// RBF Kernel
27// ============================================================================
28
29/// Radial Basis Function (RBF) kernel: k(x, y) = exp(-||x - y||^2 / (2 h^2))
30#[derive(Debug, Clone)]
31pub struct RbfKernel {
32    /// Bandwidth parameter h; if None, use median heuristic
33    pub bandwidth: Option<f64>,
34}
35
36impl RbfKernel {
37    /// Compute the median heuristic bandwidth from pairwise distances.
38    /// h^2 = median(||x_i - x_j||^2) / log(n) for n particles.
39    fn median_bandwidth(particles: &[Array1<f64>]) -> f64 {
40        let n = particles.len();
41        if n <= 1 {
42            return 1.0;
43        }
44
45        let mut dists_sq: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
46        for i in 0..n {
47            for j in (i + 1)..n {
48                let diff = &particles[i] - &particles[j];
49                dists_sq.push(diff.dot(&diff));
50            }
51        }
52
53        if dists_sq.is_empty() {
54            return 1.0;
55        }
56
57        dists_sq.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
58        let median_sq = dists_sq[dists_sq.len() / 2];
59
60        // h^2 = median / log(n), with floor to avoid degenerate bandwidth
61        let log_n = (n as f64).ln().max(1.0);
62        let h_sq = median_sq / log_n;
63        h_sq.max(1e-6).sqrt()
64    }
65
66    /// Evaluate kernel k(x, y) and its gradient w.r.t. x.
67    ///
68    /// Returns (k_val, grad_x_k) where:
69    /// - k_val = exp(-||x - y||^2 / (2 h^2))
70    /// - grad_x_k = -k_val * (x - y) / h^2
71    fn eval_with_grad(&self, x: &Array1<f64>, y: &Array1<f64>, h: f64) -> (f64, Array1<f64>) {
72        let diff = x - y;
73        let dist_sq = diff.dot(&diff);
74        let h_sq = h * h;
75        let k_val = (-dist_sq / (2.0 * h_sq)).exp();
76        let grad_x = &diff * (-k_val / h_sq);
77        (k_val, grad_x)
78    }
79}
80
81// ============================================================================
82// Adam Optimizer for SVGD (per-particle)
83// ============================================================================
84
85#[derive(Debug, Clone)]
86struct SvgdAdamState {
87    m: Vec<Array1<f64>>,
88    v: Vec<Array1<f64>>,
89    t: usize,
90    beta1: f64,
91    beta2: f64,
92    epsilon: f64,
93}
94
95impl SvgdAdamState {
96    fn new(n_particles: usize, dim: usize) -> Self {
97        Self {
98            m: vec![Array1::zeros(dim); n_particles],
99            v: vec![Array1::zeros(dim); n_particles],
100            t: 0,
101            beta1: 0.9,
102            beta2: 0.999,
103            epsilon: 1e-8,
104        }
105    }
106
107    /// Compute Adam update for each particle's gradient
108    fn update(&mut self, grads: &[Array1<f64>]) -> Vec<Array1<f64>> {
109        self.t += 1;
110        let n = grads.len();
111        let mut directions = Vec::with_capacity(n);
112
113        for i in 0..n {
114            let dim = grads[i].len();
115            let mut dir = Array1::zeros(dim);
116            for j in 0..dim {
117                self.m[i][j] = self.beta1 * self.m[i][j] + (1.0 - self.beta1) * grads[i][j];
118                self.v[i][j] =
119                    self.beta2 * self.v[i][j] + (1.0 - self.beta2) * grads[i][j] * grads[i][j];
120                let m_hat = self.m[i][j] / (1.0 - self.beta1.powi(self.t as i32));
121                let v_hat = self.v[i][j] / (1.0 - self.beta2.powi(self.t as i32));
122                dir[j] = m_hat / (v_hat.sqrt() + self.epsilon);
123            }
124            directions.push(dir);
125        }
126
127        directions
128    }
129}
130
131// ============================================================================
132// SVGD Configuration
133// ============================================================================
134
135/// Configuration for Stein Variational Gradient Descent
136#[derive(Debug, Clone)]
137pub struct SvgdConfig {
138    /// Number of particles
139    pub num_particles: usize,
140    /// Step size (learning rate)
141    pub step_size: f64,
142    /// Maximum number of iterations
143    pub max_iterations: usize,
144    /// Convergence tolerance on average update norm
145    pub tolerance: f64,
146    /// Kernel bandwidth; None = median heuristic (recommended)
147    pub kernel_bandwidth: Option<f64>,
148    /// Random seed for particle initialization
149    pub seed: u64,
150    /// Initial particle spread (std of initialization distribution)
151    pub init_spread: f64,
152    /// Whether to use Adam optimizer for adaptive step sizes
153    pub use_adam: bool,
154}
155
156impl Default for SvgdConfig {
157    fn default() -> Self {
158        Self {
159            num_particles: 100,
160            step_size: 0.1,
161            max_iterations: 1000,
162            tolerance: 1e-4,
163            kernel_bandwidth: None,
164            seed: 42,
165            init_spread: 1.0,
166            use_adam: true,
167        }
168    }
169}
170
171// ============================================================================
172// SVGD Struct
173// ============================================================================
174
175/// Stein Variational Gradient Descent
176///
177/// A particle-based method that maintains a set of particles {theta_i} and
178/// iteratively transports them to approximate the target posterior.
179///
180/// # Example
181/// ```no_run
182/// use scirs2_stats::variational::{Svgd, SvgdConfig};
183/// use scirs2_core::ndarray::Array1;
184///
185/// let config = SvgdConfig {
186///     num_particles: 50,
187///     step_size: 0.1,
188///     max_iterations: 500,
189///     ..Default::default()
190/// };
191///
192/// let mut svgd = Svgd::new(config);
193/// ```
194#[derive(Debug, Clone)]
195pub struct Svgd {
196    /// Configuration
197    pub config: SvgdConfig,
198    /// Kernel
199    kernel: RbfKernel,
200}
201
202impl Svgd {
203    /// Create a new SVGD instance
204    pub fn new(config: SvgdConfig) -> Self {
205        let kernel = RbfKernel {
206            bandwidth: config.kernel_bandwidth,
207        };
208        Self { config, kernel }
209    }
210
211    /// Initialize particles using quasi-random sequences
212    fn init_particles(&self, dim: usize) -> Vec<Array1<f64>> {
213        let n = self.config.num_particles;
214        let golden = 1.618033988749895_f64;
215        let plastic = 1.324717957244746_f64;
216
217        (0..n)
218            .map(|i| {
219                Array1::from_shape_fn(dim, |d| {
220                    let seed = self.config.seed.wrapping_add(i as u64 * 1000 + d as u64);
221                    let u1 = ((seed as f64 * golden + d as f64 * plastic) % 1.0).abs();
222                    let u2 = ((seed as f64 * plastic + d as f64 * golden + 0.5) % 1.0).abs();
223                    let u1 = u1.max(1e-10).min(1.0 - 1e-10);
224                    let u2 = u2.max(1e-10).min(1.0 - 1e-10);
225                    let r = (-2.0 * u1.ln()).sqrt();
226                    r * (2.0 * PI * u2).cos() * self.config.init_spread
227                })
228            })
229            .collect()
230    }
231
232    /// Compute the SVGD update direction for all particles.
233    ///
234    /// phi*(theta_i) = (1/n) sum_j [k(theta_j, theta_i) * grad log p(theta_j)
235    ///                              + grad_theta_j k(theta_j, theta_i)]
236    fn compute_phi_star<F>(
237        &self,
238        particles: &[Array1<f64>],
239        log_joint: &F,
240        bandwidth: f64,
241    ) -> StatsResult<Vec<Array1<f64>>>
242    where
243        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
244    {
245        let n = particles.len();
246        let dim = particles[0].len();
247
248        // Compute gradients for all particles
249        let mut grad_log_p: Vec<Array1<f64>> = Vec::with_capacity(n);
250        for particle in particles {
251            let (_log_p, grad) = log_joint(particle)?;
252            grad_log_p.push(grad);
253        }
254
255        // Compute phi* for each particle
256        let mut phi_star: Vec<Array1<f64>> = vec![Array1::zeros(dim); n];
257
258        for i in 0..n {
259            for j in 0..n {
260                let (k_val, grad_k_j) =
261                    self.kernel
262                        .eval_with_grad(&particles[j], &particles[i], bandwidth);
263
264                // Attractive term: k(theta_j, theta_i) * grad log p(theta_j)
265                for d in 0..dim {
266                    phi_star[i][d] += k_val * grad_log_p[j][d];
267                }
268
269                // Repulsive term: grad_theta_j k(theta_j, theta_i)
270                // Note: grad_k_j = d k(theta_j, theta_i) / d theta_j
271                for d in 0..dim {
272                    phi_star[i][d] += grad_k_j[d];
273                }
274            }
275
276            // Average over particles
277            phi_star[i] /= n as f64;
278        }
279
280        Ok(phi_star)
281    }
282
283    /// Compute a proxy ELBO estimate for monitoring convergence.
284    /// Uses the kernel density estimate of the entropy plus the average log joint.
285    fn estimate_elbo<F>(
286        &self,
287        particles: &[Array1<f64>],
288        log_joint: &F,
289        bandwidth: f64,
290    ) -> StatsResult<f64>
291    where
292        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
293    {
294        let n = particles.len();
295        let dim = particles[0].len();
296
297        // Average log p(theta_i)
298        let mut avg_log_p = 0.0;
299        for particle in particles {
300            let (log_p, _) = log_joint(particle)?;
301            avg_log_p += log_p;
302        }
303        avg_log_p /= n as f64;
304
305        // Kernel density entropy estimate:
306        // H approx -1/n sum_i log(1/n sum_j k(theta_i, theta_j))
307        let mut entropy_est = 0.0;
308        for i in 0..n {
309            let mut kde_sum = 0.0;
310            for j in 0..n {
311                let diff = &particles[i] - &particles[j];
312                let dist_sq = diff.dot(&diff);
313                kde_sum += (-dist_sq / (2.0 * bandwidth * bandwidth)).exp();
314            }
315            let norm_const = (2.0 * PI * bandwidth * bandwidth).powf(dim as f64 / 2.0);
316            let density = kde_sum / (n as f64 * norm_const);
317            if density > 1e-300 {
318                entropy_est -= density.ln();
319            }
320        }
321        entropy_est /= n as f64;
322
323        Ok(avg_log_p + entropy_est)
324    }
325}
326
327impl VariationalInference for Svgd {
328    fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
329    where
330        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
331    {
332        if dim == 0 {
333            return Err(StatsError::InvalidArgument(
334                "Dimension must be at least 1".to_string(),
335            ));
336        }
337        if self.config.num_particles < 2 {
338            return Err(StatsError::InvalidArgument(
339                "num_particles must be at least 2".to_string(),
340            ));
341        }
342        if self.config.step_size <= 0.0 {
343            return Err(StatsError::InvalidArgument(
344                "step_size must be positive".to_string(),
345            ));
346        }
347
348        let n = self.config.num_particles;
349        let mut particles = self.init_particles(dim);
350        let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
351        let mut converged = false;
352
353        let mut adam = if self.config.use_adam {
354            Some(SvgdAdamState::new(n, dim))
355        } else {
356            None
357        };
358
359        for _iter in 0..self.config.max_iterations {
360            // Determine bandwidth
361            let bandwidth = self
362                .config
363                .kernel_bandwidth
364                .unwrap_or_else(|| RbfKernel::median_bandwidth(&particles));
365
366            // Compute SVGD update directions
367            let phi_star = self.compute_phi_star(&particles, &log_joint, bandwidth)?;
368
369            // Update particles
370            let updates: Vec<Array1<f64>> = if let Some(ref mut adam_state) = adam {
371                let directions = adam_state.update(&phi_star);
372                directions
373                    .into_iter()
374                    .map(|d| &d * self.config.step_size)
375                    .collect()
376            } else {
377                phi_star
378                    .iter()
379                    .map(|phi| phi * self.config.step_size)
380                    .collect()
381            };
382
383            // Compute average update norm for convergence check
384            let avg_update_norm: f64 =
385                updates.iter().map(|u| u.dot(u).sqrt()).sum::<f64>() / n as f64;
386
387            for i in 0..n {
388                particles[i] = &particles[i] + &updates[i];
389            }
390
391            // Estimate ELBO periodically (every 10 iterations to save computation)
392            if _iter % 10 == 0 || _iter == self.config.max_iterations - 1 {
393                let elbo = self.estimate_elbo(&particles, &log_joint, bandwidth)?;
394                elbo_history.push(elbo);
395            }
396
397            // Check convergence
398            if avg_update_norm < self.config.tolerance {
399                converged = true;
400                break;
401            }
402        }
403
404        // Compute posterior statistics from particles
405        let mut mean = Array1::zeros(dim);
406        for p in &particles {
407            mean = &mean + p;
408        }
409        mean /= n as f64;
410
411        let mut var = Array1::zeros(dim);
412        for p in &particles {
413            let diff = p - &mean;
414            var = &var + &(&diff * &diff);
415        }
416        var /= (n - 1) as f64;
417        let std_devs = var.mapv(f64::sqrt);
418
419        Ok(PosteriorResult {
420            means: mean,
421            std_devs,
422            elbo_history,
423            iterations: self.config.max_iterations,
424            converged,
425            samples: Some(particles),
426        })
427    }
428}
429
430// ============================================================================
431// Tests
432// ============================================================================
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    /// Test: SVGD particles converge to a known unimodal Gaussian posterior
439    #[test]
440    fn test_svgd_gaussian_convergence() {
441        let target_mean = 2.0_f64;
442        let target_var = 0.5_f64;
443
444        let config = SvgdConfig {
445            num_particles: 50,
446            step_size: 0.1,
447            max_iterations: 500,
448            tolerance: 1e-5,
449            seed: 42,
450            init_spread: 2.0,
451            use_adam: true,
452            ..Default::default()
453        };
454
455        let mut svgd = Svgd::new(config);
456        let result = svgd
457            .fit(
458                move |theta: &Array1<f64>| {
459                    let x = theta[0];
460                    let log_p = -0.5 * (x - target_mean).powi(2) / target_var;
461                    let grad = Array1::from_vec(vec![-(x - target_mean) / target_var]);
462                    Ok((log_p, grad))
463                },
464                1,
465            )
466            .expect("SVGD should succeed");
467
468        assert!(
469            (result.means[0] - target_mean).abs() < 0.5,
470            "Mean should be near {}, got {}",
471            target_mean,
472            result.means[0]
473        );
474        assert!(
475            result.samples.is_some(),
476            "SVGD should return posterior samples"
477        );
478    }
479
480    /// Test: SVGD on a bimodal target — particles should spread across both modes
481    #[ignore = "slow: SVGD convergence test can exceed timeout"]
482    #[test]
483    fn test_svgd_bimodal() {
484        // Bimodal: 0.5 * N(-3, 0.5) + 0.5 * N(3, 0.5)
485        let config = SvgdConfig {
486            num_particles: 100,
487            step_size: 0.05,
488            max_iterations: 1000,
489            tolerance: 1e-6,
490            seed: 123,
491            init_spread: 5.0,
492            use_adam: true,
493            ..Default::default()
494        };
495
496        let mut svgd = Svgd::new(config);
497        let result = svgd
498            .fit(
499                |theta: &Array1<f64>| {
500                    let x = theta[0];
501                    let var = 0.5;
502                    // log of mixture: log(0.5 * N(x; -3, 0.5) + 0.5 * N(x; 3, 0.5))
503                    let log_comp1 = -0.5 * (x + 3.0).powi(2) / var;
504                    let log_comp2 = -0.5 * (x - 3.0).powi(2) / var;
505                    let max_log = log_comp1.max(log_comp2);
506                    let log_p =
507                        max_log + ((log_comp1 - max_log).exp() + (log_comp2 - max_log).exp()).ln();
508
509                    // Gradient of log mixture
510                    let w1 = (log_comp1 - max_log).exp();
511                    let w2 = (log_comp2 - max_log).exp();
512                    let total = w1 + w2;
513                    let grad_x = (w1 * (-(x + 3.0) / var) + w2 * (-(x - 3.0) / var)) / total;
514                    Ok((log_p, Array1::from_vec(vec![grad_x])))
515                },
516                1,
517            )
518            .expect("SVGD should succeed");
519
520        let samples = result.samples.as_ref().expect("should have samples");
521
522        // Check that particles exist in both modes
523        let left_count = samples.iter().filter(|p| p[0] < 0.0).count();
524        let right_count = samples.iter().filter(|p| p[0] >= 0.0).count();
525        assert!(
526            left_count > 5 && right_count > 5,
527            "Particles should spread across both modes: left={}, right={}",
528            left_count,
529            right_count
530        );
531    }
532
533    /// Test: Repulsive kernel prevents particle collapse — particles should
534    /// not all collapse to the same point even for a peaked target
535    #[test]
536    fn test_svgd_repulsive_prevents_collapse() {
537        let config = SvgdConfig {
538            num_particles: 30,
539            step_size: 0.05,
540            max_iterations: 200,
541            tolerance: 1e-8,
542            seed: 77,
543            init_spread: 2.0,
544            use_adam: true,
545            ..Default::default()
546        };
547
548        let mut svgd = Svgd::new(config);
549        let result = svgd
550            .fit(
551                |theta: &Array1<f64>| {
552                    // Very peaked Gaussian: N(0, 0.01)
553                    let x = theta[0];
554                    let var = 0.01;
555                    let log_p = -0.5 * x * x / var;
556                    let grad = Array1::from_vec(vec![-x / var]);
557                    Ok((log_p, grad))
558                },
559                1,
560            )
561            .expect("SVGD should succeed");
562
563        let samples = result.samples.as_ref().expect("should have samples");
564
565        // Compute variance of particle positions
566        let mean = result.means[0];
567        let var: f64 =
568            samples.iter().map(|p| (p[0] - mean).powi(2)).sum::<f64>() / samples.len() as f64;
569
570        // Particles should NOT all collapse to exactly the same point
571        assert!(
572            var > 1e-10,
573            "Particle variance {} should be nonzero (repulsion prevents collapse)",
574            var
575        );
576    }
577
578    /// Test: SVGD 2D Gaussian — mean and std should be reasonable
579    #[ignore = "slow: SVGD may exceed timeout on slow machines"]
580    #[test]
581    fn test_svgd_2d_gaussian() {
582        let config = SvgdConfig {
583            num_particles: 80,
584            step_size: 0.1,
585            max_iterations: 500,
586            tolerance: 1e-5,
587            seed: 55,
588            init_spread: 3.0,
589            use_adam: true,
590            ..Default::default()
591        };
592
593        let mut svgd = Svgd::new(config);
594        let result = svgd
595            .fit(
596                |theta: &Array1<f64>| {
597                    // N([1, -1], I)
598                    let d0 = theta[0] - 1.0;
599                    let d1 = theta[1] + 1.0;
600                    let log_p = -0.5 * (d0 * d0 + d1 * d1);
601                    let grad = Array1::from_vec(vec![-d0, -d1]);
602                    Ok((log_p, grad))
603                },
604                2,
605            )
606            .expect("SVGD should succeed");
607
608        assert!(
609            (result.means[0] - 1.0).abs() < 1.0,
610            "Mean[0] should be near 1.0, got {}",
611            result.means[0]
612        );
613        assert!(
614            (result.means[1] - (-1.0)).abs() < 1.0,
615            "Mean[1] should be near -1.0, got {}",
616            result.means[1]
617        );
618    }
619
620    /// Test: validation errors
621    #[test]
622    fn test_svgd_validation() {
623        let mut svgd = Svgd::new(SvgdConfig {
624            num_particles: 1, // too few
625            ..Default::default()
626        });
627        let result = svgd.fit(|_: &Array1<f64>| Ok((0.0, Array1::zeros(1))), 1);
628        assert!(result.is_err());
629    }
630
631    /// Test: median bandwidth heuristic produces reasonable values
632    #[test]
633    fn test_median_bandwidth() {
634        let particles = vec![
635            Array1::from_vec(vec![0.0]),
636            Array1::from_vec(vec![1.0]),
637            Array1::from_vec(vec![2.0]),
638            Array1::from_vec(vec![3.0]),
639            Array1::from_vec(vec![4.0]),
640        ];
641        let h = RbfKernel::median_bandwidth(&particles);
642        assert!(h > 0.0, "Bandwidth should be positive");
643        assert!(h < 10.0, "Bandwidth should be reasonable, got {}", h);
644    }
645}