mini_mcmc/core.rs
1/*!
2# Core MCMC Utilities.
3
4This module provides core functionality for running Markov Chain Monte Carlo (MCMC) chains in parallel.
5It includes:
6- The [`MarkovChain<T>`] trait, which abstracts a single MCMC chain.
7- Utility functions [`run_chain`] and [`run_chain_progress`] for executing a single chain and collecting its states.
8- The [`HasChains<T>`] trait for types that own multiple Markov chains.
9- The [`ChainRunner<T>`] trait that extends [`HasChains<T>`] with methods to run chains in parallel (using Rayon), discarding burn-in and optionally displaying progress bars.
10
11Any type implementing [`HasChains<T>`] (with the required trait bounds) automatically implements [`ChainRunner<T>`] via a blanket implementation.
12
13This module is generic over the state type using [`ndarray::LinalgScalar`].
14*/
15
16use crate::stats::{collect_rhat, ChainStats, ChainTracker, RunStats};
17use indicatif::ProgressBar;
18use indicatif::{MultiProgress, ProgressStyle};
19use ndarray::stack;
20use ndarray::{prelude::*, LinalgScalar, ShapeError};
21use ndarray_stats::QuantileExt;
22use num_traits::{Float, FromPrimitive};
23use rand::rngs::SmallRng;
24use rand::SeedableRng;
25use rand_distr::{Distribution, StandardNormal};
26use rayon::prelude::*;
27use std::cmp::PartialEq;
28use std::error::Error;
29use std::marker::Send;
30use std::sync::mpsc::{self, Receiver, Sender};
31use std::thread::{self};
32use std::time::{Duration, Instant};
33
34/// A trait that abstracts a single MCMC chain.
35///
36/// A type implementing [`MarkovChain<T>`] must provide:
37/// - `step()`: advances the chain one iteration and returns a reference to the updated state.
38/// - `current_state()`: returns a reference to the current state without modifying the chain.
39pub trait MarkovChain<T> {
40 /// Performs one iteration of the chain and returns a reference to the new state.
41 fn step(&mut self) -> &Vec<T>;
42
43 /// Returns a reference to the current state of the chain without advancing it.
44 fn current_state(&self) -> &Vec<T>;
45}
46
47/// Runs a single MCMC chain for a specified number of steps.
48///
49/// This function repeatedly calls the chain's `step()` method and collects each state into a
50/// [`ndarray::Array2<T>`] of shape `[n_collect, D]` where:
51/// - `n_collect`: number of samples to collect
52/// - `D`: dimensionality of the state space
53///
54/// Each row corresponds to one collected state of the chain.
55pub fn run_chain<T, M>(chain: &mut M, n_collect: usize, n_discard: usize) -> Array2<T>
56where
57 M: MarkovChain<T>,
58 T: LinalgScalar,
59{
60 let dim = chain.current_state().len();
61 let mut out = Array2::<T>::zeros((n_collect, dim));
62 let total = n_collect + n_discard;
63
64 for i in 0..total {
65 let state = chain.step();
66 if i >= n_discard {
67 let state_arr = ArrayView::from_shape(state.len(), state.as_slice()).unwrap();
68 out.row_mut(i - n_discard).assign(&state_arr);
69 }
70 }
71
72 out
73}
74
75/// Runs a single MCMC chain for a `n_collect` + `n_discard` number of steps while displaying progress.
76///
77/// This function is similar to [`run_chain`], but it accepts an [`indicatif::ProgressBar`]
78/// that is updated as the chain advances.
79///
80/// # Arguments
81///
82/// * `chain` - A mutable reference to an object implementing [`MarkovChain<T>`].
83/// * `n_collect` - The number of samples to collect and return.
84/// * `n_discard` - The number of samples to discard (burn-in).
85/// * `tx` - A [`Sender<ChainStats>`] object for communication with chains-managing parent thread.
86///
87/// # Returns
88///
89/// A [`ndarray::Array2<T>`] containing the chain's states with `n_collect` number of rows.
90pub fn run_chain_progress<T, M>(
91 chain: &mut M,
92 n_collect: usize,
93 n_discard: usize,
94 tx: Sender<ChainStats>,
95) -> Result<Array2<T>, String>
96where
97 M: MarkovChain<T>,
98 T: LinalgScalar + PartialEq + num_traits::ToPrimitive,
99{
100 let n_params = chain.current_state().len();
101 let mut out = Array2::<T>::zeros((n_collect, n_params));
102
103 let mut tracker = ChainTracker::new(n_params, chain.current_state());
104 let mut last = Instant::now();
105 let freq = Duration::from_secs(1);
106 let total = n_discard + n_collect;
107
108 for i in 0..total {
109 let current_state = chain.step();
110 tracker.step(current_state).map_err(|e| {
111 let msg = format!(
112 "Chain statistics tracker caused error: {}.\nAborting generation of further samples.",
113 e
114 );
115 println!("{}", msg);
116 msg
117 })?;
118
119 let now = Instant::now();
120 if (now >= last + freq) | (i == total - 1) {
121 if let Err(e) = tx.send(tracker.stats()) {
122 eprintln!("Sending chain statistics failed: {e}");
123 }
124 last = now;
125 }
126
127 if i >= n_discard {
128 out.row_mut(i - n_discard).assign(
129 &ArrayView1::from_shape(current_state.len(), current_state.as_slice()).unwrap(),
130 );
131 }
132 }
133
134 // TODO: Somehow save state of the chains and enable continuing runs
135 Ok(out)
136}
137
138/// A trait for types that own multiple MCMC chains.
139///
140/// - `T` is the type of the state elements (e.g., `f64`).
141/// - `Chain` is the concrete type of the individual chain, which must implement [`MarkovChain<T>`]
142/// and be [`Send`].
143///
144/// Implementors must provide a method to access the internal vector of chains.
145pub trait HasChains<S> {
146 type Chain: MarkovChain<S> + Send;
147
148 /// Returns a mutable reference to the vector of chains.
149 fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
150}
151
152/// An extension trait for types that own multiple MCMC chains.
153///
154/// [`ChainRunner<T>`] extends [`HasChains<T>`] by providing default methods to run all chains
155/// in parallel. These methods allow you to:
156/// - Run all chains, collect `n_collect` samples and discard `n_discard` initial burn-in samples.
157/// - Optionally display progress bars for each chain during execution.
158///
159/// Any type that implements [`HasChains<T>`] (with appropriate bounds on `T`) automatically implements
160/// [`ChainRunner<T>`].
161pub trait ChainRunner<T>: HasChains<T>
162where
163 T: LinalgScalar + PartialEq + Send + num_traits::ToPrimitive,
164{
165 /// Runs all chains in parallel, discarding the first `discard` iterations (burn-in).
166 ///
167 /// # Arguments
168 ///
169 /// * `n_collect` - The number of samples to collect and return.
170 /// * `n_discard` - The number of samples to discard (burn-in).
171 ///
172 /// # Returns
173 ///
174 /// A [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
175 /// step and the last one the parameter dimension.
176 fn run(&mut self, n_collect: usize, n_discard: usize) -> Result<Array3<T>, ShapeError> {
177 // Run them all in parallel
178 let results: Vec<Array2<T>> = self
179 .chains_mut()
180 .par_iter_mut()
181 .map(|chain| run_chain(chain, n_collect, n_discard))
182 .collect();
183 let views: Vec<ArrayView2<T>> = results.iter().map(|x| x.view()).collect();
184 let out: Array3<T> = stack(Axis(0), &views)?;
185 Ok(out)
186 }
187
188 /// Runs all chains in parallel with progress bars, discarding the burn-in.
189 ///
190 /// Each chain is run in parallel with its own progress bar. After execution, the first `discard`
191 /// iterations are discarded.
192 ///
193 /// # Arguments
194 ///
195 /// * `n_collect` - The number of samples to collect and return.
196 /// * `n_discard` - The number of samples to discard (burn-in).
197 ///
198 /// # Returns
199 ///
200 /// Returns a tuple containing:
201 /// - A [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
202 /// step and the last one the parameter dimension.
203 /// - A `RunStats` object containing convergence statistics including:
204 /// - Acceptance probability
205 /// - Potential scale reduction factor (R-hat)
206 /// - Effective sample size (ESS)
207 /// - Other convergence diagnostics
208 fn run_progress(
209 &mut self,
210 n_collect: usize,
211 n_discard: usize,
212 ) -> Result<(Array3<T>, RunStats), Box<dyn Error>> {
213 // Channels.
214 // Each chain gets its own channel. Hence, we have `n_chains` channels.
215 // The objects sent over channels are Array2<f32>s ($s_m^2$, $\bar{\theta}_m^{(\bullet)}$).
216 // The child thread sends it's respective one to the parent thread.
217 // The parent thread assemples the tuples it receives to compute Rhat.
218
219 let chains = self.chains_mut();
220
221 let mut rxs: Vec<Receiver<ChainStats>> = vec![];
222 let mut txs: Vec<Sender<ChainStats>> = vec![];
223 (0..chains.len()).for_each(|_| {
224 let (tx, rx) = mpsc::channel();
225 rxs.push(rx);
226 txs.push(tx);
227 });
228
229 let progress_handle = thread::spawn(move || {
230 let sleep_ms = Duration::from_millis(250);
231 let timeout_ms = Duration::from_millis(0);
232 let multi = MultiProgress::new();
233
234 let pb_style = ProgressStyle::default_bar()
235 .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
236 .unwrap()
237 .progress_chars("=>-");
238 let total: u64 = (n_collect + n_discard).try_into().unwrap();
239
240 // Global Progress bar
241 let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
242 global_pb.set_style(pb_style.clone());
243 global_pb.set_prefix("Global");
244
245 let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
246 .map(|chain_idx| {
247 let pb = multi.add(ProgressBar::new(total));
248 pb.set_style(pb_style.clone());
249 pb.set_prefix(format!("Chain {chain_idx}"));
250 (chain_idx, pb)
251 })
252 .collect();
253 let mut next_active = active.len();
254 let mut n_finished = 0;
255 let mut most_recent = vec![None; rxs.len()];
256 let mut total_progress;
257
258 loop {
259 for (i, rx) in rxs.iter().enumerate() {
260 while let Ok(stats) = rx.recv_timeout(timeout_ms) {
261 most_recent[i] = Some(stats)
262 }
263 }
264
265 // Update chain progress bar messages
266 // and compute average acceptance probability
267 let mut to_replace = vec![false; active.len()];
268 let mut avg_p_accept = 0.0;
269 let mut n_available_stats = 0.0;
270 for (vec_idx, (i, pb)) in active.iter().enumerate() {
271 if let Some(stats) = &most_recent[*i] {
272 pb.set_position(stats.n);
273 pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
274 avg_p_accept += stats.p_accept;
275 n_available_stats += 1.0;
276
277 if stats.n == total {
278 to_replace[vec_idx] = true;
279 n_finished += 1;
280 }
281 }
282 }
283 avg_p_accept /= n_available_stats;
284
285 // Update global progress bar
286 total_progress = 0;
287 for stats in most_recent.iter().flatten() {
288 total_progress += stats.n;
289 }
290 global_pb.set_position(total_progress);
291 let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
292 if valid.len() >= 2 {
293 let rhats = collect_rhat(valid.as_slice());
294 let max = rhats.max_skipnan();
295 global_pb.set_message(format!(
296 "p(accept)≈{:.2} max(rhat)≈{:.2}",
297 avg_p_accept, max
298 ))
299 }
300
301 let mut to_remove = vec![];
302 for (i, replace) in to_replace.iter().enumerate() {
303 if *replace && next_active < most_recent.len() {
304 let pb = multi.add(ProgressBar::new(total));
305 pb.set_style(pb_style.clone());
306 pb.set_prefix(format!("Chain {next_active}"));
307 active[i] = (next_active, pb);
308 next_active += 1;
309 } else if *replace {
310 to_remove.push(i);
311 }
312 }
313
314 to_remove.sort();
315 for i in to_remove.iter().rev() {
316 active.remove(*i);
317 }
318
319 if n_finished >= most_recent.len() {
320 break;
321 }
322 std::thread::sleep(sleep_ms);
323 }
324 });
325
326 let chain_samples: Vec<Array2<T>> = thread::scope(|s| {
327 let handles: Vec<thread::ScopedJoinHandle<Array2<T>>> = chains
328 .iter_mut()
329 .zip(txs)
330 .map(|(chain, tx)| {
331 s.spawn(|| {
332 run_chain_progress(chain, n_collect, n_discard, tx)
333 .expect("Expected running chain to succeed.")
334 })
335 })
336 .collect();
337 handles
338 .into_iter()
339 .map(|h| {
340 h.join()
341 .expect("Expected thread to succeed in generating sample.")
342 })
343 .collect()
344 });
345 let sample: Array3<T> = stack(
346 Axis(0),
347 &chain_samples
348 .iter()
349 .map(|x| x.view())
350 .collect::<Vec<ArrayView2<T>>>(),
351 )?;
352
353 if let Err(e) = progress_handle.join() {
354 eprintln!("Progress bar thread emitted error message: {:?}", e);
355 }
356
357 let run_stats = RunStats::from(sample.view());
358
359 Ok((sample, run_stats))
360 }
361}
362
363impl<T: LinalgScalar + Send + PartialEq + num_traits::ToPrimitive, R: HasChains<T>> ChainRunner<T>
364 for R
365{
366}
367
368/// Generates a vector of random initial positions from a standard normal distribution.
369///
370/// Each position is a `Vec<T>` of length `d` representing a point in `d`-dimensional space.
371/// The function returns `n` such positions.
372///
373/// # Type Parameters
374/// - `T`: The numeric type (e.g., `f32`, `f64`). Must implement `Float + FromPrimitive`.
375///
376/// # Parameters
377/// - `n`: Number of positions to generate.
378/// - `d`: Dimensionality of each position.
379///
380/// # Returns
381/// A `Vec<Vec<T>>` where each inner vector is a position in `d`-dimensional space.
382///
383/// # Panics
384/// Panics if a sample cannot be converted from `f64` to `T` (should never happen for `f32` or `f64`).
385///
386/// # Examples
387/// ```
388/// # use mini_mcmc::core::init;
389/// let positions: Vec<Vec<f32>> = init(5, 3);
390/// for pos in positions {
391/// println!("{:?}", pos);
392/// }
393/// ```
394pub fn init<T>(n: usize, d: usize) -> Vec<Vec<T>>
395where
396 T: Float + FromPrimitive,
397{
398 let rng = SmallRng::from_entropy();
399 _init(n, d, rng)
400}
401
402/// Generates `n` pseudo-random vectors from the `d` dimensional standard normal distribution.
403/// This function calls [`init_with_seed`] with the same parameters and seed 42.
404pub fn init_det<T>(n: usize, d: usize) -> Vec<Vec<T>>
405where
406 T: Float + FromPrimitive,
407{
408 init_with_seed(n, d, 42)
409}
410
411/// Generates `n` pseudo-random vectors from the `d` dimensional standard normal distribution.
412/// Same as [`init`] except this function returns a deterministic sample.
413pub fn init_with_seed<T>(n: usize, d: usize, seed: u64) -> Vec<Vec<T>>
414where
415 T: Float + FromPrimitive,
416{
417 let rng = SmallRng::seed_from_u64(seed);
418 _init(n, d, rng)
419}
420
421fn _init<T>(n: usize, d: usize, mut rng: SmallRng) -> Vec<Vec<T>>
422where
423 T: Float + FromPrimitive,
424{
425 (0..n)
426 .map(|_| {
427 (0..d)
428 .map(|_| {
429 let sample: f64 = StandardNormal.sample(&mut rng);
430 T::from_f64(sample).unwrap()
431 })
432 .collect()
433 })
434 .collect()
435}