1use ndarray::prelude::*;
5use ndarray_stats::QuantileExt;
6use num_traits::Num;
7use std::{collections::VecDeque, error::Error};
8
9#[derive(Debug, Clone, PartialEq)]
10pub struct ChainTracker<T> {
11 n_params: usize,
12 n: u64,
13 p_accept: f32,
14 mean: Array1<f32>, mean_sq: Array1<f32>, last_state: Vec<T>,
17 accept_queue: VecDeque<bool>,
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub struct ChainStats {
22 pub n: u64,
23 pub p_accept: f32,
24 pub mean: Array1<f32>, pub sm2: Array1<f32>, }
27
28impl<T: Clone + Copy + PartialEq> ChainTracker<T> {
29 pub fn new(n_params: usize, initial_state: &[T]) -> Self {
30 let mean_sq = Array1::<f32>::zeros(n_params);
31 let mean = Array1::<f32>::zeros(n_params);
32 let accept_queue = VecDeque::new();
33 Self {
34 n_params,
35 n: 0,
36 p_accept: 0.0,
37 mean,
38 mean_sq,
39 last_state: Vec::<T>::from(initial_state),
40 accept_queue,
41 }
42 }
43
44 pub fn step(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
45 where
46 T: std::clone::Clone + num_traits::ToPrimitive, {
48 self.n += 1;
49
50 let accepted = self.last_state.iter().eq(x.iter());
52 let old_aq_len = self.accept_queue.len() as f32;
53 self.accept_queue.push_back(accepted);
54 let removed = if old_aq_len > 100.0 {
55 self.accept_queue.pop_front().unwrap()
56 } else {
57 false
58 };
59 let new_aq_len = self.accept_queue.len() as f32;
60 self.p_accept = (self.p_accept * old_aq_len + (accepted as i32) as f32
61 - (removed as i32) as f32)
62 / new_aq_len;
63 self.last_state.copy_from_slice(x);
64
65 let n = self.n as f32;
66 let x_arr =
67 ndarray::ArrayView1::<T>::from_shape(self.n_params, x)?.mapv(|x| x.to_f32().unwrap());
68
69 self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
70 if self.n == 1 {
71 self.mean_sq = x_arr.pow2();
72 } else {
73 self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
74 };
75
76 Ok(())
77 }
78
79 pub fn sm2(&self) -> Array1<f32> {
80 let n = self.n as f32;
81 (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0)
82 }
83
84 pub fn stats(&self) -> ChainStats {
85 ChainStats {
86 n: self.n,
87 p_accept: self.p_accept,
88 mean: self.mean.clone(),
89 sm2: self.sm2(),
90 }
91 }
92}
93
94pub fn collect_rhat(all_chain_stats: &[&ChainStats]) -> Array1<f32> {
95 let means: Vec<ArrayView1<f32>> = all_chain_stats.iter().map(|x| x.mean.view()).collect();
96 let means = ndarray::stack(Axis(0), &means).expect("Expected stacking means to succeed");
97 let sm2s: Vec<ArrayView1<f32>> = all_chain_stats.iter().map(|x| x.sm2.view()).collect();
98 let sm2s = ndarray::stack(Axis(0), &sm2s).expect("Expected stacking sm2 arrays to succeed");
99
100 let w = sm2s
101 .mean_axis(Axis(0))
102 .expect("Expected computing within-chain variances to succeed");
103 let global_means = means
104 .mean_axis(Axis(0))
105 .expect("Expected computing global means to succeed");
106 let diffs: Array2<f32> = (means.clone()
107 - global_means
108 .broadcast(means.shape())
109 .expect("Expected broadcasting to succeed"))
110 .into_dimensionality()
111 .expect("Expected casting dimensionality to Array1 to succeed");
112 let b = diffs.pow2().sum_axis(Axis(0)) / (diffs.len() - 1) as f32;
113
114 let n: f32 =
115 all_chain_stats.iter().map(|x| x.n as f32).sum::<f32>() / all_chain_stats.len() as f32;
116 ((b + w.clone() * ((n - 1.0) / n)) / w).sqrt()
117}
118
119#[derive(Debug, Clone, PartialEq)]
120pub struct RhatMulti {
121 n: usize,
122 mean: Array2<f64>, mean_sq: Array2<f64>, n_chains: usize,
125 n_params: usize,
126}
127
128impl RhatMulti {
129 pub fn new(n_chains: usize, n_params: usize) -> Self {
130 let mean_sq = Array2::<f64>::zeros((n_chains, n_params));
131 Self {
132 n: 0,
133 mean: Array2::<f64>::zeros((n_chains, n_params)),
134 mean_sq,
135 n_chains,
136 n_params,
137 }
138 }
139
140 pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
141 where
142 T: Num
143 + num_traits::ToPrimitive
144 + num_traits::FromPrimitive
145 + std::clone::Clone
146 + std::cmp::PartialOrd,
147 {
148 self.n += 1;
149
150 let n = self.n as f64;
151 let x_arr = ndarray::ArrayView2::<T>::from_shape((self.n_chains, self.n_params), x)?
152 .mapv(|x| x.to_f64().unwrap());
153
154 self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
155 if self.n == 1 {
156 self.mean_sq = x_arr.pow2();
157 } else {
158 self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
159 };
160 Ok(())
161 }
162
163 pub fn all(&self) -> Result<Array1<f64>, Box<dyn Error>> {
164 let mean_chain = self
165 .mean
166 .mean_axis(Axis(0))
167 .ok_or("Mean reduction across chains for mean failed.")?;
168 let n_chains = self.mean.shape()[0] as f64;
169 let n = self.n as f64;
170 let fac = n / (n_chains - 1.0);
171 let between = (self.mean.clone() - mean_chain.insert_axis(Axis(0)))
172 .pow2()
173 .sum_axis(Axis(0))
174 * fac;
175 let sm2 = (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0);
176 let within = sm2
177 .mean_axis(Axis(0))
178 .ok_or("Mean reduction across chains for mean of squares failed.")?;
179 let var = within.clone() * ((n - 1.0) / n) + between * (1.0 / n);
180 let rhat = (var / within).sqrt();
181 Ok(rhat)
182 }
183
184 pub fn max(&self) -> Result<f64, Box<dyn Error>> {
185 let all: Array1<f64> = self.all()?;
186 let max = *all.max()?;
187 Ok(max)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::f64;
194
195 use super::*;
196
197 fn run_rhat_test_generic<T>(data0: Array2<T>, data1: Array2<T>, expected: Array1<f64>, tol: f64)
199 where
200 T: ndarray::NdFloat + num_traits::FromPrimitive,
201 {
202 let mut psr = RhatMulti::new(3, 4);
203 psr.step(data0.as_slice().unwrap()).unwrap();
204 psr.step(data1.as_slice().unwrap()).unwrap();
205 let rhat = psr.all().unwrap();
206 let diff = *(rhat.clone() - expected.clone()).abs().max().unwrap();
207 assert!(
208 diff < tol,
209 "Mismatch in Rhat. Got {:?}, expected {:?}, diff = {:?}",
210 rhat,
211 expected,
212 diff
213 );
214 }
215
216 #[test]
217 fn test_rhat_f64_1() {
218 let data_step_0 = arr2(&[
220 [0.0, 1.0, 0.0, 1.0], [1.0, 2.0, 0.0, 2.0], [0.0, 0.0, 0.0, 2.0], ]);
224
225 let data_step_1 = arr2(&[
227 [1.0, 2.0, 2.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0], ]);
231 let expected = array![f64::consts::SQRT_2, 1.08012345, 0.89442719, 0.8660254];
232 run_rhat_test_generic(data_step_0, data_step_1, expected, 1e-7);
233 }
234
235 #[test]
236 fn test_rhat_f32_1() {
237 let data_step_0 = arr2(&[
238 [0.0, 1.0, 0.0, 1.0], [1.0, 2.0, 0.0, 2.0], [0.0, 0.0, 0.0, 2.0], ]);
242 let data_step_1 = arr2(&[
243 [1.0, 2.0, 2.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0], ]);
247 let expected = array![f64::consts::SQRT_2, 1.0801234, 0.8944271, 0.8660254];
248 run_rhat_test_generic(data_step_0, data_step_1, expected, 1e-6);
249 }
250
251 #[test]
252 fn test_rhat_f32_data() {
253 let data_step_0 = arr2(&[
254 [1.0, 0.0, 0.0, 1.0],
255 [1.0, 0.0, 0.0, 1.0],
256 [0.0, 1.0, 0.0, 2.0],
257 ]);
258 let data_step_1 = arr2(&[
259 [1.0, 2.0, 0.0, 2.0],
260 [1.0, 2.0, 0.0, 0.0],
261 [2.0, 0.0, 1.0, 2.0],
262 ]);
263 let expected = array![f64::consts::FRAC_1_SQRT_2, 0.74535599, 1.0, 1.5];
264 run_rhat_test_generic(data_step_0, data_step_1, expected, 1e-7);
265 }
266}