mini_mcmc/
core.rs

1/*!
2# Core MCMC Utilities.
3
4This module provides core functionality for running Markov Chain Monte Carlo (MCMC) chains in parallel.
5It includes:
6- The [`MarkovChain<T>`] trait, which abstracts a single MCMC chain.
7- Utility functions [`run_chain`] and [`run_chain_progress`] for executing a single chain and collecting its states.
8- The [`HasChains<T>`] trait for types that own multiple Markov chains.
9- The [`ChainRunner<T>`] trait that extends [`HasChains<T>`] with methods to run chains in parallel (using Rayon), discarding burn-in and optionally displaying progress bars.
10
11Any type implementing [`HasChains<T>`] (with the required trait bounds) automatically implements [`ChainRunner<T>`] via a blanket implementation.
12
13This module is generic over the state type using [`ndarray::LinalgScalar`].
14*/
15
16use crate::stats::{collect_rhat, ChainStats, ChainTracker};
17use indicatif::ProgressBar;
18use indicatif::{MultiProgress, ProgressStyle};
19use ndarray::stack;
20use ndarray::{prelude::*, LinalgScalar, ShapeError};
21use ndarray_stats::QuantileExt;
22use rayon::prelude::*;
23use std::cmp::PartialEq;
24use std::error::Error;
25use std::marker::Send;
26use std::sync::mpsc::{self, Receiver, Sender};
27use std::thread::{self};
28use std::time::{Duration, Instant};
29
30/// A trait that abstracts a single MCMC chain.
31///
32/// A type implementing [`MarkovChain<T>`] must provide:
33/// - `step()`: advances the chain one iteration and returns a reference to the updated state.
34/// - `current_state()`: returns a reference to the current state without modifying the chain.
35pub trait MarkovChain<T> {
36    /// Performs one iteration of the chain and returns a reference to the new state.
37    fn step(&mut self) -> &Vec<T>;
38
39    /// Returns a reference to the current state of the chain without advancing it.
40    fn current_state(&self) -> &Vec<T>;
41}
42
43/// Runs a single MCMC chain for a specified number of steps.
44///
45/// This function repeatedly calls the chain's `step()` method and collects each state into a
46/// [`ndarray::Array2<T>`], where each row corresponds to one collected state of the chain.
47///
48/// # Arguments
49///
50/// * `chain` - A mutable reference to an object implementing [`MarkovChain<T>`].
51/// * `n_collect` - The number of samples to collect and return.
52/// * `n_discard` - The number of samples to discard (burn-in).
53///
54/// # Returns
55///
56/// A [`ndarray::Array2<T>`] where the number of rows equals `n_collect` and the number of columns equals
57/// the dimensionality of the chain's state.
58pub fn run_chain<T, M>(chain: &mut M, n_collect: usize, n_discard: usize) -> Array2<T>
59where
60    M: MarkovChain<T>,
61    T: LinalgScalar,
62{
63    let dim = chain.current_state().len();
64    let mut out = Array2::<T>::zeros((n_collect, dim));
65    let total = n_collect + n_discard;
66
67    for i in 0..total {
68        let state = chain.step();
69        if i >= n_discard {
70            let state_arr = ArrayView::from_shape(state.len(), state.as_slice()).unwrap();
71            out.row_mut(i - n_discard).assign(&state_arr);
72        }
73    }
74
75    out
76}
77
78/// Runs a single MCMC chain for a `n_collect` + `n_discard` number of steps while displaying progress.
79///
80/// This function is similar to [`run_chain`], but it accepts an [`indicatif::ProgressBar`]
81/// that is updated as the chain advances.
82///
83/// # Arguments
84///
85/// * `chain` - A mutable reference to an object implementing [`MarkovChain<T>`].
86/// * `n_collect` - The number of samples to collect and return.
87/// * `n_discard` - The number of samples to discard (burn-in).
88/// * `tx` - A [`Sender<ChainStats>`] object for communication with chains-managing parent thread.
89///
90/// # Returns
91///
92/// A [`ndarray::Array2<T>`] containing the chain's states with `n_collect` number of rows.
93pub fn run_chain_progress<T, M>(
94    chain: &mut M,
95    n_collect: usize,
96    n_discard: usize,
97    tx: Sender<ChainStats>,
98) -> Result<Array2<T>, String>
99where
100    M: MarkovChain<T>,
101    T: LinalgScalar + PartialEq + num_traits::ToPrimitive,
102{
103    let n_params = chain.current_state().len();
104    let mut out = Array2::<T>::zeros((n_collect, n_params));
105
106    let mut tracker = ChainTracker::new(n_params, chain.current_state());
107    let mut last = Instant::now();
108    let freq = Duration::from_secs(1);
109    let total = n_discard + n_collect;
110
111    for i in 0..total {
112        let current_state = chain.step();
113        tracker.step(current_state).map_err(|e| {
114            let msg = format!(
115            "Chain statistics tracker caused error: {}.\nAborting generation of further samples.",
116            e
117            );
118            println!("{}", msg);
119            msg
120        })?;
121
122        let now = Instant::now();
123        if (now >= last + freq) | (i == total - 1) {
124            if let Err(e) = tx.send(tracker.stats()) {
125                eprintln!("Sending chain statistics failed: {e}");
126            }
127            last = now;
128        }
129
130        if i >= n_discard {
131            out.row_mut(i - n_discard).assign(
132                &ArrayView1::from_shape(current_state.len(), current_state.as_slice()).unwrap(),
133            );
134        }
135    }
136
137    // TODO: Somehow save state of the chains and enable continuing runs
138    Ok(out)
139}
140
141/// A trait for types that own multiple MCMC chains.
142///
143/// - `T` is the type of the state elements (e.g., `f64`).
144/// - `Chain` is the concrete type of the individual chain, which must implement [`MarkovChain<T>`]
145///   and be [`Send`].
146///
147/// Implementors must provide a method to access the internal vector of chains.
148pub trait HasChains<S> {
149    type Chain: MarkovChain<S> + Send;
150
151    /// Returns a mutable reference to the vector of chains.
152    fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
153}
154
155/// An extension trait for types that own multiple MCMC chains.
156///
157/// [`ChainRunner<T>`] extends [`HasChains<T>`] by providing default methods to run all chains
158/// in parallel. These methods allow you to:
159/// - Run all chains, collect `n_collect` samples and discard `n_discard` initial burn-in samples.
160/// - Optionally display progress bars for each chain during execution.
161///
162/// Any type that implements [`HasChains<T>`] (with appropriate bounds on `T`) automatically implements
163/// [`ChainRunner<T>`].
164pub trait ChainRunner<T>: HasChains<T>
165where
166    T: LinalgScalar + PartialEq + Send + num_traits::ToPrimitive,
167{
168    /// Runs all chains in parallel, discarding the first `discard` iterations (burn-in).
169    ///
170    /// # Arguments
171    ///
172    /// * `n_collect` - The number of samples to collect and return.
173    /// * `n_discard` - The number of samples to discard (burn-in).
174    ///
175    /// # Returns
176    ///
177    /// A [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
178    /// step and the last one the parameter dimension.
179    fn run(&mut self, n_collect: usize, n_discard: usize) -> Result<Array3<T>, ShapeError> {
180        // Run them all in parallel
181        let results: Vec<Array2<T>> = self
182            .chains_mut()
183            .par_iter_mut()
184            .map(|chain| run_chain(chain, n_collect, n_discard))
185            .collect();
186        let views: Vec<ArrayView2<T>> = results.iter().map(|x| x.view()).collect();
187        let out: Array3<T> = stack(Axis(0), &views)?;
188        Ok(out)
189    }
190
191    /// Runs all chains in parallel with progress bars, discarding the burn-in.
192    ///
193    /// Each chain is run in parallel with its own progress bar. After execution, the first `discard`
194    /// iterations are discarded.
195    ///
196    /// # Arguments
197    ///
198    /// * `n_collect` - The number of samples to collect and return.
199    /// * `n_discard` - The number of samples to discard (burn-in).
200    ///
201    /// # Returns
202    ///
203    /// Returns a [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
204    /// step and the last one the parameter dimension.
205    fn run_progress(
206        &mut self,
207        n_collect: usize,
208        n_discard: usize,
209    ) -> Result<Array3<T>, Box<dyn Error>> {
210        // Channels.
211        // Each chain gets its own channel. Hence, we have `n_chains` channels.
212        // The objects sent over channels are Array2<f32>s ($s_m^2$, $\bar{\theta}_m^{(\bullet)}$).
213        // The child thread sends it's respective one to the parent thread.
214        // The parent thread assemples the tuples it receives to compute Rhat.
215
216        let chains = self.chains_mut();
217
218        let mut rxs: Vec<Receiver<ChainStats>> = vec![];
219        let mut txs: Vec<Sender<ChainStats>> = vec![];
220        (0..chains.len()).for_each(|_| {
221            let (tx, rx) = mpsc::channel();
222            rxs.push(rx);
223            txs.push(tx);
224        });
225
226        let progress_handle = thread::spawn(move || {
227            let sleep_ms = Duration::from_millis(250);
228            let timeout_ms = Duration::from_millis(0);
229            let multi = MultiProgress::new();
230
231            let pb_style = ProgressStyle::default_bar()
232                .template("{prefix:8} {bar:40.white} ETA {eta:3} | {msg}")
233                .unwrap()
234                .progress_chars("=>-");
235            let total: u64 = (n_collect + n_discard).try_into().unwrap();
236
237            // Global Progress bar
238            let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
239            global_pb.set_style(pb_style.clone());
240            global_pb.set_prefix("Global");
241
242            let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
243                .map(|chain_idx| {
244                    let pb = multi.add(ProgressBar::new(total));
245                    pb.set_style(pb_style.clone());
246                    pb.set_prefix(format!("Chain {chain_idx}"));
247                    (chain_idx, pb)
248                })
249                .collect();
250            let mut next_active = active.len();
251            let mut n_finished = 0;
252            let mut most_recent = vec![None; rxs.len()];
253            let mut total_progress;
254
255            loop {
256                for (i, rx) in rxs.iter().enumerate() {
257                    while let Ok(stats) = rx.recv_timeout(timeout_ms) {
258                        most_recent[i] = Some(stats)
259                    }
260                }
261
262                // Update chain progress bar messages
263                // and compute average acceptance probability
264                let mut to_replace = vec![false; active.len()];
265                let mut avg_p_accept = 0.0;
266                let mut n_available_stats = 0.0;
267                for (vec_idx, (i, pb)) in active.iter().enumerate() {
268                    if let Some(stats) = &most_recent[*i] {
269                        pb.set_position(stats.n);
270                        pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
271                        avg_p_accept += stats.p_accept;
272                        n_available_stats += 1.0;
273
274                        if stats.n == total {
275                            to_replace[vec_idx] = true;
276                            n_finished += 1;
277                        }
278                    }
279                }
280                avg_p_accept /= n_available_stats;
281
282                // Update global progress bar
283                total_progress = 0;
284                for stats in most_recent.iter().flatten() {
285                    total_progress += stats.n;
286                }
287                global_pb.set_position(total_progress);
288                let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
289                if valid.len() >= 2 {
290                    let rhats = collect_rhat(valid.as_slice());
291                    let max = rhats.max_skipnan();
292                    global_pb.set_message(format!(
293                        "p(accept)≈{:.2} max(rhat)≈{:.2}",
294                        avg_p_accept, max
295                    ))
296                }
297
298                let mut to_remove = vec![];
299                for (i, replace) in to_replace.iter().enumerate() {
300                    if *replace && next_active < most_recent.len() {
301                        let pb = multi.add(ProgressBar::new(total));
302                        pb.set_style(pb_style.clone());
303                        pb.set_prefix(format!("Chain {next_active}"));
304                        active[i] = (next_active, pb);
305                        next_active += 1;
306                    } else if *replace {
307                        to_remove.push(i);
308                    }
309                }
310
311                to_remove.sort();
312                for i in to_remove.iter().rev() {
313                    active.remove(*i);
314                }
315
316                if n_finished >= most_recent.len() {
317                    break;
318                }
319                std::thread::sleep(sleep_ms);
320            }
321        });
322
323        let samples: Vec<Array2<T>> = thread::scope(|s| {
324            let handles: Vec<thread::ScopedJoinHandle<Array2<T>>> = chains
325                .iter_mut()
326                .zip(txs)
327                .map(|(chain, tx)| {
328                    s.spawn(|| {
329                        run_chain_progress(chain, n_collect, n_discard, tx)
330                            .expect("Expected running chain to succeed.")
331                    })
332                })
333                .collect();
334            handles
335                .into_iter()
336                .map(|h| {
337                    h.join()
338                        .expect("Expected thread to succeed in generating sample.")
339                })
340                .collect()
341        });
342        let out: Array3<T> = stack(
343            Axis(0),
344            &samples
345                .iter()
346                .map(|x| x.view())
347                .collect::<Vec<ArrayView2<T>>>(),
348        )?;
349
350        if let Err(e) = progress_handle.join() {
351            eprintln!("Progress bar thread emitted error message: {:?}", e);
352        }
353        Ok(out)
354    }
355}
356
357impl<T: LinalgScalar + Send + PartialEq + num_traits::ToPrimitive, R: HasChains<T>> ChainRunner<T>
358    for R
359{
360}