1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
//! Bayesian Neural Network posterior approximations.
//!
//! This module provides **post-hoc** Bayesian approximations for neural
//! network weights — enabling uncertainty quantification without modifying
//! the base training procedure.
//!
//! Two families of methods are provided:
//!
//! ## Laplace Approximation
//!
//! Fits a Gaussian posterior `p(θ|D) ≈ N(θ*, H⁻¹)` centred at the MAP
//! estimate θ* using the curvature (Hessian) of the loss surface.
//!
//! Practical diagonal GGN (Fisher approximation):
//! ```text
//! H_ii ≈ Σₙ (∂lossₙ/∂θᵢ)² (squared gradients)
//! σ²ᵢ = 1 / (H_ii + λ)
//! ```
//!
//! See [`laplace::fit_laplace`] for the high-level API.
//!
//! ## SWAG (Stochastic Weight Averaging Gaussian)
//!
//! Collects SGD weight snapshots and fits a Gaussian with diagonal +
//! low-rank covariance (Maddox et al. 2019):
//! ```text
//! Σ ≈ diag(σ²_diag) / 2 + D̂ D̂ᵀ / (2(C−1))
//! ```
//!
//! See [`swag::SwagCollector`] and [`swag::sample_weights`] for the API.
//!
//! # Example
//!
//! ```rust
//! use scirs2_stats::bayesian_approx::{
//! laplace::fit_laplace, swag::{SwagCollector, sample_weights},
//! types::{LaplaceConfig, SwagConfig},
//! };
//!
//! // --- Laplace ---
//! let map_weights = vec![1.0f64, 0.5];
//! let loss_fn = |w: &[f64]| -> Vec<f64> {
//! vec![(w[0] - 1.0).powi(2), (w[1] - 0.5).powi(2)]
//! };
//! let config = LaplaceConfig::default();
//! let lap = fit_laplace(&map_weights, &loss_fn, &config).expect("laplace");
//! println!("Laplace uncertainty: {:?}", lap.uncertainty);
//!
//! // --- SWAG ---
//! let mut collector = SwagCollector::new(2, 5);
//! for t in 0..20usize {
//! collector.update(&[1.0 + t as f64 * 0.01, 0.5 - t as f64 * 0.005]);
//! }
//! let state = collector.finalize().expect("finalize");
//! let samples = sample_weights(&state, 10, 42).expect("sample");
//! println!("SWAG samples: {}", samples.len());
//! ```
pub use ;
pub use ;
pub use ;