mini_mcmc/core.rs
1/*!
2# Core MCMC Utilities
3
4This module provides core functionality for running Markov Chain Monte Carlo (MCMC) chains in parallel.
5It includes:
6- The [`MarkovChain<S>`] trait, which abstracts a single MCMC chain.
7- Utility functions [`run_chain`] and [`run_chain_with_progress`] for executing a single chain and collecting its states.
8- The [`HasChains<S>`] trait for types that own multiple Markov chains.
9- The [`ChainRunner<S>`] trait that extends `HasChains<S>` with methods to run chains in parallel (using Rayon), discarding burn-in and optionally displaying progress bars.
10
11Any type implementing `HasChains<S>` (with the required trait bounds) automatically implements `ChainRunner<S>` via a blanket implementation.
12
13This module is generic over the state type using [`num_traits::Float`].
14*/
15
16use indicatif::ProgressBar;
17use indicatif::{MultiProgress, ProgressStyle};
18use nalgebra as na;
19use num_traits::Zero;
20use rayon::prelude::*;
21
22/// A trait that abstracts a single MCMC chain.
23///
24/// A type implementing `MarkovChain<S>` must provide:
25/// - `step()`: advances the chain one iteration and returns a reference to the updated state.
26/// - `current_state()`: returns a reference to the current state without modifying the chain.
27pub trait MarkovChain<S> {
28 /// Performs one iteration of the chain and returns a reference to the new state.
29 fn step(&mut self) -> &Vec<S>;
30
31 /// Returns a reference to the current state of the chain without advancing it.
32 fn current_state(&self) -> &Vec<S>;
33}
34
35/// Runs a single MCMC chain for a specified number of steps.
36///
37/// This function repeatedly calls the chain's `step()` method and collects each state into a
38/// [`nalgebra::DMatrix`], where each row corresponds to one iteration of the chain.
39///
40/// # Arguments
41///
42/// * `chain` - A mutable reference to an object implementing [`MarkovChain<S>`].
43/// * `n_steps` - The total number of iterations to run.
44///
45/// # Returns
46///
47/// A [`nalgebra::DMatrix<S>`] where the number of rows equals `n_steps` and the number of columns equals
48/// the dimensionality of the chain's state.
49pub fn run_chain<S, M>(chain: &mut M, n_steps: usize) -> na::DMatrix<S>
50where
51 M: MarkovChain<S>,
52 S: Clone + na::Scalar + Zero,
53{
54 let dim = chain.current_state().len();
55 let mut out = na::DMatrix::<S>::zeros(n_steps, dim);
56
57 for i in 0..n_steps {
58 let state = chain.step();
59 out.row_mut(i).copy_from_slice(state);
60 }
61
62 out
63}
64
65/// Runs a single MCMC chain for a specified number of steps while displaying progress.
66///
67/// This function is similar to [`run_chain`], but it accepts an [`indicatif::ProgressBar`]
68/// that is updated as the chain advances.
69///
70/// # Arguments
71///
72/// * `chain` - A mutable reference to an object implementing [`MarkovChain<S>`].
73/// * `n_steps` - The total number of iterations to run.
74/// * `pb` - A progress bar used to display progress.
75///
76/// # Returns
77///
78/// A [`nalgebra::DMatrix<S>`] containing the chain's states (one row per iteration).
79pub fn run_chain_with_progress<S, M>(
80 chain: &mut M,
81 n_steps: usize,
82 pb: &ProgressBar,
83) -> na::DMatrix<S>
84where
85 M: MarkovChain<S>,
86 S: Clone + na::Scalar + Zero,
87{
88 let dim = chain.current_state().len();
89 let mut out = na::DMatrix::<S>::zeros(n_steps, dim);
90
91 pb.set_length(n_steps as u64);
92
93 for i in 0..n_steps {
94 let state = chain.step();
95 out.row_mut(i).copy_from_slice(state);
96
97 // Update progress bar
98 pb.inc(1);
99 }
100
101 out
102}
103
104/// A trait for types that own multiple MCMC chains.
105///
106/// - `S` is the type of the state elements (e.g., `f64`).
107/// - `Chain` is the concrete type of the individual chain, which must implement [`MarkovChain<S>`]
108/// and be `Send`.
109///
110/// Implementors must provide a method to access the internal vector of chains.
111pub trait HasChains<S> {
112 type Chain: MarkovChain<S> + std::marker::Send;
113
114 /// Returns a mutable reference to the vector of chains.
115 fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
116}
117
118/// An extension trait for types that own multiple MCMC chains.
119///
120/// `ChainRunner<S>` extends [`HasChains<S>`] by providing default methods to run all chains
121/// in parallel using Rayon. These methods allow you to:
122/// - Run all chains for a specified number of iterations and discard an initial burn-in period.
123/// - Optionally display progress bars for each chain during execution.
124///
125/// Any type that implements [`HasChains<S>`] (with appropriate bounds on `S`) automatically implements
126/// `ChainRunner<S>`.
127pub trait ChainRunner<S>: HasChains<S>
128where
129 S: std::clone::Clone
130 + num_traits::Zero
131 + std::marker::Send
132 + std::cmp::PartialEq
133 + std::marker::Sync
134 + std::fmt::Debug
135 + 'static,
136{
137 /// Runs all chains in parallel, discarding the first `discard` iterations (burn-in).
138 ///
139 /// # Arguments
140 ///
141 /// * `n_steps` - The total number of iterations to run for each chain.
142 /// * `discard` - The number of initial iterations to discard from each chain.
143 ///
144 /// # Returns
145 ///
146 /// A vector of [`nalgebra::DMatrix<S>`] matrices, one for each chain, containing the samples
147 /// after burn-in.
148 fn run(&mut self, n_steps: usize, discard: usize) -> Vec<na::DMatrix<S>> {
149 // Run them all in parallel
150 let results: Vec<na::DMatrix<S>> = self
151 .chains_mut()
152 .par_iter_mut()
153 .map(|chain| run_chain(chain, n_steps))
154 .collect();
155
156 // Now discard the burn-in rows from each matrix
157 results
158 .into_iter()
159 .map(|mat| {
160 let nrows = mat.nrows();
161 let keep = nrows - discard;
162 mat.rows(discard, keep).into()
163 })
164 .collect()
165 }
166
167 /// Runs all chains in parallel with progress bars, discarding the burn-in.
168 ///
169 /// Each chain is run concurrently with its own progress bar. After execution, the first `discard`
170 /// iterations are discarded.
171 ///
172 /// # Arguments
173 ///
174 /// * `n_steps` - The total number of iterations to run for each chain.
175 /// * `discard` - The number of initial iterations to discard.
176 ///
177 /// # Returns
178 ///
179 /// A vector of sample matrices (one per chain) containing only the samples after burn-in.
180 fn run_with_progress(&mut self, n_steps: usize, discard: usize) -> Vec<na::DMatrix<S>> {
181 let multi = MultiProgress::new();
182 let pb_style = ProgressStyle::default_bar()
183 .template("{prefix} [{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
184 .unwrap()
185 .progress_chars("##-");
186
187 // Run each chain in parallel
188 let results: Vec<(Vec<S>, na::DMatrix<S>)> = self
189 .chains_mut()
190 .par_iter_mut()
191 .enumerate()
192 .map(|(i, chain)| {
193 let pb = multi.add(ProgressBar::new(n_steps as u64));
194 pb.set_prefix(format!("Chain {i}"));
195 pb.set_style(pb_style.clone());
196
197 let samples = run_chain_with_progress(chain, n_steps, &pb);
198
199 pb.finish_with_message("Done!");
200
201 (chain.current_state().clone(), samples)
202 })
203 .collect();
204
205 results
206 .into_par_iter()
207 .map(|(_, samples)| {
208 let keep_rows = samples.nrows().saturating_sub(discard);
209 samples.rows(discard, keep_rows).into()
210 })
211 .collect()
212 }
213}
214
215impl<
216 S: std::fmt::Debug
217 + std::marker::Sync
218 + std::cmp::PartialEq
219 + std::marker::Send
220 + num_traits::Zero
221 + std::clone::Clone
222 + 'static,
223 T: HasChains<S>,
224 > ChainRunner<S> for T
225{
226}