mini_mcmc/core.rs
1/*!
2# Core MCMC Utilities.
3
4This module provides core functionality for running Markov Chain Monte Carlo (MCMC) chains in parallel.
5It includes:
6- The [`MarkovChain<T>`] trait, which abstracts a single MCMC chain.
7- Utility functions [`run_chain`] and [`run_chain_progress`] for executing a single chain and collecting its states.
8- The [`HasChains<T>`] trait for types that own multiple Markov chains.
9- The [`ChainRunner<T>`] trait that extends [`HasChains<T>`] with methods to run chains in parallel (using Rayon), discarding burn-in and optionally displaying progress bars.
10
11Any type implementing [`HasChains<T>`] (with the required trait bounds) automatically implements [`ChainRunner<T>`] via a blanket implementation.
12
13This module is generic over the state type using [`ndarray::LinalgScalar`].
14*/
15
16use crate::stats::{collect_rhat, ChainStats, ChainTracker};
17use indicatif::ProgressBar;
18use indicatif::{MultiProgress, ProgressStyle};
19use ndarray::stack;
20use ndarray::{prelude::*, LinalgScalar, ShapeError};
21use ndarray_stats::QuantileExt;
22use rayon::prelude::*;
23use std::cmp::PartialEq;
24use std::error::Error;
25use std::marker::Send;
26use std::sync::mpsc::{self, Receiver, Sender};
27use std::thread::{self};
28use std::time::{Duration, Instant};
29
30/// A trait that abstracts a single MCMC chain.
31///
32/// A type implementing [`MarkovChain<T>`] must provide:
33/// - `step()`: advances the chain one iteration and returns a reference to the updated state.
34/// - `current_state()`: returns a reference to the current state without modifying the chain.
35pub trait MarkovChain<T> {
36 /// Performs one iteration of the chain and returns a reference to the new state.
37 fn step(&mut self) -> &Vec<T>;
38
39 /// Returns a reference to the current state of the chain without advancing it.
40 fn current_state(&self) -> &Vec<T>;
41}
42
43/// Runs a single MCMC chain for a specified number of steps.
44///
45/// This function repeatedly calls the chain's `step()` method and collects each state into a
46/// [`ndarray::Array2<T>`], where each row corresponds to one collected state of the chain.
47///
48/// # Arguments
49///
50/// * `chain` - A mutable reference to an object implementing [`MarkovChain<T>`].
51/// * `n_collect` - The number of samples to collect and return.
52/// * `n_discard` - The number of samples to discard (burn-in).
53///
54/// # Returns
55///
56/// A [`ndarray::Array2<T>`] where the number of rows equals `n_collect` and the number of columns equals
57/// the dimensionality of the chain's state.
58pub fn run_chain<T, M>(chain: &mut M, n_collect: usize, n_discard: usize) -> Array2<T>
59where
60 M: MarkovChain<T>,
61 T: LinalgScalar,
62{
63 let dim = chain.current_state().len();
64 let mut out = Array2::<T>::zeros((n_collect, dim));
65 let total = n_collect + n_discard;
66
67 for i in 0..total {
68 let state = chain.step();
69 if i >= n_discard {
70 let state_arr = ArrayView::from_shape(state.len(), state.as_slice()).unwrap();
71 out.row_mut(i - n_discard).assign(&state_arr);
72 }
73 }
74
75 out
76}
77
78/// Runs a single MCMC chain for a `n_collect` + `n_discard` number of steps while displaying progress.
79///
80/// This function is similar to [`run_chain`], but it accepts an [`indicatif::ProgressBar`]
81/// that is updated as the chain advances.
82///
83/// # Arguments
84///
85/// * `chain` - A mutable reference to an object implementing [`MarkovChain<T>`].
86/// * `n_collect` - The number of samples to collect and return.
87/// * `n_discard` - The number of samples to discard (burn-in).
88/// * `tx` - A [`Sender<ChainStats>`] object for communication with chains-managing parent thread.
89///
90/// # Returns
91///
92/// A [`ndarray::Array2<T>`] containing the chain's states with `n_collect` number of rows.
93pub fn run_chain_progress<T, M>(
94 chain: &mut M,
95 n_collect: usize,
96 n_discard: usize,
97 tx: Sender<ChainStats>,
98) -> Result<Array2<T>, String>
99where
100 M: MarkovChain<T>,
101 T: LinalgScalar + PartialEq + num_traits::ToPrimitive,
102{
103 let n_params = chain.current_state().len();
104 let mut out = Array2::<T>::zeros((n_collect, n_params));
105
106 let mut tracker = ChainTracker::new(n_params, chain.current_state());
107 let mut last = Instant::now();
108 let freq = Duration::from_secs(1);
109 let total = n_discard + n_collect;
110
111 for i in 0..total {
112 let current_state = chain.step();
113 tracker.step(current_state).map_err(|e| {
114 let msg = format!(
115 "Chain statistics tracker caused error: {}.\nAborting generation of further samples.",
116 e
117 );
118 println!("{}", msg);
119 msg
120 })?;
121
122 let now = Instant::now();
123 if (now >= last + freq) | (i == total - 1) {
124 if let Err(e) = tx.send(tracker.stats()) {
125 eprintln!("Sending chain statistics failed: {e}");
126 }
127 last = now;
128 }
129
130 if i >= n_discard {
131 out.row_mut(i - n_discard).assign(
132 &ArrayView1::from_shape(current_state.len(), current_state.as_slice()).unwrap(),
133 );
134 }
135 }
136
137 // TODO: Somehow save state of the chains and enable continuing runs
138 Ok(out)
139}
140
141/// A trait for types that own multiple MCMC chains.
142///
143/// - `T` is the type of the state elements (e.g., `f64`).
144/// - `Chain` is the concrete type of the individual chain, which must implement [`MarkovChain<T>`]
145/// and be [`Send`].
146///
147/// Implementors must provide a method to access the internal vector of chains.
148pub trait HasChains<S> {
149 type Chain: MarkovChain<S> + Send;
150
151 /// Returns a mutable reference to the vector of chains.
152 fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
153}
154
155/// An extension trait for types that own multiple MCMC chains.
156///
157/// [`ChainRunner<T>`] extends [`HasChains<T>`] by providing default methods to run all chains
158/// in parallel. These methods allow you to:
159/// - Run all chains, collect `n_collect` samples and discard `n_discard` initial burn-in samples.
160/// - Optionally display progress bars for each chain during execution.
161///
162/// Any type that implements [`HasChains<T>`] (with appropriate bounds on `T`) automatically implements
163/// [`ChainRunner<T>`].
164pub trait ChainRunner<T>: HasChains<T>
165where
166 T: LinalgScalar + PartialEq + Send + num_traits::ToPrimitive,
167{
168 /// Runs all chains in parallel, discarding the first `discard` iterations (burn-in).
169 ///
170 /// # Arguments
171 ///
172 /// * `n_collect` - The number of samples to collect and return.
173 /// * `n_discard` - The number of samples to discard (burn-in).
174 ///
175 /// # Returns
176 ///
177 /// A [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
178 /// step and the last one the parameter dimension.
179 fn run(&mut self, n_collect: usize, n_discard: usize) -> Result<Array3<T>, ShapeError> {
180 // Run them all in parallel
181 let results: Vec<Array2<T>> = self
182 .chains_mut()
183 .par_iter_mut()
184 .map(|chain| run_chain(chain, n_collect, n_discard))
185 .collect();
186 let views: Vec<ArrayView2<T>> = results.iter().map(|x| x.view()).collect();
187 let out: Array3<T> = stack(Axis(0), &views)?;
188 Ok(out)
189 }
190
191 /// Runs all chains in parallel with progress bars, discarding the burn-in.
192 ///
193 /// Each chain is run in parallel with its own progress bar. After execution, the first `discard`
194 /// iterations are discarded.
195 ///
196 /// # Arguments
197 ///
198 /// * `n_collect` - The number of samples to collect and return.
199 /// * `n_discard` - The number of samples to discard (burn-in).
200 ///
201 /// # Returns
202 ///
203 /// Returns a [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
204 /// step and the last one the parameter dimension.
205 fn run_progress(
206 &mut self,
207 n_collect: usize,
208 n_discard: usize,
209 ) -> Result<Array3<T>, Box<dyn Error>> {
210 // Channels.
211 // Each chain gets its own channel. Hence, we have `n_chains` channels.
212 // The objects sent over channels are Array2<f32>s ($s_m^2$, $\bar{\theta}_m^{(\bullet)}$).
213 // The child thread sends it's respective one to the parent thread.
214 // The parent thread assemples the tuples it receives to compute Rhat.
215
216 let chains = self.chains_mut();
217
218 let mut rxs: Vec<Receiver<ChainStats>> = vec![];
219 let mut txs: Vec<Sender<ChainStats>> = vec![];
220 (0..chains.len()).for_each(|_| {
221 let (tx, rx) = mpsc::channel();
222 rxs.push(rx);
223 txs.push(tx);
224 });
225
226 let progress_handle = thread::spawn(move || {
227 let sleep_ms = Duration::from_millis(250);
228 let timeout_ms = Duration::from_millis(0);
229 let multi = MultiProgress::new();
230
231 let pb_style = ProgressStyle::default_bar()
232 .template("{prefix:8} {bar:40.white} ETA {eta:3} | {msg}")
233 .unwrap()
234 .progress_chars("=>-");
235 let total: u64 = (n_collect + n_discard).try_into().unwrap();
236
237 // Global Progress bar
238 let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
239 global_pb.set_style(pb_style.clone());
240 global_pb.set_prefix("Global");
241
242 let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
243 .map(|chain_idx| {
244 let pb = multi.add(ProgressBar::new(total));
245 pb.set_style(pb_style.clone());
246 pb.set_prefix(format!("Chain {chain_idx}"));
247 (chain_idx, pb)
248 })
249 .collect();
250 let mut next_active = active.len();
251 let mut n_finished = 0;
252 let mut most_recent = vec![None; rxs.len()];
253 let mut total_progress;
254
255 loop {
256 for (i, rx) in rxs.iter().enumerate() {
257 while let Ok(stats) = rx.recv_timeout(timeout_ms) {
258 most_recent[i] = Some(stats)
259 }
260 }
261
262 // Update chain progress bar messages
263 // and compute average acceptance probability
264 let mut to_replace = vec![false; active.len()];
265 let mut avg_p_accept = 0.0;
266 let mut n_available_stats = 0.0;
267 for (vec_idx, (i, pb)) in active.iter().enumerate() {
268 if let Some(stats) = &most_recent[*i] {
269 pb.set_position(stats.n);
270 pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
271 avg_p_accept += stats.p_accept;
272 n_available_stats += 1.0;
273
274 if stats.n == total {
275 to_replace[vec_idx] = true;
276 n_finished += 1;
277 }
278 }
279 }
280 avg_p_accept /= n_available_stats;
281
282 // Update global progress bar
283 total_progress = 0;
284 for stats in most_recent.iter().flatten() {
285 total_progress += stats.n;
286 }
287 global_pb.set_position(total_progress);
288 let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
289 if valid.len() >= 2 {
290 let rhats = collect_rhat(valid.as_slice());
291 let max = rhats.max_skipnan();
292 global_pb.set_message(format!(
293 "p(accept)≈{:.2} max(rhat)≈{:.2}",
294 avg_p_accept, max
295 ))
296 }
297
298 let mut to_remove = vec![];
299 for (i, replace) in to_replace.iter().enumerate() {
300 if *replace && next_active < most_recent.len() {
301 let pb = multi.add(ProgressBar::new(total));
302 pb.set_style(pb_style.clone());
303 pb.set_prefix(format!("Chain {next_active}"));
304 active[i] = (next_active, pb);
305 next_active += 1;
306 } else if *replace {
307 to_remove.push(i);
308 }
309 }
310
311 to_remove.sort();
312 for i in to_remove.iter().rev() {
313 active.remove(*i);
314 }
315
316 if n_finished >= most_recent.len() {
317 break;
318 }
319 std::thread::sleep(sleep_ms);
320 }
321 });
322
323 let samples: Vec<Array2<T>> = thread::scope(|s| {
324 let handles: Vec<thread::ScopedJoinHandle<Array2<T>>> = chains
325 .iter_mut()
326 .zip(txs)
327 .map(|(chain, tx)| {
328 s.spawn(|| {
329 run_chain_progress(chain, n_collect, n_discard, tx)
330 .expect("Expected running chain to succeed.")
331 })
332 })
333 .collect();
334 handles
335 .into_iter()
336 .map(|h| {
337 h.join()
338 .expect("Expected thread to succeed in generating sample.")
339 })
340 .collect()
341 });
342 let out: Array3<T> = stack(
343 Axis(0),
344 &samples
345 .iter()
346 .map(|x| x.view())
347 .collect::<Vec<ArrayView2<T>>>(),
348 )?;
349
350 if let Err(e) = progress_handle.join() {
351 eprintln!("Progress bar thread emitted error message: {:?}", e);
352 }
353 Ok(out)
354 }
355}
356
357impl<T: LinalgScalar + Send + PartialEq + num_traits::ToPrimitive, R: HasChains<T>> ChainRunner<T>
358 for R
359{
360}