mini_mcmc/lib.rs
1//! # Mini MCMC
2//!
3//! A compact Rust library offering **Markov Chain Monte Carlo (MCMC)** methods, including
4//! **No-U-Turn Sampler (NUTS)**, **Hamiltonian Monte Carlo (HMC)**, **Metropolis–Hastings**, and
5//! **Gibbs Sampling** for both discrete and continuous targets.
6//!
7//! ## Getting Started
8//!
9//! To use this library, add it to your project:
10//! ```bash
11//! cargo add mini-mcmc
12//! ```
13//!
14//! The library provides three main sampling approaches:
15//! 1. **No-U-Turn Sampler (NUTS)**: For continuous distributions with gradients. You need to provide:
16//! - A target distribution implementing the `GradientTarget` trait
17//! 2. **Hamiltonian Monte Carlo (HMC)**: For continuous distributions with gradients. You need to provide:
18//! - A target distribution implementing the `BatchedGradientTarget` trait
19//! 3. **Metropolis-Hastings**: For general-purpose sampling. You need to provide:
20//! - A target distribution implementing the `Target` trait
21//! - A proposal distribution implementing the `Proposal` trait
22//! 4. **Gibbs Sampling**: For sampling when conditional distributions are available. You need to provide:
23//! - A distribution implementing the `Conditional` trait
24//!
25//! ## Example 1: Sampling a 2D Rosenbrock (NUTS)
26//!
27//! ```rust
28//! use burn::backend::Autodiff;
29//! use burn::prelude::Tensor;
30//! use mini_mcmc::core::init;
31//! use mini_mcmc::distributions::Rosenbrock2D;
32//! use mini_mcmc::nuts::NUTS;
33//!
34//! // CPU backend with autodiff (NdArray).
35//! type BackendType = Autodiff<burn::backend::NdArray>;
36//!
37//! // 2D Rosenbrock target (a=1, b=100).
38//! let target = Rosenbrock2D { a: 1.0_f32, b: 100.0_f32 };
39//!
40//! // 4 independent chains starting at (1.0,2.0).
41//! let initial_positions = init::<f32>(4, 2);
42//!
43//! // NUTS with 0.95 target‐accept and a fixed seed.
44//! let mut sampler = NUTS::new(target, initial_positions, 0.95)
45//! .set_seed(42);
46//!
47//! // 400 burn-in + 400 samples.
48//! let n_collect = 400;
49//! let n_discard = 400;
50//!
51//! let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
52//!
53//! println!(
54//! "Collected {} chains × {} samples × 2 dims",
55//! sample.dims()[0],
56//! sample.dims()[1],
57//! );
58//! ```
59//!
60//! ## Example 2: Sampling a 3D Rosenbrock (HMC)
61//!
62//! ```rust
63//! use burn::tensor::Element;
64//! use burn::{backend::Autodiff, prelude::Tensor};
65//! use mini_mcmc::hmc::HMC;
66//! use mini_mcmc::distributions::BatchedGradientTarget;
67//! use mini_mcmc::core::init;
68//! use num_traits::Float;
69//!
70//! /// The 3D Rosenbrock distribution.
71//! ///
72//! /// For a point x = (x₁, x₂, x₃), the log density is defined as
73//! ///
74//! /// f(x) = 100*(x₂ - x₁²)² + (1 - x₁)² + 100*(x₃ - x₂²)² + (1 - x₂)².
75//! ///
76//! /// This implementation generalizes to d dimensions, but here we use it for 3D.
77//! struct RosenbrockND {}
78//!
79//! impl<T, B> BatchedGradientTarget<T, B> for RosenbrockND
80//! where
81//! T: Float + std::fmt::Debug + Element,
82//! B: burn::tensor::backend::AutodiffBackend,
83//! {
84//! fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
85//! // Assume positions has shape [n_chains, d] with d = 3.
86//! let k = positions.dims()[0];
87//! let n = positions.dims()[1];
88//! let low = positions.clone().slice([0..k, 0..n-1]);
89//! let high = positions.clone().slice([0..k, 1..n]);
90//! let term_1 = (high - low.clone().powi_scalar(2))
91//! .powi_scalar(2)
92//! .mul_scalar(100);
93//! let term_2 = low.neg().add_scalar(1).powi_scalar(2);
94//! -(term_1 + term_2).sum_dim(1).squeeze(1)
95//! }
96//! }
97//!
98//! // Use the CPU backend wrapped in Autodiff (e.g., NdArray).
99//! type BackendType = Autodiff<burn::backend::NdArray>;
100//!
101//! // Create the 3D Rosenbrock target.
102//! let target = RosenbrockND {};
103//!
104//! // Define initial positions for 6 chains (each a 3D point).
105//! let initial_positions = init(6, 3);
106//!
107//! // Create the HMC sampler with a step size of 0.01 and 5 leapfrog steps.
108//! let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
109//! target,
110//! initial_positions,
111//! 0.032,
112//! 5,
113//! );
114//!
115//! // Run the sampler for 123+45 iterations, discard 45 burnin observations
116//! let sample = sampler.run(123, 45);
117//!
118//! // Print the shape of the collected sample.
119//! println!("Collected sample with shape: {:?}", sample.dims());
120//! ```
121//!
122//!
123//! See [`examples/minimal_nuts.rs`](examples/minimal_nuts.rs) for the full version with diagnostics.
124//!
125//! ## Example 3: Sampling a 2D Gaussian (Metropolis–Hastings, Continuous)
126//!
127//! ```rust
128//! use mini_mcmc::core::{ChainRunner, init};
129//! use mini_mcmc::distributions::{Gaussian2D, IsotropicGaussian};
130//! use mini_mcmc::metropolis_hastings::MetropolisHastings;
131//! use ndarray::{arr1, arr2};
132//!
133//! let target = Gaussian2D {
134//! mean: arr1(&[0.0, 0.0]),
135//! cov: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
136//! };
137//! let proposal = IsotropicGaussian::new(1.0);
138//! let initial_state = [0.0, 0.0];
139//!
140//! let mut mh = MetropolisHastings::new(target, proposal, init(4, 2));
141//! let sample = mh.run(1000, 100).unwrap();
142//! println!("Metropolis–Hastings sample shape: {:?}", sample.shape());
143//! ```
144//!
145//! ## Example 4: Sampling a Poisson Distribution (Metropolis-Hastings, Discrete)
146//!
147//! ```rust
148//! use mini_mcmc::core::{ChainRunner, init};
149//! use mini_mcmc::distributions::{Proposal, Target};
150//! use mini_mcmc::metropolis_hastings::MetropolisHastings;
151//! use rand::Rng;
152//!
153//! #[derive(Clone)]
154//! struct PoissonTarget {
155//! lambda: f64,
156//! }
157//!
158//! impl Target<usize, f64> for PoissonTarget {
159//! /// unnorm_logp(k) = log( p(k) ), ignoring normalizing constants if you wish.
160//! /// For Poisson(k|lambda) = exp(-lambda) * (lambda^k / k!)
161//! /// so log p(k) = -lambda + k*ln(lambda) - ln(k!)
162//! /// which is enough to do MH acceptance.
163//! fn unnorm_logp(&self, theta: &[usize]) -> f64 {
164//! let k = theta[0];
165//! -self.lambda + (k as f64) * self.lambda.ln() - ln_factorial(k as u64)
166//! }
167//! }
168//!
169//! #[derive(Clone)]
170//! struct NonnegativeProposal;
171//!
172//! impl Proposal<usize, f64> for NonnegativeProposal {
173//! fn sample(&mut self, current: &[usize]) -> Vec<usize> {
174//! let x = current[0];
175//! if x == 0 {
176//! vec![1]
177//! } else {
178//! let step_up = rand::rng().gen_bool(0.5);
179//! vec![if step_up { x + 1 } else { x - 1 }]
180//! }
181//! }
182//!
183//! fn logp(&self, from: &[usize], to: &[usize]) -> f64 {
184//! let (x, y) = (from[0], to[0]);
185//! if x == 0 && y == 1 {
186//! 0.0
187//! } else if x > 0 && (y == x + 1 || y == x - 1) {
188//! (0.5_f64).ln()
189//! } else {
190//! f64::NEG_INFINITY
191//! }
192//! }
193//!
194//! fn set_seed(self, _seed: u64) -> Self {
195//! self
196//! }
197//! }
198//!
199//! fn ln_factorial(k: u64) -> f64 {
200//! (1..=k).map(|v| (v as f64).ln()).sum()
201//! }
202//!
203//! let target = PoissonTarget { lambda: 4.0 };
204//! let proposal = NonnegativeProposal;
205//! let initial_state = vec![vec![0]];
206//!
207//! let mut mh = MetropolisHastings::new(target, proposal, initial_state);
208//! let sample = mh.run(5000, 100).unwrap();
209//! println!("Poisson sample shape: {:?}", sample.shape());
210//! ```
211//!
212//! For more complete implementations (including Gibbs sampling and I/O helpers),
213//! see the `examples/` directory.
214//!
215//! ## Features
216//! - **Parallel Chains** for improved throughput
217//! - **Progress Indicators** (acceptance rates, max Rhat, iteration counts)
218//! - **Common Distributions** (e.g. Gaussian) plus easy traits for custom log‐prob
219//! - **Optional I/O** (CSV, Arrow, Parquet) and GPU support (WGPU)
220//! - **Effective Sample Size (ESS)** estimation following STAN's methodology
221//! - **R-hat Diagnostics** for convergence monitoring
222//!
223//! ## Roadmap
224//! - Rank-Normalized R-hat diagnostics
225//! - Ensemble Slice Sampling (ESS)
226
227pub mod core;
228mod dev_tools;
229pub mod distributions;
230pub mod gibbs;
231pub mod hmc;
232pub mod io;
233pub mod metropolis_hastings;
234pub mod nuts;
235pub mod stats;