mini_mcmc/
core.rs

1/*!
2# Core MCMC Utilities
3
4This module provides core functionality for running Markov Chain Monte Carlo (MCMC) chains in parallel.
5It includes:
6- The [`MarkovChain<S>`] trait, which abstracts a single MCMC chain.
7- Utility functions [`run_chain`] and [`run_chain_with_progress`] for executing a single chain and collecting its states.
8- The [`HasChains<S>`] trait for types that own multiple Markov chains.
9- The [`ChainRunner<S>`] trait that extends `HasChains<S>` with methods to run chains in parallel (using Rayon), discarding burn-in and optionally displaying progress bars.
10
11Any type implementing `HasChains<S>` (with the required trait bounds) automatically implements `ChainRunner<S>` via a blanket implementation.
12
13This module is generic over the state type using [`num_traits::Float`].
14*/
15
16use indicatif::ProgressBar;
17use indicatif::{MultiProgress, ProgressStyle};
18use nalgebra as na;
19use num_traits::Zero;
20use rayon::prelude::*;
21
22/// A trait that abstracts a single MCMC chain.
23///
24/// A type implementing `MarkovChain<S>` must provide:
25/// - `step()`: advances the chain one iteration and returns a reference to the updated state.
26/// - `current_state()`: returns a reference to the current state without modifying the chain.
27pub trait MarkovChain<S> {
28    /// Performs one iteration of the chain and returns a reference to the new state.
29    fn step(&mut self) -> &Vec<S>;
30
31    /// Returns a reference to the current state of the chain without advancing it.
32    fn current_state(&self) -> &Vec<S>;
33}
34
35/// Runs a single MCMC chain for a specified number of steps.
36///
37/// This function repeatedly calls the chain's `step()` method and collects each state into a
38/// [`nalgebra::DMatrix`], where each row corresponds to one iteration of the chain.
39///
40/// # Arguments
41///
42/// * `chain` - A mutable reference to an object implementing [`MarkovChain<S>`].
43/// * `n_steps` - The total number of iterations to run.
44///
45/// # Returns
46///
47/// A [`nalgebra::DMatrix<S>`] where the number of rows equals `n_steps` and the number of columns equals
48/// the dimensionality of the chain's state.
49pub fn run_chain<S, M>(chain: &mut M, n_steps: usize) -> na::DMatrix<S>
50where
51    M: MarkovChain<S>,
52    S: Clone + na::Scalar + Zero,
53{
54    let dim = chain.current_state().len();
55    let mut out = na::DMatrix::<S>::zeros(n_steps, dim);
56
57    for i in 0..n_steps {
58        let state = chain.step();
59        out.row_mut(i).copy_from_slice(state);
60    }
61
62    out
63}
64
65/// Runs a single MCMC chain for a specified number of steps while displaying progress.
66///
67/// This function is similar to [`run_chain`], but it accepts an [`indicatif::ProgressBar`]
68/// that is updated as the chain advances.
69///
70/// # Arguments
71///
72/// * `chain` - A mutable reference to an object implementing [`MarkovChain<S>`].
73/// * `n_steps` - The total number of iterations to run.
74/// * `pb` - A progress bar used to display progress.
75///
76/// # Returns
77///
78/// A [`nalgebra::DMatrix<S>`] containing the chain's states (one row per iteration).
79pub fn run_chain_with_progress<S, M>(
80    chain: &mut M,
81    n_steps: usize,
82    pb: &ProgressBar,
83) -> na::DMatrix<S>
84where
85    M: MarkovChain<S>,
86    S: Clone + na::Scalar + Zero,
87{
88    let dim = chain.current_state().len();
89    let mut out = na::DMatrix::<S>::zeros(n_steps, dim);
90
91    pb.set_length(n_steps as u64);
92
93    for i in 0..n_steps {
94        let state = chain.step();
95        out.row_mut(i).copy_from_slice(state);
96
97        // Update progress bar
98        pb.inc(1);
99    }
100
101    out
102}
103
104/// A trait for types that own multiple MCMC chains.
105///
106/// - `S` is the type of the state elements (e.g., `f64`).
107/// - `Chain` is the concrete type of the individual chain, which must implement [`MarkovChain<S>`]
108///   and be `Send`.
109///
110/// Implementors must provide a method to access the internal vector of chains.
111pub trait HasChains<S> {
112    type Chain: MarkovChain<S> + std::marker::Send;
113
114    /// Returns a mutable reference to the vector of chains.
115    fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
116}
117
118/// An extension trait for types that own multiple MCMC chains.
119///
120/// `ChainRunner<S>` extends [`HasChains<S>`] by providing default methods to run all chains
121/// in parallel using Rayon. These methods allow you to:
122/// - Run all chains for a specified number of iterations and discard an initial burn-in period.
123/// - Optionally display progress bars for each chain during execution.
124///
125/// Any type that implements [`HasChains<S>`] (with appropriate bounds on `S`) automatically implements
126/// `ChainRunner<S>`.
127pub trait ChainRunner<S>: HasChains<S>
128where
129    S: std::clone::Clone
130        + num_traits::Zero
131        + std::marker::Send
132        + std::cmp::PartialEq
133        + std::marker::Sync
134        + std::fmt::Debug
135        + 'static,
136{
137    /// Runs all chains in parallel, discarding the first `discard` iterations (burn-in).
138    ///
139    /// # Arguments
140    ///
141    /// * `n_steps` - The total number of iterations to run for each chain.
142    /// * `discard` - The number of initial iterations to discard from each chain.
143    ///
144    /// # Returns
145    ///
146    /// A vector of [`nalgebra::DMatrix<S>`] matrices, one for each chain, containing the samples
147    /// after burn-in.
148    fn run(&mut self, n_steps: usize, discard: usize) -> Vec<na::DMatrix<S>> {
149        // Run them all in parallel
150        let results: Vec<na::DMatrix<S>> = self
151            .chains_mut()
152            .par_iter_mut()
153            .map(|chain| run_chain(chain, n_steps))
154            .collect();
155
156        // Now discard the burn-in rows from each matrix
157        results
158            .into_iter()
159            .map(|mat| {
160                let nrows = mat.nrows();
161                let keep = nrows - discard;
162                mat.rows(discard, keep).into()
163            })
164            .collect()
165    }
166
167    /// Runs all chains in parallel with progress bars, discarding the burn-in.
168    ///
169    /// Each chain is run concurrently with its own progress bar. After execution, the first `discard`
170    /// iterations are discarded.
171    ///
172    /// # Arguments
173    ///
174    /// * `n_steps` - The total number of iterations to run for each chain.
175    /// * `discard` - The number of initial iterations to discard.
176    ///
177    /// # Returns
178    ///
179    /// A vector of sample matrices (one per chain) containing only the samples after burn-in.
180    fn run_with_progress(&mut self, n_steps: usize, discard: usize) -> Vec<na::DMatrix<S>> {
181        let multi = MultiProgress::new();
182        let pb_style = ProgressStyle::default_bar()
183            .template("{prefix} [{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
184            .unwrap()
185            .progress_chars("##-");
186
187        // Run each chain in parallel
188        let results: Vec<(Vec<S>, na::DMatrix<S>)> = self
189            .chains_mut()
190            .par_iter_mut()
191            .enumerate()
192            .map(|(i, chain)| {
193                let pb = multi.add(ProgressBar::new(n_steps as u64));
194                pb.set_prefix(format!("Chain {i}"));
195                pb.set_style(pb_style.clone());
196
197                let samples = run_chain_with_progress(chain, n_steps, &pb);
198
199                pb.finish_with_message("Done!");
200
201                (chain.current_state().clone(), samples)
202            })
203            .collect();
204
205        results
206            .into_par_iter()
207            .map(|(_, samples)| {
208                let keep_rows = samples.nrows().saturating_sub(discard);
209                samples.rows(discard, keep_rows).into()
210            })
211            .collect()
212    }
213}
214
215impl<
216        S: std::fmt::Debug
217            + std::marker::Sync
218            + std::cmp::PartialEq
219            + std::marker::Send
220            + num_traits::Zero
221            + std::clone::Clone
222            + 'static,
223        T: HasChains<S>,
224    > ChainRunner<S> for T
225{
226}