use crate::{DVector, Float};
use nalgebra::Complex;
use rustfft::FftPlanner;
use serde::{Deserialize, Serialize};
pub fn integrated_autocorrelation_times(
samples: Vec<Vec<DVector<Float>>>,
c: Option<Float>,
) -> DVector<Float> {
let c = c.unwrap_or(7.0);
let n_parameters = samples[0][0].len();
let samples: Vec<DVector<Float>> = samples.into_iter().flatten().collect();
let mut n = 1usize;
while n < samples.len() {
n <<= 1;
}
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(2 * n);
let ifft = planner.plan_fft_inverse(2 * n);
DVector::from_iterator(
n_parameters,
(0..n_parameters).map(|i_parameter| {
let x: Vec<Float> = samples.iter().map(|sample| sample[i_parameter]).collect();
let mean = x.iter().sum::<Float>() / x.len() as Float;
let mut input: Vec<Complex<Float>> =
x.iter().map(|&val| Complex::new(val - mean, 0.0)).collect();
input.resize(2 * n, Complex::new(0.0, 0.0));
fft.process(&mut input);
for val in &mut input {
*val *= val.conj();
}
ifft.process(&mut input);
let mut acf: Vec<Float> = input
.iter()
.take(x.len())
.map(|value| value.re / (4.0 * n as Float))
.collect();
if !acf.is_empty() && acf[0] != 0.0 {
let norm_factor = acf[0];
acf.iter_mut().for_each(|value| *value /= norm_factor);
}
let taus: Vec<Float> = acf
.iter()
.scan(0.0, |acc, &value| {
*acc += value;
Some(*acc)
})
.map(|value| Float::mul_add(2.0, value, -1.0))
.collect();
let ind = taus
.iter()
.enumerate()
.position(|(idx, &tau)| (idx as Float) >= c * tau)
.unwrap_or(taus.len() - 1);
taus[ind]
}),
)
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct MCMCDiagnostics {
pub r_hat: DVector<Float>,
pub ess: DVector<Float>,
pub acceptance_rates: DVector<Float>,
pub mean_acceptance_rate: Float,
}
fn split_chains(chains: &[Vec<DVector<Float>>]) -> Vec<Vec<DVector<Float>>> {
let mut split = Vec::new();
for chain in chains {
let half = chain.len() / 2;
if half >= 2 {
split.push(chain[..half].to_vec());
split.push(chain[chain.len() - half..].to_vec());
} else if chain.len() >= 2 {
split.push(chain.clone());
}
}
split
}
fn sample_variance(values: &[Float]) -> Float {
if values.len() <= 1 {
return 0.0;
}
let mean = values.iter().sum::<Float>() / values.len() as Float;
values
.iter()
.map(|value| (value - mean).powi(2))
.sum::<Float>()
/ (values.len() - 1) as Float
}
pub(crate) fn split_r_hat(chains: &[Vec<DVector<Float>>]) -> DVector<Float> {
let split = split_chains(chains);
if split.is_empty() || split[0].is_empty() {
return DVector::zeros(0);
}
let n_chains = split.len();
let n_samples = split[0].len();
let n_params = split[0][0].len();
DVector::from_iterator(
n_params,
(0..n_params).map(|param| {
let means: Vec<Float> = split
.iter()
.map(|chain| {
chain.iter().map(|sample| sample[param]).sum::<Float>() / n_samples as Float
})
.collect();
let variances: Vec<Float> = split
.iter()
.map(|chain| {
let vals: Vec<Float> = chain.iter().map(|sample| sample[param]).collect();
sample_variance(&vals)
})
.collect();
let w = variances.iter().sum::<Float>() / n_chains as Float;
let b = if n_chains > 1 {
n_samples as Float * sample_variance(&means)
} else {
0.0
};
if w <= Float::EPSILON {
1.0
} else {
let var_hat = ((n_samples - 1) as Float / n_samples as Float)
.mul_add(w, b / n_samples as Float);
(var_hat / w).sqrt().max(1.0)
}
}),
)
}
pub(crate) fn effective_sample_size(chains: &[Vec<DVector<Float>>]) -> DVector<Float> {
if chains.is_empty() || chains[0].is_empty() {
return DVector::zeros(0);
}
let total_samples = chains.iter().map(Vec::len).sum::<usize>() as Float;
integrated_autocorrelation_times(chains.to_vec(), None).map(|tau| {
if !tau.is_finite() || tau <= 0.0 {
total_samples
} else {
total_samples / tau
}
})
}
pub(crate) fn acceptance_rates(chains: &[Vec<DVector<Float>>]) -> DVector<Float> {
DVector::from_iterator(
chains.len(),
chains.iter().map(|chain| {
if chain.len() <= 1 {
return 0.0;
}
let accepted = chain.windows(2).filter(|pair| pair[0] != pair[1]).count();
accepted as Float / (chain.len() - 1) as Float
}),
)
}
pub(crate) fn diagnostics_from_chain(chains: &[Vec<DVector<Float>>]) -> MCMCDiagnostics {
let acceptance_rates = acceptance_rates(chains);
let mean_acceptance_rate = if acceptance_rates.is_empty() {
0.0
} else {
acceptance_rates.iter().sum::<Float>() / acceptance_rates.len() as Float
};
MCMCDiagnostics {
r_hat: split_r_hat(chains),
ess: effective_sample_size(chains),
acceptance_rates,
mean_acceptance_rate,
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::dvector;
#[test]
fn test_split_r_hat_is_one_for_identical_split_chains() {
let chains = vec![
vec![dvector![0.0], dvector![1.0], dvector![0.0], dvector![1.0]],
vec![dvector![0.0], dvector![1.0], dvector![0.0], dvector![1.0]],
];
let r_hat = split_r_hat(&chains);
assert_eq!(r_hat.len(), 1);
assert!((r_hat[0] - 1.0).abs() < 1e-12);
}
#[test]
fn test_acceptance_rates_detect_repeated_samples() {
let chains = vec![
vec![dvector![0.0], dvector![0.0], dvector![1.0], dvector![1.0]],
vec![dvector![0.0], dvector![1.0], dvector![2.0], dvector![3.0]],
];
let rates = acceptance_rates(&chains);
assert!((rates[0] - (1.0 / 3.0)).abs() < 1e-12);
assert!((rates[1] - 1.0).abs() < 1e-12);
}
#[test]
fn test_effective_sample_size_is_positive() {
let chains = vec![
vec![dvector![0.0], dvector![1.0], dvector![0.0], dvector![1.0]],
vec![dvector![1.0], dvector![0.0], dvector![1.0], dvector![0.0]],
];
let ess = effective_sample_size(&chains);
assert_eq!(ess.len(), 1);
assert!(ess[0] > 0.0);
}
}