mini_mcmc/
stats.rs

1//! Computation and tracking of MCMC statistics like acceptance probability and Potential Scale
2//! Reduction.
3
4use ndarray::prelude::*;
5use ndarray_stats::QuantileExt;
6use num_traits::Num;
7use std::{collections::VecDeque, error::Error};
8
9#[derive(Debug, Clone, PartialEq)]
10pub struct ChainTracker<T> {
11    n_params: usize,
12    n: u64,
13    p_accept: f32,
14    mean: Array1<f32>,    // n_params
15    mean_sq: Array1<f32>, // n_params
16    last_state: Vec<T>,
17    accept_queue: VecDeque<bool>,
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct ChainStats {
22    pub n: u64,
23    pub p_accept: f32,
24    pub mean: Array1<f32>, // n_params
25    pub sm2: Array1<f32>,  // n_params
26}
27
28impl<T: Clone + Copy + PartialEq> ChainTracker<T> {
29    pub fn new(n_params: usize, initial_state: &[T]) -> Self {
30        let mean_sq = Array1::<f32>::zeros(n_params);
31        let mean = Array1::<f32>::zeros(n_params);
32        let accept_queue = VecDeque::new();
33        Self {
34            n_params,
35            n: 0,
36            p_accept: 0.0,
37            mean,
38            mean_sq,
39            last_state: Vec::<T>::from(initial_state),
40            accept_queue,
41        }
42    }
43
44    pub fn step(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
45    where
46        T: std::clone::Clone + num_traits::ToPrimitive, //+ num_traits::FromPrimitive , // + std::cmp::PartialOrd,
47    {
48        self.n += 1;
49
50        // TODO: Update p_accept and last_state
51        let accepted = self.last_state.iter().eq(x.iter());
52        let old_aq_len = self.accept_queue.len() as f32;
53        self.accept_queue.push_back(accepted);
54        let removed = if old_aq_len > 100.0 {
55            self.accept_queue.pop_front().unwrap()
56        } else {
57            false
58        };
59        let new_aq_len = self.accept_queue.len() as f32;
60        self.p_accept = (self.p_accept * old_aq_len + (accepted as i32) as f32
61            - (removed as i32) as f32)
62            / new_aq_len;
63        self.last_state.copy_from_slice(x);
64
65        let n = self.n as f32;
66        let x_arr =
67            ndarray::ArrayView1::<T>::from_shape(self.n_params, x)?.mapv(|x| x.to_f32().unwrap());
68
69        self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
70        if self.n == 1 {
71            self.mean_sq = x_arr.pow2();
72        } else {
73            self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
74        };
75
76        Ok(())
77    }
78
79    pub fn sm2(&self) -> Array1<f32> {
80        let n = self.n as f32;
81        (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0)
82    }
83
84    pub fn stats(&self) -> ChainStats {
85        ChainStats {
86            n: self.n,
87            p_accept: self.p_accept,
88            mean: self.mean.clone(),
89            sm2: self.sm2(),
90        }
91    }
92}
93
94pub fn collect_rhat(all_chain_stats: &[&ChainStats]) -> Array1<f32> {
95    let means: Vec<ArrayView1<f32>> = all_chain_stats.iter().map(|x| x.mean.view()).collect();
96    let means = ndarray::stack(Axis(0), &means).expect("Expected stacking means to succeed");
97    let sm2s: Vec<ArrayView1<f32>> = all_chain_stats.iter().map(|x| x.sm2.view()).collect();
98    let sm2s = ndarray::stack(Axis(0), &sm2s).expect("Expected stacking sm2 arrays to succeed");
99
100    let w = sm2s
101        .mean_axis(Axis(0))
102        .expect("Expected computing within-chain variances to succeed");
103    let global_means = means
104        .mean_axis(Axis(0))
105        .expect("Expected computing global means to succeed");
106    let diffs: Array2<f32> = (means.clone()
107        - global_means
108            .broadcast(means.shape())
109            .expect("Expected broadcasting to succeed"))
110    .into_dimensionality()
111    .expect("Expected casting dimensionality to Array1 to succeed");
112    let b = diffs.pow2().sum_axis(Axis(0)) / (diffs.len() - 1) as f32;
113
114    let n: f32 =
115        all_chain_stats.iter().map(|x| x.n as f32).sum::<f32>() / all_chain_stats.len() as f32;
116    ((b + w.clone() * ((n - 1.0) / n)) / w).sqrt()
117}
118
119#[derive(Debug, Clone, PartialEq)]
120pub struct RhatMulti {
121    n: usize,
122    mean: Array2<f64>,    // n_chains x n_params
123    mean_sq: Array2<f64>, // n_chains x n_params
124    n_chains: usize,
125    n_params: usize,
126}
127
128impl RhatMulti {
129    pub fn new(n_chains: usize, n_params: usize) -> Self {
130        let mean_sq = Array2::<f64>::zeros((n_chains, n_params));
131        Self {
132            n: 0,
133            mean: Array2::<f64>::zeros((n_chains, n_params)),
134            mean_sq,
135            n_chains,
136            n_params,
137        }
138    }
139
140    pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
141    where
142        T: Num
143            + num_traits::ToPrimitive
144            + num_traits::FromPrimitive
145            + std::clone::Clone
146            + std::cmp::PartialOrd,
147    {
148        self.n += 1;
149
150        let n = self.n as f64;
151        let x_arr = ndarray::ArrayView2::<T>::from_shape((self.n_chains, self.n_params), x)?
152            .mapv(|x| x.to_f64().unwrap());
153
154        self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
155        if self.n == 1 {
156            self.mean_sq = x_arr.pow2();
157        } else {
158            self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
159        };
160        Ok(())
161    }
162
163    pub fn all(&self) -> Result<Array1<f64>, Box<dyn Error>> {
164        let mean_chain = self
165            .mean
166            .mean_axis(Axis(0))
167            .ok_or("Mean reduction across chains for mean failed.")?;
168        let n_chains = self.mean.shape()[0] as f64;
169        let n = self.n as f64;
170        let fac = n / (n_chains - 1.0);
171        let between = (self.mean.clone() - mean_chain.insert_axis(Axis(0)))
172            .pow2()
173            .sum_axis(Axis(0))
174            * fac;
175        let sm2 = (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0);
176        let within = sm2
177            .mean_axis(Axis(0))
178            .ok_or("Mean reduction across chains for mean of squares failed.")?;
179        let var = within.clone() * ((n - 1.0) / n) + between * (1.0 / n);
180        let rhat = (var / within).sqrt();
181        Ok(rhat)
182    }
183
184    pub fn max(&self) -> Result<f64, Box<dyn Error>> {
185        let all: Array1<f64> = self.all()?;
186        let max = *all.max()?;
187        Ok(max)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::f64;
194
195    use super::*;
196
197    // Generic helper function to run the Rhat test.
198    fn run_rhat_test_generic<T>(data0: Array2<T>, data1: Array2<T>, expected: Array1<f64>, tol: f64)
199    where
200        T: ndarray::NdFloat + num_traits::FromPrimitive,
201    {
202        let mut psr = RhatMulti::new(3, 4);
203        psr.step(data0.as_slice().unwrap()).unwrap();
204        psr.step(data1.as_slice().unwrap()).unwrap();
205        let rhat = psr.all().unwrap();
206        let diff = *(rhat.clone() - expected.clone()).abs().max().unwrap();
207        assert!(
208            diff < tol,
209            "Mismatch in Rhat. Got {:?}, expected {:?}, diff = {:?}",
210            rhat,
211            expected,
212            diff
213        );
214    }
215
216    #[test]
217    fn test_rhat_f64_1() {
218        // Step 0 data (chains x params)
219        let data_step_0 = arr2(&[
220            [0.0, 1.0, 0.0, 1.0], // chain 0
221            [1.0, 2.0, 0.0, 2.0], // chain 1
222            [0.0, 0.0, 0.0, 2.0], // chain 2
223        ]);
224
225        // Step 1 data (chains x params)
226        let data_step_1 = arr2(&[
227            [1.0, 2.0, 2.0, 0.0], // chain 0
228            [1.0, 1.0, 1.0, 1.0], // chain 1
229            [0.0, 1.0, 0.0, 0.0], // chain 2
230        ]);
231        let expected = array![f64::consts::SQRT_2, 1.08012345, 0.89442719, 0.8660254];
232        run_rhat_test_generic(data_step_0, data_step_1, expected, 1e-7);
233    }
234
235    #[test]
236    fn test_rhat_f32_1() {
237        let data_step_0 = arr2(&[
238            [0.0, 1.0, 0.0, 1.0], // chain 0
239            [1.0, 2.0, 0.0, 2.0], // chain 1
240            [0.0, 0.0, 0.0, 2.0], // chain 2
241        ]);
242        let data_step_1 = arr2(&[
243            [1.0, 2.0, 2.0, 0.0], // chain 0
244            [1.0, 1.0, 1.0, 1.0], // chain 1
245            [0.0, 1.0, 0.0, 0.0], // chain 2
246        ]);
247        let expected = array![f64::consts::SQRT_2, 1.0801234, 0.8944271, 0.8660254];
248        run_rhat_test_generic(data_step_0, data_step_1, expected, 1e-6);
249    }
250
251    #[test]
252    fn test_rhat_f32_data() {
253        let data_step_0 = arr2(&[
254            [1.0, 0.0, 0.0, 1.0],
255            [1.0, 0.0, 0.0, 1.0],
256            [0.0, 1.0, 0.0, 2.0],
257        ]);
258        let data_step_1 = arr2(&[
259            [1.0, 2.0, 0.0, 2.0],
260            [1.0, 2.0, 0.0, 0.0],
261            [2.0, 0.0, 1.0, 2.0],
262        ]);
263        let expected = array![f64::consts::FRAC_1_SQRT_2, 0.74535599, 1.0, 1.5];
264        run_rhat_test_generic(data_step_0, data_step_1, expected, 1e-7);
265    }
266}