Mini MCMC
A compact Rust library for Markov Chain Monte Carlo (MCMC) methods with GPU support.
Installation
To add the latest version of mini-mcmc to your project:
Then use mini_mcmc in your Rust code.
Example: Sampling From a 2D Gaussian
use ChainRunner;
use ;
use MetropolisHastings;
use RunStats;
use ;
You can also find this example at examples/minimal_mh.rs.
Example: Sampling From a Custom Distribution
Below we define a custom Poisson distribution for nonnegative integer states ${0,1,2,\dots}$ and a basic random-walk proposal. We then run Metropolis–Hastings to sample from this distribution, collecting frequencies of $k$ after some burn-in:
use ChainRunner;
use ;
use MetropolisHastings;
use ;
use Rng; // for thread_rng
use Error;
/// A Poisson(\lambda) distribution, seen as a discrete target over k=0,1,2,...
/// A simple random-walk proposal in the nonnegative integers:
/// - If current_state=0, propose 0 -> 1 always
/// - Otherwise propose x->x+1 or x->x-1 with p=0.5 each
;
// A small helper for computing ln(k!)
// Helper function to compute Poisson PMF
You can also find this example at examples/poisson_mh.rs.
Explanation
-
PoissonTargetimplementsTarget<usize, f64>for a discrete Poisson($\lambda$) distribution:
$$p(k) = e{-\lambda} \frac{\lambdak}{k!},\quad k=0,1,2,\ldots$$
The log form of it is $\log p(k) = -\lambda + k \log \lambda - \log k!$. -
NonnegativeProposalprovides a random-walk in the set ${0,1,2,\dots}$:- If $x=0$, propose $1$ with probability $1$.
- If $x>0$, propose $x+1$ or $x-1$ with probability $0.5$ each.
logpreturns $\ln(0.5)$ for the possible moves, or $-\infty$ for impossible moves.
-
Usage:
We start the chain at $k=0$, run 11,000 iterations discarding 1,000 as burn-in, and tally the final sample frequencies for $k=0 \dots 20$. They should approximate the Poisson(4.0) distribution (peak around $k=4$).
With this example, you can see how to use mini_mcmc for unbounded discrete distributions via a custom random-walk proposal and a log‐PMF.
Below is an additional documentation section that you can add to your README. It first gives a minimal version of the rosenbrock3d_hmc.rs example for sampling using HMC. (Note that the full example also plots the sampled data interactively using Plotly.)
Example: Sampling from a 3D Rosenbrock Distribution Using HMC
The following minimal example demonstrates how to create and run an HMC sampler to sample from a 3D Rosenbrock distribution. In this example, we construct an HMC sampler, run it for a fixed number of iterations, and print the shape of the collected samples. The corresponding file can also be found at examples/minimal_hmc.rs. For a complete example—including interactive 3D plotting with Plotly, refer to examples/rosenbrock3d_hmc.rs.
use Element;
use ;
use init_det;
use GradientTarget;
use HMC;
use Float;
/// The 3D Rosenbrock distribution.
///
/// For a point x = (x₁, x₂, x₃), the log probability is defined as the negative of
/// the sum of two Rosenbrock terms:
///
/// f(x) = 100*(x₂ - x₁²)² + (1 - x₁)² + 100*(x₃ - x₂²)² + (1 - x₂)²
///
/// This implementation generalizes to d dimensions, but here we use it for 3D.
Overview
This library provides implementations of
- Hamiltonian Monte Carlo (HMC): an MCMC method that efficiently samples by simulating Hamiltonian dynamics using gradients of the target distribution.
- Metropolis-Hastings: an MCMC algorithm that samples from a distribution by proposing candidates and probabilistically accepting or rejecting them.
- Gibbs Sampling: an MCMC method that iteratively samples each variable from its conditional distribution given all other variables.
Additional features:
- Implementations of Common Distributions: featuring handy Gaussian and isotropic Gaussian implementations, along with traits for defining custom log-prob functions.
- Parallelization: for running multiple Markov chains in parallel.
- Progress Bars: that show progress of MCMC algorithms with convergence statistics and acceptance rates.
- Support for Discrete & Continuous Distributions: for example, Metropolis-Hastings- and Gibbs Samplers can sample from continuous and discrete target distributions.
- Generic Datatypes: enable sampling of vectors with various integer or floating point types.
- Standard Convergence Diagnostics: estimate the effective sample size and Rhat.
Roadmap
- No-U-Turn Sampler (NUTS): An extension of HMC that removes the need to choose path lengths.
- Rank Normalized Rhat: Modern convergence diagnostic, see paper.
- Ensemble Slice Sampling (ESS): Efficient gradient-free sampler, see paper.
- Effective Size Estimation: Online estimation of effective sample size for early stopping.
Structure
src/lib.rs: The main library entry point—exports MCMC functionality.src/distributions.rs: Target distributions (e.g., multivariate Gaussians) and proposal distributions.src/metropolis_hastings.rs: The Metropolis-Hastings algorithm implementation.src/gibbs.rs: The Gibbs sampling algorithm implementation.examples/: Examples on how to use this library.src/io/arrow.rs: Helper functions for saving samples as Apache Arrow files. Enable viaarrowfeature.src/io/parquet.rs: Helper functions for saving samples as Apache Parquet files. Enable viaparquetfeature.src/io/csv.rs: Helper functions for saving samples as CSV files. Enable viacsvfeature.
Usage (Local)
-
Build (Library + Demo):
-
Run the Demo:
Prints basic statistics of the MCMC chain (e.g., estimated mean). Saves a scatter plot of sampled points in
scatter_plot.pngand a Parquet filesamples.parquet.
Optional Features
csv: Enables CSV I/O for samples.arrow/parquet: Enables Apache Arrow / Parquet I/O.wgpu: Enables GPU accelerated sampling for gradient based samplers using burn's WGPU backend. In the HMC example above, you only have to replace line
withtype BackendType = ;
Depending on the number of parallel chains, dimensionality of your sample space and complexity of evaluating the unnormalized log density of your target distribution it might be more efficient to stick with the CPU (NdArray) backend.type BackendType = ;- By default, all features are disabled.
Progress Tracking and Statistics
The run_progress method returns both the sample and an convergence statistics object. It also displays progress bars during the run.
let = mh.run_progress.unwrap;
println!;
The RunStats object contains statistics about:
- Potential scale reduction factor (R-hat)
- Effective sample size (ESS)
and potentially further metrics in the future.
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.
This project includes code from the kolmogorov_smirnov project, licensed under Apache 2.0 as noted in NOTICE.