1use crate::diagnostics::{McmcDiagnosticSummary, McmcDiagnostics};
51use crate::error::{BayesError, Result};
52use crate::samplers::Sampler;
53use nalgebra::DVector;
54
55#[cfg_attr(feature = "serde", derive(serde::Serialize))]
57#[derive(Debug, Clone, PartialEq)]
58pub struct MultiChainOutput {
59 pub chains: Vec<Vec<DVector<f64>>>,
61 pub diagnostics: McmcDiagnostics,
63 pub summary: McmcDiagnosticSummary,
65}
66
67pub fn run_multiple_chains<S>(
88 samplers: &mut [S],
89 n_warmup: usize,
90 n_samples: usize,
91) -> Result<MultiChainOutput>
92where
93 S: Sampler,
94{
95 if samplers.len() < 2 {
96 return Err(BayesError::invalid_parameter(
97 "At least two samplers are required for multi-chain diagnostics",
98 ));
99 }
100
101 if n_samples < 2 {
102 return Err(BayesError::invalid_parameter(
103 "At least two samples per chain are required for multi-chain diagnostics",
104 ));
105 }
106
107 let chains: Vec<Vec<DVector<f64>>> = samplers
108 .iter_mut()
109 .map(|sampler| sampler.sample_with_warmup(n_warmup, n_samples))
110 .collect();
111 let diagnostics = McmcDiagnostics::from_multiple_chains(&chains)?;
112 let summary = diagnostics.summary();
113
114 Ok(MultiChainOutput {
115 chains,
116 diagnostics,
117 summary,
118 })
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[derive(Debug, Clone)]
126 struct DeterministicSampler {
127 state: DVector<f64>,
128 steps: usize,
129 resets: usize,
130 }
131
132 impl DeterministicSampler {
133 fn new(start: f64) -> Self {
134 Self {
135 state: DVector::from_vec(vec![start]),
136 steps: 0,
137 resets: 0,
138 }
139 }
140 }
141
142 impl Sampler for DeterministicSampler {
143 fn sample(&mut self, n_samples: usize) -> Vec<DVector<f64>> {
144 (0..n_samples).map(|_| self.step()).collect()
145 }
146
147 fn step(&mut self) -> DVector<f64> {
148 self.steps += 1;
149 self.state[0] += 1.0;
150 self.state.clone()
151 }
152
153 fn current_state(&self) -> &DVector<f64> {
154 &self.state
155 }
156
157 fn reset_statistics(&mut self) {
158 self.resets += 1;
159 }
160 }
161
162 #[test]
163 fn run_multiple_chains_returns_raw_chains_diagnostics_and_summary() {
164 let mut samplers = [
165 DeterministicSampler::new(0.0),
166 DeterministicSampler::new(10.0),
167 ];
168
169 let output = run_multiple_chains(&mut samplers, 2, 4).unwrap();
170
171 assert_eq!(
172 output.chains,
173 vec![
174 vec![
175 DVector::from_vec(vec![3.0]),
176 DVector::from_vec(vec![4.0]),
177 DVector::from_vec(vec![5.0]),
178 DVector::from_vec(vec![6.0]),
179 ],
180 vec![
181 DVector::from_vec(vec![13.0]),
182 DVector::from_vec(vec![14.0]),
183 DVector::from_vec(vec![15.0]),
184 DVector::from_vec(vec![16.0]),
185 ],
186 ]
187 );
188 assert!(output.diagnostics.r_hat.is_some());
189 assert_eq!(output.summary.parameters.len(), 1);
190 assert_eq!(output.summary.parameters[0].parameter_index, 0);
191 assert_eq!(
192 output.summary.parameters[0].effective_sample_size,
193 output.diagnostics.effective_sample_size[0]
194 );
195 assert_eq!(
196 output.summary.parameters[0].mc_se,
197 output.diagnostics.mc_se[0]
198 );
199 assert_eq!(
200 output.summary.parameters[0].r_hat,
201 Some(output.diagnostics.r_hat.as_ref().unwrap()[0])
202 );
203 assert_eq!(samplers[0].steps, 6);
204 assert_eq!(samplers[1].steps, 6);
205 assert_eq!(samplers[0].resets, 1);
206 assert_eq!(samplers[1].resets, 1);
207 }
208
209 #[test]
210 fn run_multiple_chains_rejects_empty_sampler_list() {
211 let mut samplers: [DeterministicSampler; 0] = [];
212
213 assert!(run_multiple_chains(&mut samplers, 0, 4).is_err());
214 }
215
216 #[test]
217 fn run_multiple_chains_rejects_single_sampler_without_mutating_it() {
218 let mut samplers = [DeterministicSampler::new(0.0)];
219
220 assert!(run_multiple_chains(&mut samplers, 2, 4).is_err());
221 assert_eq!(samplers[0].steps, 0);
222 assert_eq!(samplers[0].resets, 0);
223 assert_eq!(samplers[0].current_state(), &DVector::from_vec(vec![0.0]));
224 }
225
226 #[test]
227 fn run_multiple_chains_rejects_too_few_samples_without_mutating_samplers() {
228 let mut samplers = [
229 DeterministicSampler::new(0.0),
230 DeterministicSampler::new(10.0),
231 ];
232
233 assert!(run_multiple_chains(&mut samplers, 0, 1).is_err());
234 assert_eq!(samplers[0].steps, 0);
235 assert_eq!(samplers[1].steps, 0);
236 assert_eq!(samplers[0].resets, 0);
237 assert_eq!(samplers[1].resets, 0);
238 assert_eq!(samplers[0].current_state(), &DVector::from_vec(vec![0.0]));
239 assert_eq!(samplers[1].current_state(), &DVector::from_vec(vec![10.0]));
240 }
241
242 #[cfg(feature = "serde")]
243 #[test]
244 fn multi_chain_output_serializes_to_json() {
245 let mut samplers = [
246 DeterministicSampler::new(0.0),
247 DeterministicSampler::new(10.0),
248 ];
249 let output = run_multiple_chains(&mut samplers, 2, 4).unwrap();
250
251 let json = serde_json::to_value(&output).unwrap();
252
253 assert_eq!(json["chains"].as_array().unwrap().len(), 2);
254 assert_eq!(json["summary"]["parameters"].as_array().unwrap().len(), 1);
255 assert_eq!(json["diagnostics"]["mean"].as_array().unwrap().len(), 1);
256 }
257}