mini_mcmc/
stats.rs

1//! Computation and tracking of MCMC statistics like acceptance probability and Potential Scale
2//! Reduction.
3
4use burn::prelude::{Backend, Tensor};
5use core::fmt;
6use ndarray::{concatenate, prelude::*, stack};
7use ndarray_stats::QuantileExt;
8use num_traits::{Num, ToPrimitive};
9use rayon::prelude::*;
10use rustfft::{num_complex::Complex, FftPlanner};
11use std::{cmp::Ordering, error::Error};
12
13const ALPHA: f32 = 0.01;
14
15/// Tracks statistics for a single MCMC chain.
16///
17/// # Fields
18/// - `n_params`: Number of parameters in the chain.
19/// - `n`: Number of steps taken.
20/// - `p_accept`: Acceptance probability.
21/// - `mean`: Mean of the parameters.
22/// - `mean_sq`: Mean of the squared parameters.
23/// - `last_state`: Last state of the chain.
24/// - `accept_queue`: Queue tracking acceptance history.
25#[derive(Debug, Clone, PartialEq)]
26pub struct ChainTracker {
27    n_params: usize,
28    n: u64,
29    p_accept: f32,
30    last_state: Array1<f32>,
31    mean: Array1<f32>,    // n_params
32    mean_sq: Array1<f32>, // n_params
33}
34
35/// Statistics of an MCMC chain.
36///
37/// # Fields
38/// - `n`: Number of steps taken.
39/// - `p_accept`: Acceptance probability.
40/// - `mean`: Mean of the parameters.
41/// - `sm2`: Variance of the parameters.
42#[derive(Debug, Clone, PartialEq)]
43pub struct ChainStats {
44    pub n: u64,
45    pub p_accept: f32,
46    pub mean: Array1<f32>, // n_params
47    pub sm2: Array1<f32>,  // n_params
48}
49
50impl ChainTracker {
51    /// Creates a new `ChainTracker` with the given number of parameters and initial state.
52    ///
53    /// # Arguments
54    /// - `n_params`: Number of parameters in the chain.
55    /// - `initial_state`: Initial state of the chain.
56    ///
57    /// # Returns
58    /// A new `ChainTracker` instance.
59    pub fn new<T>(n_params: usize, initial_state: &[T]) -> Self
60    where
61        T: num_traits::ToPrimitive + Clone,
62    {
63        let mean_sq = Array1::<f32>::zeros(n_params);
64        let mean = Array1::<f32>::zeros(n_params);
65        let last_state = ArrayView1::from_shape(n_params, initial_state)
66            .expect("Expected being able to convert initial state to a NdArray")
67            .mapv(|x| {
68                x.to_f32()
69                    .expect("Expected conversion of elements to f32's to succeed")
70            });
71
72        Self {
73            n_params,
74            n: 0,
75            p_accept: -1.0,
76            last_state,
77            mean,
78            mean_sq,
79        }
80    }
81
82    /// Updates the tracker with a new state.
83    ///
84    /// # Arguments
85    /// - `x`: New state of the chain.
86    ///
87    /// # Returns
88    /// `Ok(())` if successful; an error otherwise.
89    pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
90    where
91        T: num_traits::ToPrimitive + Clone,
92    {
93        self.n += 1;
94
95        let n = self.n as f32;
96        let x_arr =
97            ndarray::ArrayView1::<T>::from_shape(self.n_params, x)?.mapv(|x| x.to_f32().unwrap());
98
99        self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
100        if self.n == 1 {
101            self.mean_sq = x_arr.pow2();
102        } else {
103            self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
104        };
105
106        //  x_1 = (1 - a) x_0 + a x_1
107        // <=> x_1 (1 - a) = (1 - a) x_0
108        // <=> x_1 = x_0
109        // So set initial p_accept to 1 if the transition state was an 'accept' and 0 otherwise
110        let p_start = if self.p_accept >= 0.0 {
111            self.p_accept
112        } else {
113            x_arr
114                .index_axis(Axis(0), 0)
115                .ne(&self.last_state.index_axis(Axis(0), 0)) as i32 as f32
116        };
117        self.p_accept = ndarray::Zip::from(x_arr.rows())
118            .and(self.last_state.rows())
119            .fold(p_start, |p_accept, a, b| {
120                let accepted = (a.ne(&b) as i32) as f32;
121                (1.0 - ALPHA) * p_accept + ALPHA * accepted
122            });
123        self.last_state = x_arr;
124
125        Ok(())
126    }
127
128    /// Retrieves the current statistics of the chain.
129    ///
130    /// # Returns
131    /// A `ChainStats` struct containing the current statistics.
132    pub fn stats(&self) -> ChainStats {
133        let n = self.n as f32;
134        ChainStats {
135            n: self.n,
136            p_accept: self.p_accept,
137            mean: self.mean.clone(),
138            sm2: (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0),
139        }
140    }
141}
142
143/// Computes the Potential Scale Reduction Factor (R-hat) for multiple chains.
144///
145/// # Arguments
146/// - `chain_stats`: Slice of references to `ChainStats` from multiple chains.
147///
148/// # Returns
149/// An array containing the R-hat values for each parameter.
150pub fn collect_rhat(chain_stats: &[&ChainStats]) -> Array1<f32> {
151    let (within, var) = withinvar_from_cs(chain_stats);
152    (var / within).sqrt()
153}
154
155fn withinvar_from_cs(chain_stats: &[&ChainStats]) -> (Array1<f32>, Array1<f32>) {
156    let means: Vec<ArrayView1<f32>> = chain_stats.iter().map(|x| x.mean.view()).collect();
157    let means = ndarray::stack(Axis(0), &means).expect("Expected stacking means to succeed");
158    let sm2s: Vec<ArrayView1<f32>> = chain_stats.iter().map(|x| x.sm2.view()).collect();
159    let sm2s = ndarray::stack(Axis(0), &sm2s).expect("Expected stacking sm2 arrays to succeed");
160
161    let within = sm2s
162        .mean_axis(Axis(0))
163        .expect("Expected computing within-chain variances to succeed");
164    let global_means = means
165        .mean_axis(Axis(0))
166        .expect("Expected computing global means to succeed");
167    let diffs: Array2<f32> = (means.clone()
168        - global_means
169            .broadcast(means.shape())
170            .expect("Expected broadcasting to succeed"))
171    .into_dimensionality()
172    .expect("Expected casting dimensionality to Array1 to succeed");
173    let between = diffs.pow2().sum_axis(Axis(0)) / (diffs.len() - 1) as f32;
174
175    let n: f32 = chain_stats.iter().map(|x| x.n as f32).sum::<f32>() / chain_stats.len() as f32;
176    let var = between + within.clone() * ((n - 1.0) / n);
177    (within, var)
178}
179
180/// Tracks statistics across multiple MCMC chains.
181///
182/// # Fields
183/// - `n`: Number of steps taken.
184/// - `mean`: Mean of the parameters across chains.
185/// - `mean_sq`: Mean of the squared parameters across chains.
186/// - `n_chains`: Number of chains.
187/// - `n_params`: Number of parameters.
188#[derive(Debug, Clone, PartialEq)]
189pub struct MultiChainTracker {
190    n: usize,
191    pub p_accept: f32,
192    last_state: Array2<f32>,
193    mean: Array2<f32>,    // n_chains x n_params
194    mean_sq: Array2<f32>, // n_chains x n_params
195    n_chains: usize,
196    n_params: usize,
197}
198
199impl MultiChainTracker {
200    /// Creates a new `MultiChainTracker` for the given number of chains and parameters.
201    ///
202    /// # Arguments
203    /// - `n_chains`: Number of chains.
204    /// - `n_params`: Number of parameters.
205    ///
206    /// # Returns
207    /// A new `MultiChainTracker` instance.
208    pub fn new(n_chains: usize, n_params: usize) -> Self {
209        let mean_sq = Array2::<f32>::zeros((n_chains, n_params));
210        Self {
211            n: 0,
212            p_accept: 0.0,
213            last_state: Array2::<f32>::zeros((n_chains, n_params)),
214            mean: Array2::<f32>::zeros((n_chains, n_params)),
215            mean_sq,
216            n_chains,
217            n_params,
218        }
219    }
220
221    /// Updates the tracker with new states from all chains.
222    ///
223    /// # Arguments
224    /// - `x`: New states of the chains, flattened into a single slice.
225    ///
226    /// # Returns
227    /// `Ok(())` if successful; an error otherwise.
228    pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
229    where
230        T: Num
231            + num_traits::ToPrimitive
232            + num_traits::FromPrimitive
233            + std::clone::Clone
234            + std::cmp::PartialOrd,
235    {
236        self.n += 1;
237
238        let n = self.n as f32;
239        let x_arr = ndarray::ArrayView2::<T>::from_shape((self.n_chains, self.n_params), x)?
240            .mapv(|x| x.to_f32().unwrap());
241
242        self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
243        if self.n == 1 {
244            self.mean_sq = x_arr.pow2();
245        } else {
246            self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
247        };
248
249        // Update self.p_accept and last state
250        self.p_accept = ndarray::Zip::from(x_arr.rows())
251            .and(self.last_state.rows())
252            .fold(self.p_accept, |p_accept, a, b| {
253                let accepted = (a.ne(&b) as i32) as f32;
254                (1.0 - ALPHA) * p_accept + ALPHA * accepted
255            });
256        self.last_state = x_arr;
257
258        Ok(())
259    }
260
261    pub fn stats<B: Backend>(&self, sample: Tensor<B, 3>) -> Result<RunStats, Box<dyn Error>> {
262        let sample_data = sample.to_data();
263        let sample_ndarray =
264            ArrayView3::from_shape(sample.dims(), sample_data.as_slice().unwrap())?;
265        Ok(RunStats::from_f32_view(sample_ndarray))
266    }
267
268    /// Computes the maximum R-hat value across all parameters.
269    ///
270    /// # Returns
271    /// The maximum R-hat value, or an error if computation fails.
272    pub fn max_rhat(&self) -> Result<f32, Box<dyn Error>> {
273        let all: Array1<f32> = self.rhat()?;
274        let max = *all.max()?;
275        Ok(max)
276    }
277
278    /// Computes the R-hat values for all parameters.
279    ///
280    /// # Returns
281    /// An array containing the R-hat values for each parameter, or an error if computation fails.
282    pub fn rhat(&self) -> Result<Array1<f32>, Box<dyn Error>> {
283        let (within, var) = self.within_and_var()?;
284        let rhat = (var / within).sqrt();
285        Ok(rhat)
286    }
287
288    fn within_and_var(&self) -> Result<(Array1<f32>, Array1<f32>), Box<dyn Error>> {
289        let mean_chain = self
290            .mean
291            .mean_axis(Axis(0))
292            .ok_or("Mean reduction across chains for mean failed.")?;
293        let n_chains = self.mean.shape()[0] as f32;
294        let n = self.n as f32;
295        let fac = n / (n_chains - 1.0);
296        let between = (self.mean.clone() - mean_chain.insert_axis(Axis(0)))
297            .pow2()
298            .sum_axis(Axis(0))
299            * fac;
300        let sm2 = (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0);
301        let within = sm2
302            .mean_axis(Axis(0))
303            .ok_or("Mean reduction across chains for mean of squares failed.")?;
304        let var = within.clone() * ((n - 1.0) / n) + between * (1.0 / n);
305        Ok((within, var))
306    }
307}
308
309/// Computes basic statistics from
310pub fn basic_stats(name: &str, mut data: Array1<f32>) -> BasicStats {
311    data.as_slice_mut()
312        .unwrap()
313        .sort_by(|a, b| match b.partial_cmp(a) {
314            Some(x) => x,
315            None => Ordering::Equal,
316        });
317    let (min, median, max) = (
318        *data
319            .last()
320            .expect("Expected getting first element from ess array succeed"),
321        data[data.len() / 2],
322        *data
323            .first()
324            .expect("Expected getting last element from ess array succeed"),
325    );
326    let mean = data.mean().expect("Expected computing mean ess to succeed");
327    let std = data.std(1.0);
328    BasicStats {
329        name: name.to_string(),
330        min,
331        median,
332        max,
333        mean,
334        std,
335    }
336}
337
338#[derive(Clone, Debug, PartialEq, PartialOrd)]
339pub struct RunStats {
340    pub ess: BasicStats,
341    pub rhat: BasicStats,
342}
343
344impl RunStats {
345    fn from_f32_view(sample: ArrayView3<f32>) -> Self {
346        let (rhat, ess) = split_rhat_mean_ess(sample);
347        let ess = basic_stats("ESS", ess);
348        let rhat = basic_stats("Split R-hat", rhat);
349        RunStats { ess, rhat }
350    }
351}
352
353impl fmt::Display for RunStats {
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        // Using the Display implementation of BasicStats
356        write!(f, "{}\n{}", self.ess, self.rhat)
357    }
358}
359
360impl<T> From<ArrayView3<'_, T>> for RunStats
361where
362    T: ToPrimitive + std::clone::Clone,
363{
364    fn from(sample: ArrayView3<T>) -> Self {
365        let f32_sample = sample.mapv(|x| x.to_f32().unwrap());
366        let (rhat, ess) = split_rhat_mean_ess(f32_sample.view());
367        let ess = basic_stats("ESS", ess);
368        let rhat = basic_stats("Split R-hat", rhat);
369        RunStats { ess, rhat }
370    }
371}
372
373#[derive(Clone, Debug, PartialEq, PartialOrd)]
374pub struct BasicStats {
375    pub name: String,
376    pub min: f32,
377    pub median: f32,
378    pub max: f32,
379    pub mean: f32,
380    pub std: f32,
381}
382
383impl fmt::Display for BasicStats {
384    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385        // Write a human-friendly summary of the stats.
386        write!(
387            f,
388            "{} in [{:.2}, {:.2}], median: {:.2}, mean: {:.2} ± {:.2}",
389            self.name, self.min, self.max, self.median, self.mean, self.std
390        )
391    }
392}
393
394/// Takes a (chains, observations, parameters) view and returns a new
395/// (2*chains, observations/2, parameters) array by splitting each chain in half.
396fn splitcat(sample: ArrayView3<f32>) -> Array3<f32> {
397    let n = sample.shape()[1];
398    let half = (n / 2) as i32;
399    let half_1 = sample.slice(s![.., ..half, ..]);
400    let half_2 = sample.slice(s![.., -half.., ..]);
401    concatenate(Axis(0), &[half_1, half_2]).expect("Expected stacking two halves to succeed")
402}
403
404/// Computes both split-R-hat and ESS metrics following STAN's methodology.
405///
406/// # Arguments
407/// - `sample`: 3D array of observations with shape (chains, observations, parameters).
408///
409/// # Returns
410/// A tuple containing:
411/// - An array of split-R-hat values for each parameter
412/// - An array of ESS values for each parameter
413///
414/// # References
415/// - STAN Reference Manual, Section on R-hat and Effective Sample Size
416pub fn split_rhat_mean_ess(sample: ArrayView3<f32>) -> (Array1<f32>, Array1<f32>) {
417    let splitted = splitcat(sample); // shape: (2c, n/2, p)
418    let (within, var) = withinvar(splitted.view());
419    (
420        rhat(within.view(), var.view()),
421        ess(splitted.view(), within.view(), var.view()),
422    )
423}
424
425fn rhat(within: ArrayView1<f32>, var: ArrayView1<f32>) -> Array1<f32> {
426    (within.to_owned() / var).sqrt()
427}
428
429fn withinvar(sample: ArrayView3<f32>) -> (Array1<f32>, Array1<f32>) {
430    let c = sample.shape()[0];
431    let n = sample.shape()[1];
432    let p = sample.shape()[2];
433
434    // 2) For each parameter, compute the chain-split stats
435    //    shape for data_p is (2c, n/2)
436    let (within, var): (Vec<f32>, Vec<f32>) = (0..p)
437        .into_par_iter()
438        .map(|param_idx| {
439            let data_p = sample.slice(s![.., .., param_idx]);
440
441            // chain means => shape (2c,)
442            let chain_means = data_p.mean_axis(Axis(1)).unwrap();
443            let overall_mean = chain_means.mean().unwrap();
444
445            // B => between chain var
446            //   = (n/2 / (2c - 1)) * sum( (chain_means - overall_mean)^2 )
447            // (We assume 2c > 1)
448            let diff = &chain_means - overall_mean;
449            let b = diff.pow2().sum() * ((n as f32) / ((c - 1) as f32));
450
451            // within => we broadcast chain_means to shape (2c, n/2) to subtract
452            // but an easier approach might be:
453            // squares[i] = mean over that chain's (x_i - chain_mean[i])^2
454            // then average squares across all chains
455            let mut squares = Vec::with_capacity(c);
456            for chain_i in 0..c {
457                let row = data_p.slice(s![chain_i, ..]);
458                let cm = chain_means[chain_i];
459                let sq = row.iter().map(|v| (v - cm) * (v - cm)).sum::<f32>() / (n as f32);
460                squares.push(sq);
461            }
462            let squares = Array1::from(squares); // shape (2c,)
463            let w = squares.mean().unwrap(); // within = average across chains
464                                             // var => ((n/2 - 1)/(n/2)) * w + b/(n/2)
465            let v = ((n as f32 - 1.0) / (n as f32)) * w + b / (n as f32);
466
467            (w, v)
468        })
469        .collect::<Vec<(f32, f32)>>()
470        .into_iter()
471        .fold((vec![], vec![]), |(mut within, mut var), (w, v)| {
472            within.push(w);
473            var.push(v);
474            (within, var)
475        });
476    (Array1::from_vec(within), Array1::from_vec(var))
477}
478
479/// Computes the Effective Sample Size (ESS) for each parameter.
480///
481/// This function implements the ESS calculation as described in the STAN documentation.
482/// It computes the ESS based on the autocorrelation of the chains and the ratio of
483/// within-chain to total variance.
484///
485/// # Arguments
486/// - `sample`: 3D array of observations with shape (chains, observations, parameters).
487/// - `within`: Array of within-chain variances for each parameter.
488/// - `var`: Array of total variances for each parameter.
489///
490/// # Returns
491/// An array containing the ESS for each parameter.
492///
493/// # References
494/// - STAN Reference Manual, Section on Effective Sample Size
495///   (https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html)
496fn ess(sample: ArrayView3<f32>, within: ArrayView1<f32>, var: ArrayView1<f32>) -> Array1<f32> {
497    let shape = sample.shape();
498    let (n_chains, n_steps, n_params) = (shape[0], shape[1], shape[2]);
499    let chain_rho: Vec<Array2<f32>> = (0..n_chains)
500        .map(|c| {
501            let chain_sample = sample.index_axis(Axis(0), c);
502            autocov(chain_sample)
503        })
504        .collect();
505    let chain_rho: Vec<ArrayView2<f32>> = chain_rho.iter().map(|x| x.view()).collect();
506    let chain_rho = stack(Axis(0), &chain_rho)
507        .expect("Expected stacking chain-specific autocovariance matrices to succeed");
508    let avg_rho = chain_rho.mean_axis(Axis(0)).unwrap();
509    let diff = -avg_rho
510        + within
511            .broadcast((n_steps, n_params))
512            .expect("Expected broadcasting to succeed");
513    let rho = -(diff
514        / var
515            .broadcast((n_steps, n_params))
516            .expect("Expected broadcasting to succeed"))
517        + 1.0;
518    let tau: Vec<f32> = (0..n_params)
519        .into_par_iter()
520        .map(|d| {
521            let rho_d = rho.index_axis(Axis(1), d).to_owned();
522
523            let mut min = if rho_d.len() >= 2 {
524                rho_d[[0]] + rho_d[[1]]
525            } else {
526                0.0
527            };
528
529            let mut out = 0.0;
530            for rho_t in rho_d.windows_with_stride(2, 2) {
531                let mut p_t = rho_t[0] + rho_t[1];
532                if p_t <= 0.0 {
533                    break;
534                }
535                if p_t > min {
536                    p_t = min;
537                }
538                min = p_t;
539                out += p_t;
540            }
541            -1.0 + 2.0 * out
542        })
543        .collect();
544    let tau = Array1::from_vec(tau);
545    tau.recip() * n_chains as f32 * n_steps as f32
546}
547
548fn autocov(sample: ArrayView2<f32>) -> Array2<f32> {
549    if sample.nrows() <= 100 {
550        autocov_bf(sample)
551    } else {
552        autocov_fft(sample)
553    }
554}
555
556/// Compute the autocovariance of multiple sequences (each column represents a distinct sequence)
557/// using FFT for efficient calculation.
558///
559/// # Arguments
560///
561/// * `sample` - A 2-dimensional array view (`ArrayView2<f32>`) of shape `(n, d)`, where:
562///     - `n`: length of each sequence.
563///     - `d`: number of sequences (each column is treated independently).
564///
565/// # Returns
566///
567/// An `Array2<f32>` of shape `(n, d)` containing the autocovariance results.
568/// Each column contains the autocovariance values for the corresponding input sequence.
569///
570/// # Notes
571///
572/// * Uses zero-padding to avoid circular convolution effects (wrap-around).
573/// * FFT and inverse FFT are performed using the `rustfft` crate.
574/// * Computation is parallelized across sequences using Rayon.
575/// * Normalization (`1/n_padded`) is applied explicitly, as `rustfft` does not normalize results.
576fn autocov_fft(sample: ArrayView2<f32>) -> Array2<f32> {
577    let (n, d) = (sample.shape()[0], sample.shape()[1]);
578    let mut planner = FftPlanner::new();
579
580    // Next power of 2 >= 2*n - 1 for zero-padding to avoid wrap-around.
581    let mut n_padded = 1;
582    while n_padded < 2 * n - 1 {
583        n_padded <<= 1;
584    }
585    let fft = planner.plan_fft_forward(n_padded);
586    let ffti = planner.plan_fft_inverse(n_padded);
587    let out: Vec<f32> = sample
588        .axis_iter(Axis(1))
589        .into_par_iter()
590        .map(|traj| {
591            let traj_mean = traj.sum() / traj.len() as f32;
592            let mut x: Vec<Complex<f32>> = traj
593                .iter()
594                .map(|xi| Complex {
595                    re: (*xi - traj_mean),
596                    im: 0.0f32,
597                })
598                .chain(
599                    [Complex {
600                        re: 0.0f32,
601                        im: 0.0f32,
602                    }]
603                    .repeat(n_padded - n),
604                )
605                .collect();
606            fft.process(x.as_mut_slice());
607            x.iter_mut().for_each(|xi| {
608                *xi *= xi.conj();
609            });
610            ffti.process(x.as_mut_slice());
611            x.iter_mut()
612                .take(n)
613                .map(|xi| xi.re / n_padded as f32 / n as f32) // rustfft doens't normalize for us
614                .collect::<Vec<f32>>()
615        })
616        .flatten_iter()
617        .collect();
618    let out = Array2::from_shape_vec((d, n), out).expect("Expected creating dxn array to succeed");
619    out.t().to_owned()
620}
621
622/// Brute force autocovariance on a 2D array of shape (n, d).
623/// - `n` = number of time points (rows)
624/// - `d` = number of parameters (columns)
625///
626/// For each column `col` and each lag `lag` (0..n), the function
627/// computes:
628/// $$
629///    sum_{t=0..(n - lag - 1)} [ data[t, col] * data[t + lag, col] ]
630/// $$
631/// and stores it in `out[lag, col]`.
632fn autocov_bf(data: ArrayView2<f32>) -> Array2<f32> {
633    let (n, d) = data.dim();
634    let mut out = Array2::<f32>::zeros((n, d));
635
636    out.axis_iter_mut(Axis(1)) // mutable view of each column in `out`
637        .into_par_iter() // make it parallel
638        .enumerate() // get (col_index, col_view_mut)
639        .for_each(|(col_idx, mut out_col)| {
640            let col_data = data.column(col_idx);
641            let col_data = col_data.to_owned() - col_data.mean().unwrap();
642
643            // For each lag, compute sum_{t=0..(n-lag-1)} [ data[t, col] * data[t + lag, col] ]
644            for lag in 0..n {
645                let mut sum_lag = 0.0;
646                for t in 0..(n - lag) {
647                    sum_lag += col_data[t] * col_data[t + lag];
648                }
649                // Write result into the current column
650                out_col[lag] = sum_lag / n as f32;
651            }
652        });
653    out
654}
655
656/// Computes the Effective Sample Size (ESS) from chain statistics.
657/// We don't split chains here.
658///
659/// # Arguments
660/// - `sample`: 3D array of observations with shape (chains, observations, parameters).
661/// - `chain_stats`: Slice of references to `ChainStats` from multiple chains.
662///
663/// # Returns
664/// An array containing the ESS for each parameter.
665///
666/// # References
667/// - STAN Reference Manual, Section on Effective Sample Size
668pub fn ess_from_chainstats(sample: ArrayView3<f32>, chain_stats: &[&ChainStats]) -> Array1<f32> {
669    let (within, var) = withinvar_from_cs(chain_stats);
670    ess(sample, within.view(), var.view())
671}
672
673#[cfg(test)]
674mod tests {
675    use std::io::Write;
676    use std::{f32, fs::File, time::Instant};
677
678    use approx::assert_abs_diff_eq;
679    use rand::rngs::SmallRng;
680    use rand::{Rng, SeedableRng};
681
682    use super::*;
683
684    // Generic helper function to run the Rhat test.
685    fn run_rhat_test_generic<T>(data0: Array2<T>, data1: Array2<T>, expected: Array1<f32>, tol: f32)
686    where
687        T: ndarray::NdFloat + num_traits::FromPrimitive,
688    {
689        let mut psr = MultiChainTracker::new(3, 4);
690        psr.step(data0.as_slice().unwrap()).unwrap();
691        psr.step(data1.as_slice().unwrap()).unwrap();
692        let rhat = psr.rhat().unwrap();
693        let diff = *(rhat.clone() - expected.clone()).abs().max().unwrap();
694        assert!(
695            diff < tol,
696            "Mismatch in Rhat. Got {:?}, expected {:?}, diff = {:?}",
697            rhat,
698            expected,
699            diff
700        );
701    }
702
703    #[test]
704    fn test_rhat_f32_1() {
705        // Step 0 data (chains x params)
706        let data_step_0: Array2<f32> = arr2(&[
707            [0.0, 1.0, 0.0, 1.0], // chain 0
708            [1.0, 2.0, 0.0, 2.0], // chain 1
709            [0.0, 0.0, 0.0, 2.0], // chain 2
710        ]);
711
712        // Step 1 data (chains x params)
713        let data_step_1: Array2<f32> = arr2(&[
714            [1.0, 2.0, 2.0, 0.0], // chain 0
715            [1.0, 1.0, 1.0, 1.0], // chain 1
716            [0.0, 1.0, 0.0, 0.0], // chain 2
717        ]);
718        let expected = array![f32::consts::SQRT_2, 1.080_123_4, 0.894_427_3, 0.8660254];
719        run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
720    }
721
722    #[test]
723    fn test_rhat_f64_1() {
724        let data_step_0: Array2<f64> = arr2(&[
725            [0.0, 1.0, 0.0, 1.0], // chain 0
726            [1.0, 2.0, 0.0, 2.0], // chain 1
727            [0.0, 0.0, 0.0, 2.0], // chain 2
728        ]);
729        let data_step_1: Array2<f64> = arr2(&[
730            [1.0, 2.0, 2.0, 0.0], // chain 0
731            [1.0, 1.0, 1.0, 1.0], // chain 1
732            [0.0, 1.0, 0.0, 0.0], // chain 2
733        ]);
734        let expected = array![f32::consts::SQRT_2, 1.0801234, 0.8944271, 0.8660254];
735        run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
736    }
737
738    #[test]
739    fn test_rhat_f64_2() {
740        let data_step_0 = arr2(&[
741            [1.0, 0.0, 0.0, 1.0],
742            [1.0, 0.0, 0.0, 1.0],
743            [0.0, 1.0, 0.0, 2.0],
744        ]);
745        let data_step_1 = arr2(&[
746            [1.0, 2.0, 0.0, 2.0],
747            [1.0, 2.0, 0.0, 0.0],
748            [2.0, 0.0, 1.0, 2.0],
749        ]);
750        let expected = array![f32::consts::FRAC_1_SQRT_2, 0.74535599, 1.0, 1.5];
751        run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
752    }
753
754    fn run_test_case(
755        autocov_func: &dyn Fn(ArrayView2<f32>) -> Array2<f32>,
756        data: &Array2<f32>,
757        expected: &Array2<f32>,
758        test_name: &str,
759    ) {
760        let result = autocov_func(data.view());
761        assert_eq!(
762            result.dim(),
763            expected.dim(),
764            "{}: shape mismatch; got {:?}, expected {:?}",
765            test_name,
766            result.dim(),
767            expected.dim()
768        );
769
770        assert_abs_diff_eq!(result, *expected, epsilon = 1e-6);
771        println!("Test: {test_name} succeeded");
772    }
773
774    // ----------------------------------------------------------
775    // Test: single parameter, small integer sequence
776    // ----------------------------------------------------------
777    #[test]
778    fn test_single_param() {
779        let data = array![[1.0], [2.0], [3.0], [4.0],];
780        let expected = array![[1.25], [0.3125], [-0.375], [-0.5625]];
781
782        // Compare brute force
783        run_test_case(&autocov_bf, &data, &expected, "BF: single_param_small");
784
785        println!("Doing FFT test");
786        // Compare FFT-based
787        run_test_case(&autocov_fft, &data, &expected, "FFT: single_param_small");
788        println!("FFT test succeeded");
789    }
790
791    // ----------------------------------------------------------
792    // Test: two parameters, 4 time points
793    // ----------------------------------------------------------
794    #[test]
795    fn test_two_params_1() {
796        let data = array![[1.0, 0.3], [2.0, 2.0], [3.0, -2.0], [4.0, 5.0],];
797        let expected = array![
798            [1.25, 6.516875],
799            [0.3125, -3.7889063],
800            [-0.375, 1.4721875],
801            [-0.5625, -0.94171875],
802        ];
803
804        // Compare brute force
805        run_test_case(&autocov_bf, &data, &expected, "BF: two_params_small");
806        // Compare FFT-based
807        run_test_case(&autocov_fft, &data, &expected, "FFT: two_params_small");
808    }
809
810    #[test]
811    fn ess_1() {
812        let m = 4;
813        let n = 1000;
814
815        // Initialize an empty 4 x 100 array
816        let mut data = Array2::<f32>::zeros((m, n));
817
818        // Use the built-in RNG to generate each row separately
819        let mut rng = SmallRng::seed_from_u64(42);
820        for mut row in data.rows_mut() {
821            for elem in row.iter_mut() {
822                *elem = rng.random::<f32>(); // generates uniform random number between 0 and 1
823            }
824        }
825        let data = data
826            .to_shape((data.shape()[0], data.shape()[1], 1))
827            .unwrap();
828        let run_stats = RunStats::from(data.view());
829
830        println!("Samples: {}\n{run_stats}", m * n);
831
832        assert!(run_stats.ess.min > 3800.0);
833        assert!(run_stats.rhat.max < 1.01);
834    }
835
836    #[test]
837    #[ignore = "Benchmark test: run only when explicitly requested"]
838    fn test_autocov_perf_comp() {
839        // Create output CSV
840        let mut file =
841            File::create("runtime_results.csv").expect("Unable to create runtime_results.csv");
842        // Write header row
843        writeln!(file, "length,rep,time,algorithm").expect("Unable to write CSV header");
844
845        let mut rng = rand::rng();
846
847        for exp in 0..10 {
848            let n = 1 << exp; // 2^exp
849            for rep in 1..=10 {
850                // Generate random data of size (n x 1)
851                let sample_data: Vec<f32> = (0..n * 1000).map(|_| rng.random()).collect();
852                let sample = Array2::from_shape_vec((n, 1000), sample_data)
853                    .expect("Failed to create Array2");
854
855                // Measure FFT-based implementation
856                let start_fft = Instant::now();
857                autocov_fft(sample.view());
858                let fft_time = start_fft.elapsed().as_nanos();
859
860                // Measure brute-force implementation
861                let start_brute = Instant::now();
862                autocov_bf(sample.view());
863                let brute_time = start_brute.elapsed().as_nanos();
864
865                // Log results to CSV
866                writeln!(file, "{},{},{},fft", n, rep, fft_time)
867                    .expect("Unable to write test results to CSV");
868                writeln!(file, "{},{},{},brute force", n, rep, brute_time)
869                    .expect("Unable to write test results to CSV");
870
871                // Print results for convenience
872                println!(
873                    "Length: {} | Rep: {} | FFT: {} ns | Brute: {} ns",
874                    n, rep, fft_time, brute_time
875                );
876            }
877        }
878    }
879}