mini_mcmc/
core.rs

1/*!
2# Core MCMC Utilities.
3
4This module provides core functionality for running Markov Chain Monte Carlo (MCMC) chains in parallel.
5It includes:
6- The [`MarkovChain<T>`] trait, which abstracts a single MCMC chain.
7- Utility functions [`run_chain`] and [`run_chain_progress`] for executing a single chain and collecting its states.
8- The [`HasChains<T>`] trait for types that own multiple Markov chains.
9- The [`ChainRunner<T>`] trait that extends [`HasChains<T>`] with methods to run chains in parallel (using Rayon), discarding burn-in and optionally displaying progress bars.
10
11Any type implementing [`HasChains<T>`] (with the required trait bounds) automatically implements [`ChainRunner<T>`] via a blanket implementation.
12
13This module is generic over the state type using [`ndarray::LinalgScalar`].
14*/
15
16use crate::stats::{collect_rhat, ChainStats, ChainTracker, RunStats};
17use indicatif::ProgressBar;
18use indicatif::{MultiProgress, ProgressStyle};
19use ndarray::stack;
20use ndarray::{prelude::*, LinalgScalar, ShapeError};
21use ndarray_stats::QuantileExt;
22use num_traits::{Float, FromPrimitive};
23use rand::rngs::SmallRng;
24use rand::SeedableRng;
25use rand_distr::{Distribution, StandardNormal};
26use rayon::prelude::*;
27use std::cmp::PartialEq;
28use std::error::Error;
29use std::marker::Send;
30use std::sync::mpsc::{self, Receiver, Sender};
31use std::thread::{self};
32use std::time::{Duration, Instant};
33
34/// A trait that abstracts a single MCMC chain.
35///
36/// A type implementing [`MarkovChain<T>`] must provide:
37/// - `step()`: advances the chain one iteration and returns a reference to the updated state.
38/// - `current_state()`: returns a reference to the current state without modifying the chain.
39pub trait MarkovChain<T> {
40    /// Performs one iteration of the chain and returns a reference to the new state.
41    fn step(&mut self) -> &Vec<T>;
42
43    /// Returns a reference to the current state of the chain without advancing it.
44    fn current_state(&self) -> &Vec<T>;
45}
46
47/// Runs a single MCMC chain for a specified number of steps.
48///
49/// This function repeatedly calls the chain's `step()` method and collects each state into a
50/// [`ndarray::Array2<T>`] of shape `[n_collect, D]` where:
51/// - `n_collect`: number of observations to collect
52/// - `D`: dimensionality of the state space
53///
54/// Each row corresponds to one collected state of the chain.
55pub fn run_chain<T, M>(chain: &mut M, n_collect: usize, n_discard: usize) -> Array2<T>
56where
57    M: MarkovChain<T>,
58    T: LinalgScalar,
59{
60    let dim = chain.current_state().len();
61    let mut out = Array2::<T>::zeros((n_collect, dim));
62    let total = n_collect + n_discard;
63
64    for i in 0..total {
65        let state = chain.step();
66        if i >= n_discard {
67            let state_arr = ArrayView::from_shape(state.len(), state.as_slice()).unwrap();
68            out.row_mut(i - n_discard).assign(&state_arr);
69        }
70    }
71
72    out
73}
74
75/// Runs a single MCMC chain for a `n_collect` + `n_discard` number of steps while displaying progress.
76///
77/// This function is similar to [`run_chain`], but it accepts an [`indicatif::ProgressBar`]
78/// that is updated as the chain advances.
79///
80/// # Arguments
81///
82/// * `chain` - A mutable reference to an object implementing [`MarkovChain<T>`].
83/// * `n_collect` - The number of observations to collect and return.
84/// * `n_discard` - The number of observations to discard (burn-in).
85/// * `tx` - A [`Sender<ChainStats>`] object for communication with chains-managing parent thread.
86///
87/// # Returns
88///
89/// A [`ndarray::Array2<T>`] containing the chain's states with `n_collect` number of rows.
90pub fn run_chain_progress<T, M>(
91    chain: &mut M,
92    n_collect: usize,
93    n_discard: usize,
94    tx: Sender<ChainStats>,
95) -> Result<Array2<T>, String>
96where
97    M: MarkovChain<T>,
98    T: LinalgScalar + PartialEq + num_traits::ToPrimitive,
99{
100    let n_params = chain.current_state().len();
101    let mut out = Array2::<T>::zeros((n_collect, n_params));
102
103    let mut tracker = ChainTracker::new(n_params, chain.current_state());
104    let mut last = Instant::now();
105    let freq = Duration::from_secs(1);
106    let total = n_discard + n_collect;
107
108    for i in 0..total {
109        let current_state = chain.step();
110        tracker.step(current_state).map_err(|e| {
111            let msg = format!(
112            "Chain statistics tracker caused error: {}.\nAborting generation of further observations.",
113            e
114            );
115            println!("{}", msg);
116            msg
117        })?;
118
119        let now = Instant::now();
120        if (now >= last + freq) | (i == total - 1) {
121            if let Err(e) = tx.send(tracker.stats()) {
122                eprintln!("Sending chain statistics failed: {e}");
123            }
124            last = now;
125        }
126
127        if i >= n_discard {
128            out.row_mut(i - n_discard).assign(
129                &ArrayView1::from_shape(current_state.len(), current_state.as_slice()).unwrap(),
130            );
131        }
132    }
133
134    // TODO: Somehow save state of the chains and enable continuing runs
135    Ok(out)
136}
137
138/// A trait for types that own multiple MCMC chains.
139///
140/// - `T` is the type of the state elements (e.g., `f64`).
141/// - `Chain` is the concrete type of the individual chain, which must implement [`MarkovChain<T>`]
142///   and be [`Send`].
143///
144/// Implementors must provide a method to access the internal vector of chains.
145pub trait HasChains<S> {
146    type Chain: MarkovChain<S> + Send;
147
148    /// Returns a mutable reference to the vector of chains.
149    fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
150}
151
152/// An extension trait for types that own multiple MCMC chains.
153///
154/// [`ChainRunner<T>`] extends [`HasChains<T>`] by providing default methods to run all chains
155/// in parallel. These methods allow you to:
156/// - Run all chains, collect `n_collect` observations and discard `n_discard` initial burn-in observations.
157/// - Optionally display progress bars for each chain during execution.
158///
159/// Any type that implements [`HasChains<T>`] (with appropriate bounds on `T`) automatically implements
160/// [`ChainRunner<T>`].
161pub trait ChainRunner<T>: HasChains<T>
162where
163    T: LinalgScalar + PartialEq + Send + num_traits::ToPrimitive,
164{
165    /// Runs all chains in parallel, discarding the first `discard` iterations (burn-in).
166    ///
167    /// # Arguments
168    ///
169    /// * `n_collect` - The number of observations to collect and return.
170    /// * `n_discard` - The number of observations to discard (burn-in).
171    ///
172    /// # Returns
173    ///
174    /// A [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
175    /// step and the last one the parameter dimension.
176    fn run(&mut self, n_collect: usize, n_discard: usize) -> Result<Array3<T>, ShapeError> {
177        // Run them all in parallel
178        let results: Vec<Array2<T>> = self
179            .chains_mut()
180            .par_iter_mut()
181            .map(|chain| run_chain(chain, n_collect, n_discard))
182            .collect();
183        let views: Vec<ArrayView2<T>> = results.iter().map(|x| x.view()).collect();
184        let out: Array3<T> = stack(Axis(0), &views)?;
185        Ok(out)
186    }
187
188    /// Runs all chains in parallel with progress bars, discarding the burn-in.
189    ///
190    /// Each chain is run in parallel with its own progress bar. After execution, the first `discard`
191    /// iterations are discarded.
192    ///
193    /// # Arguments
194    ///
195    /// * `n_collect` - The number of observations to collect and return.
196    /// * `n_discard` - The number of observations to discard (burn-in).
197    ///
198    /// # Returns
199    ///
200    /// Returns a tuple containing:
201    /// - A [`ndarray::Array3`] tensor with the first axis representing the chain, the second one the
202    ///   step and the last one the parameter dimension.
203    /// - A `RunStats` object containing convergence statistics including:
204    ///   - Acceptance probability
205    ///   - Potential scale reduction factor (R-hat)
206    ///   - Effective sample size (ESS)
207    ///   - Other convergence diagnostics
208    fn run_progress(
209        &mut self,
210        n_collect: usize,
211        n_discard: usize,
212    ) -> Result<(Array3<T>, RunStats), Box<dyn Error>> {
213        // Channels.
214        // Each chain gets its own channel. Hence, we have `n_chains` channels.
215        // The objects sent over channels are Array2<f32>s ($s_m^2$, $\bar{\theta}_m^{(\bullet)}$).
216        // The child thread sends it's respective one to the parent thread.
217        // The parent thread assemples the tuples it receives to compute Rhat.
218
219        let chains = self.chains_mut();
220
221        let mut rxs: Vec<Receiver<ChainStats>> = vec![];
222        let mut txs: Vec<Sender<ChainStats>> = vec![];
223        (0..chains.len()).for_each(|_| {
224            let (tx, rx) = mpsc::channel();
225            rxs.push(rx);
226            txs.push(tx);
227        });
228
229        let progress_handle = thread::spawn(move || {
230            let sleep_ms = Duration::from_millis(250);
231            let timeout_ms = Duration::from_millis(0);
232            let multi = MultiProgress::new();
233
234            let pb_style = ProgressStyle::default_bar()
235                .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
236                .unwrap()
237                .progress_chars("=>-");
238            let total: u64 = (n_collect + n_discard).try_into().unwrap();
239
240            // Global Progress bar
241            let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
242            global_pb.set_style(pb_style.clone());
243            global_pb.set_prefix("Global");
244
245            let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
246                .map(|chain_idx| {
247                    let pb = multi.add(ProgressBar::new(total));
248                    pb.set_style(pb_style.clone());
249                    pb.set_prefix(format!("Chain {chain_idx}"));
250                    (chain_idx, pb)
251                })
252                .collect();
253            let mut next_active = active.len();
254            let mut n_finished = 0;
255            let mut most_recent = vec![None; rxs.len()];
256            let mut total_progress;
257
258            loop {
259                for (i, rx) in rxs.iter().enumerate() {
260                    while let Ok(stats) = rx.recv_timeout(timeout_ms) {
261                        most_recent[i] = Some(stats)
262                    }
263                }
264
265                // Update chain progress bar messages
266                // and compute average acceptance probability
267                let mut to_replace = vec![false; active.len()];
268                let mut avg_p_accept = 0.0;
269                let mut n_available_stats = 0.0;
270                for (vec_idx, (i, pb)) in active.iter().enumerate() {
271                    if let Some(stats) = &most_recent[*i] {
272                        pb.set_position(stats.n);
273                        pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
274                        avg_p_accept += stats.p_accept;
275                        n_available_stats += 1.0;
276
277                        if stats.n == total {
278                            to_replace[vec_idx] = true;
279                            n_finished += 1;
280                        }
281                    }
282                }
283                avg_p_accept /= n_available_stats;
284
285                // Update global progress bar
286                total_progress = 0;
287                for stats in most_recent.iter().flatten() {
288                    total_progress += stats.n;
289                }
290                global_pb.set_position(total_progress);
291                let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
292                if valid.len() >= 2 {
293                    let rhats = collect_rhat(valid.as_slice());
294                    let max = rhats.max_skipnan();
295                    global_pb.set_message(format!(
296                        "p(accept)≈{:.2} max(rhat)≈{:.2}",
297                        avg_p_accept, max
298                    ))
299                }
300
301                let mut to_remove = vec![];
302                for (i, replace) in to_replace.iter().enumerate() {
303                    if *replace && next_active < most_recent.len() {
304                        let pb = multi.add(ProgressBar::new(total));
305                        pb.set_style(pb_style.clone());
306                        pb.set_prefix(format!("Chain {next_active}"));
307                        active[i] = (next_active, pb);
308                        next_active += 1;
309                    } else if *replace {
310                        to_remove.push(i);
311                    }
312                }
313
314                to_remove.sort();
315                for i in to_remove.iter().rev() {
316                    active.remove(*i);
317                }
318
319                if n_finished >= most_recent.len() {
320                    break;
321                }
322                std::thread::sleep(sleep_ms);
323            }
324        });
325
326        let chain_sample: Vec<Array2<T>> = thread::scope(|s| {
327            let handles: Vec<thread::ScopedJoinHandle<Array2<T>>> = chains
328                .iter_mut()
329                .zip(txs)
330                .map(|(chain, tx)| {
331                    s.spawn(|| {
332                        run_chain_progress(chain, n_collect, n_discard, tx)
333                            .expect("Expected running chain to succeed.")
334                    })
335                })
336                .collect();
337            handles
338                .into_iter()
339                .map(|h| {
340                    h.join()
341                        .expect("Expected thread to succeed in generating observation.")
342                })
343                .collect()
344        });
345        let sample: Array3<T> = stack(
346            Axis(0),
347            &chain_sample
348                .iter()
349                .map(|x| x.view())
350                .collect::<Vec<ArrayView2<T>>>(),
351        )?;
352
353        if let Err(e) = progress_handle.join() {
354            eprintln!("Progress bar thread emitted error message: {:?}", e);
355        }
356
357        let run_stats = RunStats::from(sample.view());
358
359        Ok((sample, run_stats))
360    }
361}
362
363impl<T: LinalgScalar + Send + PartialEq + num_traits::ToPrimitive, R: HasChains<T>> ChainRunner<T>
364    for R
365{
366}
367
368/// Generates a vector of random initial positions from a standard normal distribution.
369///
370/// Each position is a `Vec<T>` of length `d` representing a point in `d`-dimensional space.
371/// The function returns `n` such positions.
372///
373/// # Type Parameters
374/// - `T`: The numeric type (e.g., `f32`, `f64`). Must implement `Float + FromPrimitive`.
375///
376/// # Parameters
377/// - `n`: Number of positions to generate.
378/// - `d`: Dimensionality of each position.
379///
380/// # Returns
381/// A `Vec<Vec<T>>` where each inner vector is a position in `d`-dimensional space.
382///
383/// # Panics
384/// Panics if an observation cannot be converted from `f64` to `T` (should never happen for `f32` or `f64`).
385///
386/// # Examples
387/// ```
388/// # use mini_mcmc::core::init;
389/// let positions: Vec<Vec<f32>> = init(5, 3);
390/// for pos in positions {
391///     println!("{:?}", pos);
392/// }
393/// ```
394pub fn init<T>(n: usize, d: usize) -> Vec<Vec<T>>
395where
396    T: Float + FromPrimitive,
397{
398    let rng = SmallRng::from_os_rng();
399    _init(n, d, rng)
400}
401
402/// Generates `n` pseudo-random vectors from the `d` dimensional standard normal distribution.
403/// This function calls [`init_with_seed`] with the same parameters and seed 42.
404pub fn init_det<T>(n: usize, d: usize) -> Vec<Vec<T>>
405where
406    T: Float + FromPrimitive,
407{
408    init_with_seed(n, d, 42)
409}
410
411/// Generates `n` pseudo-random vectors from the `d` dimensional standard normal distribution.
412/// Same as [`init`] except this function returns a deterministic sample.
413pub fn init_with_seed<T>(n: usize, d: usize, seed: u64) -> Vec<Vec<T>>
414where
415    T: Float + FromPrimitive,
416{
417    let rng = SmallRng::seed_from_u64(seed);
418    _init(n, d, rng)
419}
420
421fn _init<T>(n: usize, d: usize, mut rng: SmallRng) -> Vec<Vec<T>>
422where
423    T: Float + FromPrimitive,
424{
425    (0..n)
426        .map(|_| {
427            (0..d)
428                .map(|_| {
429                    let obs: f64 = StandardNormal.sample(&mut rng);
430                    T::from_f64(obs).unwrap()
431                })
432                .collect()
433        })
434        .collect()
435}