use burn::prelude::{Backend, Tensor};
use core::fmt;
use ndarray::{concatenate, prelude::*, stack};
use ndarray_stats::QuantileExt;
use num_traits::{Num, ToPrimitive};
use rayon::prelude::*;
use rustfft::{num_complex::Complex, FftPlanner};
use std::{cmp::Ordering, error::Error};
const ALPHA: f32 = 0.01;
#[derive(Debug, Clone, PartialEq)]
pub struct ChainTracker {
n_params: usize,
n: u64,
p_accept: f32,
last_state: Array1<f32>,
mean: Array1<f32>, mean_sq: Array1<f32>, }
#[derive(Debug, Clone, PartialEq)]
pub struct ChainStats {
pub n: u64,
pub p_accept: f32,
pub mean: Array1<f32>, pub sm2: Array1<f32>, }
impl ChainTracker {
pub fn new<T>(n_params: usize, initial_state: &[T]) -> Self
where
T: num_traits::ToPrimitive + Clone,
{
let mean_sq = Array1::<f32>::zeros(n_params);
let mean = Array1::<f32>::zeros(n_params);
let last_state = ArrayView1::from_shape(n_params, initial_state)
.expect("Expected being able to convert initial state to a NdArray")
.mapv(|x| {
x.to_f32()
.expect("Expected conversion of elements to f32's to succeed")
});
Self {
n_params,
n: 0,
p_accept: -1.0,
last_state,
mean,
mean_sq,
}
}
pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
where
T: num_traits::ToPrimitive + Clone,
{
self.n += 1;
let n = self.n as f32;
let x_arr =
ndarray::ArrayView1::<T>::from_shape(self.n_params, x)?.mapv(|x| x.to_f32().unwrap());
self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
if self.n == 1 {
self.mean_sq = x_arr.pow2();
} else {
self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
};
let p_start = if self.p_accept >= 0.0 {
self.p_accept
} else {
x_arr
.index_axis(Axis(0), 0)
.ne(&self.last_state.index_axis(Axis(0), 0)) as i32 as f32
};
self.p_accept = ndarray::Zip::from(x_arr.rows())
.and(self.last_state.rows())
.fold(p_start, |p_accept, a, b| {
let accepted = (a.ne(&b) as i32) as f32;
(1.0 - ALPHA) * p_accept + ALPHA * accepted
});
self.last_state = x_arr;
Ok(())
}
pub fn stats(&self) -> ChainStats {
let n = self.n as f32;
ChainStats {
n: self.n,
p_accept: self.p_accept,
mean: self.mean.clone(),
sm2: (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0),
}
}
}
pub fn collect_rhat(chain_stats: &[&ChainStats]) -> Array1<f32> {
let (within, var) = withinvar_from_cs(chain_stats);
(var / within).sqrt()
}
fn withinvar_from_cs(chain_stats: &[&ChainStats]) -> (Array1<f32>, Array1<f32>) {
let means: Vec<ArrayView1<f32>> = chain_stats.iter().map(|x| x.mean.view()).collect();
let means = ndarray::stack(Axis(0), &means).expect("Expected stacking means to succeed");
let sm2s: Vec<ArrayView1<f32>> = chain_stats.iter().map(|x| x.sm2.view()).collect();
let sm2s = ndarray::stack(Axis(0), &sm2s).expect("Expected stacking sm2 arrays to succeed");
let within = sm2s
.mean_axis(Axis(0))
.expect("Expected computing within-chain variances to succeed");
let global_means = means
.mean_axis(Axis(0))
.expect("Expected computing global means to succeed");
let diffs: Array2<f32> = (means.clone()
- global_means
.broadcast(means.shape())
.expect("Expected broadcasting to succeed"))
.into_dimensionality()
.expect("Expected casting dimensionality to Array1 to succeed");
let between = diffs.pow2().sum_axis(Axis(0)) / (diffs.len() - 1) as f32;
let n: f32 = chain_stats.iter().map(|x| x.n as f32).sum::<f32>() / chain_stats.len() as f32;
let var = between + within.clone() * ((n - 1.0) / n);
(within, var)
}
#[derive(Debug, Clone, PartialEq)]
pub struct MultiChainTracker {
n: usize,
pub p_accept: f32,
last_state: Array2<f32>,
mean: Array2<f32>, mean_sq: Array2<f32>, n_chains: usize,
n_params: usize,
}
impl MultiChainTracker {
pub fn new(n_chains: usize, n_params: usize) -> Self {
let mean_sq = Array2::<f32>::zeros((n_chains, n_params));
Self {
n: 0,
p_accept: 0.0,
last_state: Array2::<f32>::zeros((n_chains, n_params)),
mean: Array2::<f32>::zeros((n_chains, n_params)),
mean_sq,
n_chains,
n_params,
}
}
pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
where
T: Num
+ num_traits::ToPrimitive
+ num_traits::FromPrimitive
+ std::clone::Clone
+ std::cmp::PartialOrd,
{
self.n += 1;
let n = self.n as f32;
let x_arr = ndarray::ArrayView2::<T>::from_shape((self.n_chains, self.n_params), x)?
.mapv(|x| x.to_f32().unwrap());
self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
if self.n == 1 {
self.mean_sq = x_arr.pow2();
} else {
self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
};
self.p_accept = ndarray::Zip::from(x_arr.rows())
.and(self.last_state.rows())
.fold(self.p_accept, |p_accept, a, b| {
let accepted = (a.ne(&b) as i32) as f32;
(1.0 - ALPHA) * p_accept + ALPHA * accepted
});
self.last_state = x_arr;
Ok(())
}
pub fn stats<B: Backend>(&self, sample: Tensor<B, 3>) -> Result<RunStats, Box<dyn Error>> {
let sample_data = sample.to_data();
let sample_ndarray =
ArrayView3::from_shape(sample.dims(), sample_data.as_slice().unwrap())?;
Ok(RunStats::from_f32_view(sample_ndarray))
}
pub fn max_rhat(&self) -> Result<f32, Box<dyn Error>> {
let all: Array1<f32> = self.rhat()?;
let max = *all.max()?;
Ok(max)
}
pub fn rhat(&self) -> Result<Array1<f32>, Box<dyn Error>> {
let (within, var) = self.within_and_var()?;
let rhat = (var / within).sqrt();
Ok(rhat)
}
fn within_and_var(&self) -> Result<(Array1<f32>, Array1<f32>), Box<dyn Error>> {
let mean_chain = self
.mean
.mean_axis(Axis(0))
.ok_or("Mean reduction across chains for mean failed.")?;
let n_chains = self.mean.shape()[0] as f32;
let n = self.n as f32;
let fac = n / (n_chains - 1.0);
let between = (self.mean.clone() - mean_chain.insert_axis(Axis(0)))
.pow2()
.sum_axis(Axis(0))
* fac;
let sm2 = (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0);
let within = sm2
.mean_axis(Axis(0))
.ok_or("Mean reduction across chains for mean of squares failed.")?;
let var = within.clone() * ((n - 1.0) / n) + between * (1.0 / n);
Ok((within, var))
}
}
pub fn basic_stats(name: &str, mut data: Array1<f32>) -> BasicStats {
data.as_slice_mut()
.unwrap()
.sort_by(|a, b| match b.partial_cmp(a) {
Some(x) => x,
None => Ordering::Equal,
});
let (min, median, max) = (
*data
.last()
.expect("Expected getting first element from ess array succeed"),
data[data.len() / 2],
*data
.first()
.expect("Expected getting last element from ess array succeed"),
);
let mean = data.mean().expect("Expected computing mean ess to succeed");
let std = data.std(1.0);
BasicStats {
name: name.to_string(),
min,
median,
max,
mean,
std,
}
}
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct RunStats {
pub ess: BasicStats,
pub rhat: BasicStats,
}
impl RunStats {
fn from_f32_view(sample: ArrayView3<f32>) -> Self {
let (rhat, ess) = split_rhat_mean_ess(sample);
let ess = basic_stats("ESS", ess);
let rhat = basic_stats("Split R-hat", rhat);
RunStats { ess, rhat }
}
}
impl fmt::Display for RunStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}\n{}", self.ess, self.rhat)
}
}
impl<T> From<ArrayView3<'_, T>> for RunStats
where
T: ToPrimitive + std::clone::Clone,
{
fn from(sample: ArrayView3<T>) -> Self {
let f32_sample = sample.mapv(|x| x.to_f32().unwrap());
let (rhat, ess) = split_rhat_mean_ess(f32_sample.view());
let ess = basic_stats("ESS", ess);
let rhat = basic_stats("Split R-hat", rhat);
RunStats { ess, rhat }
}
}
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct BasicStats {
pub name: String,
pub min: f32,
pub median: f32,
pub max: f32,
pub mean: f32,
pub std: f32,
}
impl fmt::Display for BasicStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} in [{:.2}, {:.2}], median: {:.2}, mean: {:.2} ± {:.2}",
self.name, self.min, self.max, self.median, self.mean, self.std
)
}
}
fn splitcat(sample: ArrayView3<f32>) -> Array3<f32> {
let n = sample.shape()[1];
let half = (n / 2) as i32;
let half_1 = sample.slice(s![.., ..half, ..]);
let half_2 = sample.slice(s![.., -half.., ..]);
concatenate(Axis(0), &[half_1, half_2]).expect("Expected stacking two halves to succeed")
}
pub fn split_rhat_mean_ess(sample: ArrayView3<f32>) -> (Array1<f32>, Array1<f32>) {
let splitted = splitcat(sample); let (within, var) = withinvar(splitted.view());
(
rhat(within.view(), var.view()),
ess(splitted.view(), within.view(), var.view()),
)
}
fn rhat(within: ArrayView1<f32>, var: ArrayView1<f32>) -> Array1<f32> {
(within.to_owned() / var).sqrt()
}
fn withinvar(sample: ArrayView3<f32>) -> (Array1<f32>, Array1<f32>) {
let c = sample.shape()[0];
let n = sample.shape()[1];
let p = sample.shape()[2];
let (within, var): (Vec<f32>, Vec<f32>) = (0..p)
.into_par_iter()
.map(|param_idx| {
let data_p = sample.slice(s![.., .., param_idx]);
let chain_means = data_p.mean_axis(Axis(1)).unwrap();
let overall_mean = chain_means.mean().unwrap();
let diff = &chain_means - overall_mean;
let b = diff.pow2().sum() * ((n as f32) / ((c - 1) as f32));
let mut squares = Vec::with_capacity(c);
for chain_i in 0..c {
let row = data_p.slice(s![chain_i, ..]);
let cm = chain_means[chain_i];
let sq = row.iter().map(|v| (v - cm) * (v - cm)).sum::<f32>() / (n as f32);
squares.push(sq);
}
let squares = Array1::from(squares); let w = squares.mean().unwrap(); let v = ((n as f32 - 1.0) / (n as f32)) * w + b / (n as f32);
(w, v)
})
.collect::<Vec<(f32, f32)>>()
.into_iter()
.fold((vec![], vec![]), |(mut within, mut var), (w, v)| {
within.push(w);
var.push(v);
(within, var)
});
(Array1::from_vec(within), Array1::from_vec(var))
}
fn ess(sample: ArrayView3<f32>, within: ArrayView1<f32>, var: ArrayView1<f32>) -> Array1<f32> {
let shape = sample.shape();
let (n_chains, n_steps, n_params) = (shape[0], shape[1], shape[2]);
let chain_rho: Vec<Array2<f32>> = (0..n_chains)
.map(|c| {
let chain_sample = sample.index_axis(Axis(0), c);
autocov(chain_sample)
})
.collect();
let chain_rho: Vec<ArrayView2<f32>> = chain_rho.iter().map(|x| x.view()).collect();
let chain_rho = stack(Axis(0), &chain_rho)
.expect("Expected stacking chain-specific autocovariance matrices to succeed");
let avg_rho = chain_rho.mean_axis(Axis(0)).unwrap();
let diff = -avg_rho
+ within
.broadcast((n_steps, n_params))
.expect("Expected broadcasting to succeed");
let rho = -(diff
/ var
.broadcast((n_steps, n_params))
.expect("Expected broadcasting to succeed"))
+ 1.0;
let tau: Vec<f32> = (0..n_params)
.into_par_iter()
.map(|d| {
let rho_d = rho.index_axis(Axis(1), d).to_owned();
let mut min = if rho_d.len() >= 2 {
rho_d[[0]] + rho_d[[1]]
} else {
0.0
};
let mut out = 0.0;
for rho_t in rho_d.windows_with_stride(2, 2) {
let mut p_t = rho_t[0] + rho_t[1];
if p_t <= 0.0 {
break;
}
if p_t > min {
p_t = min;
}
min = p_t;
out += p_t;
}
-1.0 + 2.0 * out
})
.collect();
let tau = Array1::from_vec(tau);
tau.recip() * n_chains as f32 * n_steps as f32
}
fn autocov(sample: ArrayView2<f32>) -> Array2<f32> {
if sample.nrows() <= 100 {
autocov_bf(sample)
} else {
autocov_fft(sample)
}
}
fn autocov_fft(sample: ArrayView2<f32>) -> Array2<f32> {
let (n, d) = (sample.shape()[0], sample.shape()[1]);
let mut planner = FftPlanner::new();
let mut n_padded = 1;
while n_padded < 2 * n - 1 {
n_padded <<= 1;
}
let fft = planner.plan_fft_forward(n_padded);
let ffti = planner.plan_fft_inverse(n_padded);
let out: Vec<f32> = sample
.axis_iter(Axis(1))
.into_par_iter()
.map(|traj| {
let traj_mean = traj.sum() / traj.len() as f32;
let mut x: Vec<Complex<f32>> = traj
.iter()
.map(|xi| Complex {
re: (*xi - traj_mean),
im: 0.0f32,
})
.chain(
[Complex {
re: 0.0f32,
im: 0.0f32,
}]
.repeat(n_padded - n),
)
.collect();
fft.process(x.as_mut_slice());
x.iter_mut().for_each(|xi| {
*xi *= xi.conj();
});
ffti.process(x.as_mut_slice());
x.iter_mut()
.take(n)
.map(|xi| xi.re / n_padded as f32 / n as f32) .collect::<Vec<f32>>()
})
.flatten_iter()
.collect();
let out = Array2::from_shape_vec((d, n), out).expect("Expected creating dxn array to succeed");
out.t().to_owned()
}
fn autocov_bf(data: ArrayView2<f32>) -> Array2<f32> {
let (n, d) = data.dim();
let mut out = Array2::<f32>::zeros((n, d));
out.axis_iter_mut(Axis(1)) .into_par_iter() .enumerate() .for_each(|(col_idx, mut out_col)| {
let col_data = data.column(col_idx);
let col_data = col_data.to_owned() - col_data.mean().unwrap();
for lag in 0..n {
let mut sum_lag = 0.0;
for t in 0..(n - lag) {
sum_lag += col_data[t] * col_data[t + lag];
}
out_col[lag] = sum_lag / n as f32;
}
});
out
}
pub fn ess_from_chainstats(sample: ArrayView3<f32>, chain_stats: &[&ChainStats]) -> Array1<f32> {
let (within, var) = withinvar_from_cs(chain_stats);
ess(sample, within.view(), var.view())
}
#[cfg(test)]
mod tests {
use std::io::Write;
use std::{f32, fs::File, time::Instant};
use approx::assert_abs_diff_eq;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::*;
fn run_rhat_test_generic<T>(data0: Array2<T>, data1: Array2<T>, expected: Array1<f32>, tol: f32)
where
T: ndarray::NdFloat + num_traits::FromPrimitive,
{
let mut psr = MultiChainTracker::new(3, 4);
psr.step(data0.as_slice().unwrap()).unwrap();
psr.step(data1.as_slice().unwrap()).unwrap();
let rhat = psr.rhat().unwrap();
let diff = *(rhat.clone() - expected.clone()).abs().max().unwrap();
assert!(
diff < tol,
"Mismatch in Rhat. Got {:?}, expected {:?}, diff = {:?}",
rhat,
expected,
diff
);
}
#[test]
fn test_rhat_f32_1() {
let data_step_0: Array2<f32> = arr2(&[
[0.0, 1.0, 0.0, 1.0], [1.0, 2.0, 0.0, 2.0], [0.0, 0.0, 0.0, 2.0], ]);
let data_step_1: Array2<f32> = arr2(&[
[1.0, 2.0, 2.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0], ]);
let expected = array![f32::consts::SQRT_2, 1.080_123_4, 0.894_427_3, 0.8660254];
run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
}
#[test]
fn test_rhat_f64_1() {
let data_step_0: Array2<f64> = arr2(&[
[0.0, 1.0, 0.0, 1.0], [1.0, 2.0, 0.0, 2.0], [0.0, 0.0, 0.0, 2.0], ]);
let data_step_1: Array2<f64> = arr2(&[
[1.0, 2.0, 2.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0], ]);
let expected = array![f32::consts::SQRT_2, 1.0801234, 0.8944271, 0.8660254];
run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
}
#[test]
fn test_rhat_f64_2() {
let data_step_0 = arr2(&[
[1.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 2.0],
]);
let data_step_1 = arr2(&[
[1.0, 2.0, 0.0, 2.0],
[1.0, 2.0, 0.0, 0.0],
[2.0, 0.0, 1.0, 2.0],
]);
let expected = array![f32::consts::FRAC_1_SQRT_2, 0.74535599, 1.0, 1.5];
run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
}
fn run_test_case(
autocov_func: &dyn Fn(ArrayView2<f32>) -> Array2<f32>,
data: &Array2<f32>,
expected: &Array2<f32>,
test_name: &str,
) {
let result = autocov_func(data.view());
assert_eq!(
result.dim(),
expected.dim(),
"{}: shape mismatch; got {:?}, expected {:?}",
test_name,
result.dim(),
expected.dim()
);
assert_abs_diff_eq!(result, *expected, epsilon = 1e-6);
println!("Test: {test_name} succeeded");
}
#[test]
fn test_single_param() {
let data = array![[1.0], [2.0], [3.0], [4.0],];
let expected = array![[1.25], [0.3125], [-0.375], [-0.5625]];
run_test_case(&autocov_bf, &data, &expected, "BF: single_param_small");
println!("Doing FFT test");
run_test_case(&autocov_fft, &data, &expected, "FFT: single_param_small");
println!("FFT test succeeded");
}
#[test]
fn test_two_params_1() {
let data = array![[1.0, 0.3], [2.0, 2.0], [3.0, -2.0], [4.0, 5.0],];
let expected = array![
[1.25, 6.516875],
[0.3125, -3.7889063],
[-0.375, 1.4721875],
[-0.5625, -0.94171875],
];
run_test_case(&autocov_bf, &data, &expected, "BF: two_params_small");
run_test_case(&autocov_fft, &data, &expected, "FFT: two_params_small");
}
#[test]
fn ess_1() {
let m = 4;
let n = 1000;
let mut data = Array2::<f32>::zeros((m, n));
let mut rng = SmallRng::seed_from_u64(42);
for mut row in data.rows_mut() {
for elem in row.iter_mut() {
*elem = rng.random::<f32>(); }
}
let data = data
.to_shape((data.shape()[0], data.shape()[1], 1))
.unwrap();
let run_stats = RunStats::from(data.view());
println!("Samples: {}\n{run_stats}", m * n);
assert!(run_stats.ess.min > 3800.0);
assert!(run_stats.rhat.max < 1.01);
}
#[test]
#[ignore = "Benchmark test: run only when explicitly requested"]
fn test_autocov_perf_comp() {
let mut file =
File::create("runtime_results.csv").expect("Unable to create runtime_results.csv");
writeln!(file, "length,rep,time,algorithm").expect("Unable to write CSV header");
let mut rng = rand::rng();
for exp in 0..10 {
let n = 1 << exp; for rep in 1..=10 {
let sample_data: Vec<f32> = (0..n * 1000).map(|_| rng.random()).collect();
let sample = Array2::from_shape_vec((n, 1000), sample_data)
.expect("Failed to create Array2");
let start_fft = Instant::now();
autocov_fft(sample.view());
let fft_time = start_fft.elapsed().as_nanos();
let start_brute = Instant::now();
autocov_bf(sample.view());
let brute_time = start_brute.elapsed().as_nanos();
writeln!(file, "{},{},{},fft", n, rep, fft_time)
.expect("Unable to write test results to CSV");
writeln!(file, "{},{},{},brute force", n, rep, brute_time)
.expect("Unable to write test results to CSV");
println!(
"Length: {} | Rep: {} | FFT: {} ns | Brute: {} ns",
n, rep, fft_time, brute_time
);
}
}
}
}