Skip to main content

bayes_rs/
multi_chain.rs

1//! Helpers for running multiple MCMC chains and summarizing diagnostics.
2//!
3//! The primary entry point is [`run_multiple_chains`], which consumes a mutable
4//! slice of already-configured samplers, runs the same warmup/sample schedule for
5//! each chain, and returns both raw draws and [`McmcDiagnostics`]. Construct each
6//! sampler with its own seed (for samplers that support seeded constructors) to
7//! make multi-chain runs reproducible.
8//! At least two samplers and at least two retained samples per sampler are
9//! required because this helper always computes multi-chain diagnostics,
10//! including R-hat; these preconditions are checked before any sampler is run.
11//!
12//! ```rust
13//! use bayes_rs::{multi_chain::run_multiple_chains, samplers::Sampler};
14//! use nalgebra::DVector;
15//!
16//! struct CounterSampler {
17//!     state: DVector<f64>,
18//! }
19//!
20//! impl CounterSampler {
21//!     fn new(start: f64) -> Self {
22//!         Self { state: DVector::from_vec(vec![start]) }
23//!     }
24//! }
25//!
26//! impl Sampler for CounterSampler {
27//!     fn sample(&mut self, n_samples: usize) -> Vec<DVector<f64>> {
28//!         (0..n_samples).map(|_| self.step()).collect()
29//!     }
30//!
31//!     fn step(&mut self) -> DVector<f64> {
32//!         self.state[0] += 1.0;
33//!         self.state.clone()
34//!     }
35//!
36//!     fn current_state(&self) -> &DVector<f64> {
37//!         &self.state
38//!     }
39//! }
40//!
41//! let mut samplers = [CounterSampler::new(0.0), CounterSampler::new(1.0)];
42//! let output = run_multiple_chains(&mut samplers, 5, 10).unwrap();
43//!
44//! assert_eq!(output.chains.len(), 2);
45//! assert_eq!(output.chains[0].len(), 10);
46//! assert!(output.diagnostics.r_hat.is_some());
47//! assert_eq!(output.summary.parameters.len(), 1);
48//! ```
49
50use crate::diagnostics::{McmcDiagnosticSummary, McmcDiagnostics};
51use crate::error::{BayesError, Result};
52use crate::samplers::Sampler;
53use nalgebra::DVector;
54
55/// Raw multi-chain samples plus their diagnostics summary.
56#[cfg_attr(feature = "serde", derive(serde::Serialize))]
57#[derive(Debug, Clone, PartialEq)]
58pub struct MultiChainOutput {
59    /// Samples for each chain, preserving sampler order.
60    pub chains: Vec<Vec<DVector<f64>>>,
61    /// Full diagnostics computed from [`Self::chains`].
62    pub diagnostics: McmcDiagnostics,
63    /// Compact per-parameter summary of R-hat, ESS, and MCSE.
64    pub summary: McmcDiagnosticSummary,
65}
66
67/// Run multiple already-constructed samplers with a shared warmup/sample schedule.
68///
69/// Each sampler is run with [`Sampler::sample_with_warmup`], so sampler-level
70/// statistics are reset after warmup and describe only the returned samples. The returned chains
71/// are then passed to [`McmcDiagnostics::from_multiple_chains`], so the same
72/// validation rules apply: at least two samplers and at least two retained
73/// samples per chain are required for R-hat, and the generated chains must be
74/// non-empty, equal length, finite, and have consistent dimensions. The sampler
75/// count and retained sample count are checked before execution, so invalid
76/// inputs do not advance or otherwise mutate the samplers.
77///
78/// For reproducible stochastic runs, construct each sampler with an explicit and
79/// distinct seed before calling this helper.
80///
81/// # Errors
82///
83/// Returns [`BayesError::InvalidParameter`] if fewer than two samplers are
84/// provided or if fewer than two posterior samples are requested. Other
85/// validation or numerical errors are propagated from
86/// [`McmcDiagnostics::from_multiple_chains`].
87pub fn run_multiple_chains<S>(
88    samplers: &mut [S],
89    n_warmup: usize,
90    n_samples: usize,
91) -> Result<MultiChainOutput>
92where
93    S: Sampler,
94{
95    if samplers.len() < 2 {
96        return Err(BayesError::invalid_parameter(
97            "At least two samplers are required for multi-chain diagnostics",
98        ));
99    }
100
101    if n_samples < 2 {
102        return Err(BayesError::invalid_parameter(
103            "At least two samples per chain are required for multi-chain diagnostics",
104        ));
105    }
106
107    let chains: Vec<Vec<DVector<f64>>> = samplers
108        .iter_mut()
109        .map(|sampler| sampler.sample_with_warmup(n_warmup, n_samples))
110        .collect();
111    let diagnostics = McmcDiagnostics::from_multiple_chains(&chains)?;
112    let summary = diagnostics.summary();
113
114    Ok(MultiChainOutput {
115        chains,
116        diagnostics,
117        summary,
118    })
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[derive(Debug, Clone)]
126    struct DeterministicSampler {
127        state: DVector<f64>,
128        steps: usize,
129        resets: usize,
130    }
131
132    impl DeterministicSampler {
133        fn new(start: f64) -> Self {
134            Self {
135                state: DVector::from_vec(vec![start]),
136                steps: 0,
137                resets: 0,
138            }
139        }
140    }
141
142    impl Sampler for DeterministicSampler {
143        fn sample(&mut self, n_samples: usize) -> Vec<DVector<f64>> {
144            (0..n_samples).map(|_| self.step()).collect()
145        }
146
147        fn step(&mut self) -> DVector<f64> {
148            self.steps += 1;
149            self.state[0] += 1.0;
150            self.state.clone()
151        }
152
153        fn current_state(&self) -> &DVector<f64> {
154            &self.state
155        }
156
157        fn reset_statistics(&mut self) {
158            self.resets += 1;
159        }
160    }
161
162    #[test]
163    fn run_multiple_chains_returns_raw_chains_diagnostics_and_summary() {
164        let mut samplers = [
165            DeterministicSampler::new(0.0),
166            DeterministicSampler::new(10.0),
167        ];
168
169        let output = run_multiple_chains(&mut samplers, 2, 4).unwrap();
170
171        assert_eq!(
172            output.chains,
173            vec![
174                vec![
175                    DVector::from_vec(vec![3.0]),
176                    DVector::from_vec(vec![4.0]),
177                    DVector::from_vec(vec![5.0]),
178                    DVector::from_vec(vec![6.0]),
179                ],
180                vec![
181                    DVector::from_vec(vec![13.0]),
182                    DVector::from_vec(vec![14.0]),
183                    DVector::from_vec(vec![15.0]),
184                    DVector::from_vec(vec![16.0]),
185                ],
186            ]
187        );
188        assert!(output.diagnostics.r_hat.is_some());
189        assert_eq!(output.summary.parameters.len(), 1);
190        assert_eq!(output.summary.parameters[0].parameter_index, 0);
191        assert_eq!(
192            output.summary.parameters[0].effective_sample_size,
193            output.diagnostics.effective_sample_size[0]
194        );
195        assert_eq!(
196            output.summary.parameters[0].mc_se,
197            output.diagnostics.mc_se[0]
198        );
199        assert_eq!(
200            output.summary.parameters[0].r_hat,
201            Some(output.diagnostics.r_hat.as_ref().unwrap()[0])
202        );
203        assert_eq!(samplers[0].steps, 6);
204        assert_eq!(samplers[1].steps, 6);
205        assert_eq!(samplers[0].resets, 1);
206        assert_eq!(samplers[1].resets, 1);
207    }
208
209    #[test]
210    fn run_multiple_chains_rejects_empty_sampler_list() {
211        let mut samplers: [DeterministicSampler; 0] = [];
212
213        assert!(run_multiple_chains(&mut samplers, 0, 4).is_err());
214    }
215
216    #[test]
217    fn run_multiple_chains_rejects_single_sampler_without_mutating_it() {
218        let mut samplers = [DeterministicSampler::new(0.0)];
219
220        assert!(run_multiple_chains(&mut samplers, 2, 4).is_err());
221        assert_eq!(samplers[0].steps, 0);
222        assert_eq!(samplers[0].resets, 0);
223        assert_eq!(samplers[0].current_state(), &DVector::from_vec(vec![0.0]));
224    }
225
226    #[test]
227    fn run_multiple_chains_rejects_too_few_samples_without_mutating_samplers() {
228        let mut samplers = [
229            DeterministicSampler::new(0.0),
230            DeterministicSampler::new(10.0),
231        ];
232
233        assert!(run_multiple_chains(&mut samplers, 0, 1).is_err());
234        assert_eq!(samplers[0].steps, 0);
235        assert_eq!(samplers[1].steps, 0);
236        assert_eq!(samplers[0].resets, 0);
237        assert_eq!(samplers[1].resets, 0);
238        assert_eq!(samplers[0].current_state(), &DVector::from_vec(vec![0.0]));
239        assert_eq!(samplers[1].current_state(), &DVector::from_vec(vec![10.0]));
240    }
241
242    #[cfg(feature = "serde")]
243    #[test]
244    fn multi_chain_output_serializes_to_json() {
245        let mut samplers = [
246            DeterministicSampler::new(0.0),
247            DeterministicSampler::new(10.0),
248        ];
249        let output = run_multiple_chains(&mut samplers, 2, 4).unwrap();
250
251        let json = serde_json::to_value(&output).unwrap();
252
253        assert_eq!(json["chains"].as_array().unwrap().len(), 2);
254        assert_eq!(json["summary"]["parameters"].as_array().unwrap().len(), 1);
255        assert_eq!(json["diagnostics"]["mean"].as_array().unwrap().len(), 1);
256    }
257}