use scirs2_core::ndarray::{Array1, Array2};
use crate::error::{StatsError, StatsResult};
use super::types::{BNNConfig, BNNPosterior, CovarianceType, PredictiveDistribution};
fn xorshift64(state: &mut u64) -> u64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
x
}
fn randn(state: &mut u64) -> f64 {
let u1 = (xorshift64(state) as f64) / (u64::MAX as f64);
let u2 = (xorshift64(state) as f64) / (u64::MAX as f64);
let u1_clamped = u1.max(1e-300); (-2.0 * u1_clamped.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
#[derive(Debug, Clone)]
pub struct SWAGCollector {
mean: Array1<f64>,
sq_mean: Array1<f64>,
deviations: Vec<Array1<f64>>,
n_collected: usize,
config: BNNConfig,
n_params: usize,
}
impl SWAGCollector {
pub fn new(n_params: usize, config: &BNNConfig) -> Self {
Self {
mean: Array1::zeros(n_params),
sq_mean: Array1::zeros(n_params),
deviations: Vec::new(),
n_collected: 0,
config: config.clone(),
n_params,
}
}
pub fn collect(&mut self, weights: &Array1<f64>) {
let n = self.n_collected as f64;
let n1 = n + 1.0;
self.mean = &self.mean * (n / n1) + &(weights * (1.0 / n1));
let w_sq = weights.mapv(|w| w * w);
self.sq_mean = &self.sq_mean * (n / n1) + &(&w_sq * (1.0 / n1));
self.n_collected += 1;
if self.deviations.len() < self.config.swag_rank {
self.deviations.push(weights.clone());
} else {
self.deviations.remove(0);
self.deviations.push(weights.clone());
}
}
pub fn build_posterior(&self) -> StatsResult<BNNPosterior> {
if self.n_collected < 2 {
return Err(StatsError::invalid_argument(
"SWAG requires at least 2 weight snapshots",
));
}
let diag_var = &self.sq_mean - &self.mean.mapv(|m| m * m);
let diag_var = diag_var.mapv(|v| v.max(0.0));
let k = self.deviations.len();
let mut deviation = Array2::<f64>::zeros((self.n_params, k));
for (col_idx, snapshot) in self.deviations.iter().enumerate() {
for row_idx in 0..self.n_params {
deviation[[row_idx, col_idx]] = snapshot[row_idx] - self.mean[row_idx];
}
}
let covariance_type = CovarianceType::LowRankPlusDiagonal {
d_diag: diag_var,
deviation,
};
let log_marginal = 0.0;
Ok(BNNPosterior {
mean: self.mean.clone(),
covariance_type,
log_marginal_likelihood: log_marginal,
})
}
pub fn sample_weights(&self, rng_state: &mut u64) -> StatsResult<Array1<f64>> {
if self.n_collected < 2 {
return Err(StatsError::invalid_argument(
"SWAG requires at least 2 snapshots to sample",
));
}
let diag_var = &self.sq_mean - &self.mean.mapv(|m| m * m);
let diag_std = diag_var.mapv(|v| v.max(0.0).sqrt());
let z1: Array1<f64> = Array1::from_shape_fn(self.n_params, |_| randn(rng_state));
let diag_part = &diag_std * &z1;
let k = self.deviations.len();
let k_minus_1 = if k > 1 { (k - 1) as f64 } else { 1.0 };
let mut lr_part = Array1::zeros(self.n_params);
if k > 0 {
let z2: Array1<f64> = Array1::from_shape_fn(k, |_| randn(rng_state));
for (col_idx, snapshot) in self.deviations.iter().enumerate() {
let dev = snapshot - &self.mean;
lr_part = lr_part + &(&dev * z2[col_idx]);
}
lr_part /= k_minus_1.sqrt();
}
let scale = 1.0 / 2.0_f64.sqrt();
let sample = &self.mean + &((&diag_part + &lr_part) * scale);
Ok(sample)
}
pub fn n_collected(&self) -> usize {
self.n_collected
}
pub fn mean(&self) -> &Array1<f64> {
&self.mean
}
}
pub fn multi_swag_predict(
models: &[SWAGCollector],
predict_fn: &dyn Fn(&Array1<f64>) -> StatsResult<Array1<f64>>,
n_samples_per_model: usize,
rng_state: &mut u64,
) -> StatsResult<PredictiveDistribution> {
if models.is_empty() {
return Err(StatsError::invalid_argument("Need at least one SWAG model"));
}
if n_samples_per_model == 0 {
return Err(StatsError::invalid_argument(
"Need at least 1 sample per model",
));
}
let total_samples = models.len() * n_samples_per_model;
let mut all_predictions: Vec<Array1<f64>> = Vec::with_capacity(total_samples);
for model in models {
for _ in 0..n_samples_per_model {
let w = model.sample_weights(rng_state)?;
let pred = predict_fn(&w)?;
all_predictions.push(pred);
}
}
let n_outputs = all_predictions[0].len();
let n_total = all_predictions.len();
let mut mean = Array1::zeros(n_outputs);
for p in &all_predictions {
mean = mean + p;
}
mean /= n_total as f64;
let mut variance = Array1::zeros(n_outputs);
for p in &all_predictions {
let diff = p - &mean;
variance = variance + &diff.mapv(|d| d * d);
}
variance /= n_total as f64;
let mut samples = Array2::zeros((n_total, n_outputs));
for (i, p) in all_predictions.iter().enumerate() {
for j in 0..n_outputs {
samples[[i, j]] = p[j];
}
}
Ok(PredictiveDistribution {
mean,
variance,
samples: Some(samples),
})
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config() -> BNNConfig {
BNNConfig {
swag_rank: 5,
..BNNConfig::default()
}
}
#[test]
fn test_swag_collector_mean_converges() {
let config = make_config();
let mut collector = SWAGCollector::new(3, &config);
let target = array![1.0, 2.0, 3.0];
let mut rng: u64 = 42;
for _ in 0..100 {
let noise = Array1::from_shape_fn(3, |_| randn(&mut rng) * 0.01);
collector.collect(&(&target + &noise));
}
let mean = collector.mean();
for i in 0..3 {
assert!(
(mean[i] - target[i]).abs() < 0.1,
"Mean[{}] = {}, expected ~{}",
i,
mean[i],
target[i]
);
}
}
use scirs2_core::ndarray::array;
#[test]
fn test_swag_deviations_stored() {
let config = BNNConfig {
swag_rank: 3,
..BNNConfig::default()
};
let mut collector = SWAGCollector::new(2, &config);
collector.collect(&array![1.0, 2.0]);
collector.collect(&array![3.0, 4.0]);
collector.collect(&array![5.0, 6.0]);
assert_eq!(collector.deviations.len(), 3);
collector.collect(&array![7.0, 8.0]);
assert_eq!(collector.deviations.len(), 3);
}
#[test]
fn test_swag_build_posterior() {
let config = make_config();
let mut collector = SWAGCollector::new(2, &config);
collector.collect(&array![1.0, 2.0]);
collector.collect(&array![3.0, 4.0]);
collector.collect(&array![5.0, 6.0]);
let posterior = collector.build_posterior().expect("build posterior");
assert_eq!(posterior.mean.len(), 2);
match &posterior.covariance_type {
CovarianceType::LowRankPlusDiagonal { d_diag, deviation } => {
assert_eq!(d_diag.len(), 2);
assert_eq!(deviation.nrows(), 2);
assert_eq!(deviation.ncols(), 3);
for &v in d_diag.iter() {
assert!(v >= 0.0, "Diagonal variance should be >= 0, got {}", v);
}
}
_ => panic!("Expected LowRankPlusDiagonal covariance"),
}
}
#[test]
fn test_swag_sample_correct_dimension() {
let config = make_config();
let mut collector = SWAGCollector::new(4, &config);
for i in 0..5 {
let w = Array1::from_shape_fn(4, |j| (i * 4 + j) as f64);
collector.collect(&w);
}
let mut rng: u64 = 123;
let sample = collector.sample_weights(&mut rng).expect("sample");
assert_eq!(sample.len(), 4);
}
#[test]
fn test_swag_insufficient_snapshots() {
let config = make_config();
let mut collector = SWAGCollector::new(2, &config);
collector.collect(&array![1.0, 2.0]);
assert!(collector.build_posterior().is_err());
let mut rng: u64 = 1;
assert!(collector.sample_weights(&mut rng).is_err());
}
#[test]
fn test_multi_swag_predict() {
let config = make_config();
let mut c1 = SWAGCollector::new(2, &config);
let mut c2 = SWAGCollector::new(2, &config);
for i in 0..5 {
let w = array![i as f64, (i as f64) * 0.5];
c1.collect(&w);
c2.collect(&(&w + &array![0.1, 0.1]));
}
let predict_fn = |w: &Array1<f64>| -> StatsResult<Array1<f64>> { Ok(array![w[0] + w[1]]) };
let mut rng: u64 = 42;
let result = multi_swag_predict(&[c1, c2], &predict_fn, 5, &mut rng).expect("multi swag");
assert_eq!(result.mean.len(), 1);
assert_eq!(result.variance.len(), 1);
assert!(result.samples.is_some());
let samples = result.samples.as_ref().expect("samples should exist");
assert_eq!(samples.nrows(), 10); }
#[test]
fn test_multi_swag_empty_models() {
let models: Vec<SWAGCollector> = vec![];
let predict_fn = |_w: &Array1<f64>| -> StatsResult<Array1<f64>> { Ok(array![0.0]) };
let mut rng: u64 = 1;
assert!(multi_swag_predict(&models, &predict_fn, 5, &mut rng).is_err());
}
}