general_mcmc/lib.rs
1//! # Mini MCMC
2//!
3//! A compact Rust library offering **Markov Chain Monte Carlo (MCMC)** methods, including
4//! **No-U-Turn Sampler (NUTS)**, **Hamiltonian Monte Carlo (HMC)**, **Metropolis–Hastings**, and
5//! **Gibbs Sampling** for both discrete and continuous targets.
6//!
7//! ## Getting Started
8//!
9//! To use this library, add it to your project:
10//! ```bash
11//! cargo add general-mcmc
12//! ```
13//!
14//! The library provides three main sampling approaches:
15//! 1. **No-U-Turn Sampler (NUTS)**: For continuous distributions with gradients. You need to provide:
16//! - A target distribution implementing the `GradientTarget` trait
17//! 2. **Hamiltonian Monte Carlo (HMC)**: For continuous distributions with gradients. You need to provide:
18//! - A target distribution implementing the `BatchedGradientTarget` trait
19//! 3. **Metropolis-Hastings**: For general-purpose sampling. You need to provide:
20//! - A target distribution implementing the `Target` trait
21//! - A proposal distribution implementing the `Proposal` trait
22//! 4. **Gibbs Sampling**: For sampling when conditional distributions are available. You need to provide:
23//! - A distribution implementing the `Conditional` trait
24//!
25//! ## Example 1: Sampling a 2D Rosenbrock (NUTS)
26//!
27//! ```rust
28//! use burn::backend::Autodiff;
29//! use burn::prelude::Tensor;
30//! use general_mcmc::core::init;
31//! use general_mcmc::distributions::Rosenbrock2D;
32//! use general_mcmc::nuts::NUTS;
33//!
34//! // CPU backend with autodiff (NdArray).
35//! type BackendType = Autodiff<burn::backend::NdArray>;
36//!
37//! // 2D Rosenbrock target (a=1, b=100).
38//! let target = Rosenbrock2D { a: 1.0_f32, b: 100.0_f32 };
39//!
40//! // 4 independent chains starting at (1.0,2.0).
41//! let initial_positions = init::<f32>(4, 2);
42//!
43//! // NUTS with 0.95 target‐accept and a fixed seed.
44//! let mut sampler = NUTS::new(target, initial_positions, 0.95)
45//! .set_seed(42);
46//!
47//! // 400 burn-in + 400 samples.
48//! let n_collect = 400;
49//! let n_discard = 400;
50//!
51//! let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
52//!
53//! println!(
54//! "Collected {} chains × {} samples × 2 dims",
55//! sample.dims()[0],
56//! sample.dims()[1],
57//! );
58//! ```
59//!
60//! ## Example 2: Sampling a 3D Rosenbrock (HMC)
61//!
62//! ```rust
63//! use burn::tensor::Element;
64//! use burn::{backend::Autodiff, prelude::Tensor};
65//! use general_mcmc::hmc::HMC;
66//! use general_mcmc::distributions::BatchedGradientTarget;
67//! use general_mcmc::core::init;
68//! use num_traits::Float;
69//!
70//! /// The 3D Rosenbrock distribution.
71//! ///
72//! /// For a point x = (x₁, x₂, x₃), the log density is defined as
73//! ///
74//! /// f(x) = 100*(x₂ - x₁²)² + (1 - x₁)² + 100*(x₃ - x₂²)² + (1 - x₂)².
75//! ///
76//! /// This implementation generalizes to d dimensions, but here we use it for 3D.
77//! #[derive(Clone)]
78//! struct RosenbrockND {}
79//!
80//! impl<T, B> BatchedGradientTarget<T, B> for RosenbrockND
81//! where
82//! T: Float + std::fmt::Debug + Element,
83//! B: burn::tensor::backend::AutodiffBackend,
84//! {
85//! fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
86//! // Assume positions has shape [n_chains, d] with d = 3.
87//! let k = positions.dims()[0];
88//! let n = positions.dims()[1];
89//! let low = positions.clone().slice([0..k, 0..n-1]);
90//! let high = positions.clone().slice([0..k, 1..n]);
91//! let term_1 = (high - low.clone().powi_scalar(2))
92//! .powi_scalar(2)
93//! .mul_scalar(100);
94//! let term_2 = low.neg().add_scalar(1).powi_scalar(2);
95//! -(term_1 + term_2).sum_dim(1).squeeze(1)
96//! }
97//! }
98//!
99//! // Use the CPU backend wrapped in Autodiff (e.g., NdArray).
100//! type BackendType = Autodiff<burn::backend::NdArray>;
101//!
102//! // Create the 3D Rosenbrock target.
103//! let target = RosenbrockND {};
104//!
105//! // Define initial positions for 6 chains (each a 3D point).
106//! let initial_positions = init(6, 3);
107//!
108//! // Create the HMC sampler with a step size of 0.01 and 5 leapfrog steps.
109//! let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
110//! target,
111//! initial_positions,
112//! 0.032,
113//! 5,
114//! );
115//!
116//! // Run the sampler for 123+45 iterations, discard 45 burnin observations
117//! let sample = sampler.run(123, 45);
118//!
119//! // Print the shape of the collected sample.
120//! println!("Collected sample with shape: {:?}", sample.dims());
121//! ```
122//!
123//!
124//! See [`examples/minimal_nuts.rs`](examples/minimal_nuts.rs) for the full version with diagnostics.
125//!
126//! ## Example 3: Sampling a 2D Gaussian (Metropolis–Hastings, Continuous)
127//!
128//! ```rust
129//! use general_mcmc::core::{ChainRunner, init};
130//! use general_mcmc::distributions::{Gaussian2D, IsotropicGaussian};
131//! use general_mcmc::metropolis_hastings::MetropolisHastings;
132//! use ndarray::{arr1, arr2};
133//!
134//! let target = Gaussian2D {
135//! mean: arr1(&[0.0, 0.0]),
136//! cov: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
137//! };
138//! let proposal = IsotropicGaussian::new(1.0);
139//! let initial_state = [0.0, 0.0];
140//!
141//! let mut mh = MetropolisHastings::new(target, proposal, init(4, 2));
142//! let sample = mh.run(1000, 100).unwrap();
143//! println!("Metropolis–Hastings sample shape: {:?}", sample.shape());
144//! ```
145//!
146//! ## Example 4: Sampling a Poisson Distribution (Metropolis-Hastings, Discrete)
147//!
148//! ```rust
149//! use general_mcmc::core::{ChainRunner, init};
150//! use general_mcmc::distributions::{Proposal, Target};
151//! use general_mcmc::metropolis_hastings::MetropolisHastings;
152//! use rand::Rng;
153//!
154//! #[derive(Clone)]
155//! struct PoissonTarget {
156//! lambda: f64,
157//! }
158//!
159//! impl Target<usize, f64> for PoissonTarget {
160//! /// unnorm_logp(k) = log( p(k) ), ignoring normalizing constants if you wish.
161//! /// For Poisson(k|lambda) = exp(-lambda) * (lambda^k / k!)
162//! /// so log p(k) = -lambda + k*ln(lambda) - ln(k!)
163//! /// which is enough to do MH acceptance.
164//! fn unnorm_logp(&self, theta: &[usize]) -> f64 {
165//! let k = theta[0];
166//! -self.lambda + (k as f64) * self.lambda.ln() - ln_factorial(k as u64)
167//! }
168//! }
169//!
170//! #[derive(Clone)]
171//! struct NonnegativeProposal;
172//!
173//! impl Proposal<usize, f64> for NonnegativeProposal {
174//! fn sample(&mut self, current: &[usize]) -> Vec<usize> {
175//! let x = current[0];
176//! if x == 0 {
177//! vec![1]
178//! } else {
179//! let step_up = rand::rng().gen_bool(0.5);
180//! vec![if step_up { x + 1 } else { x - 1 }]
181//! }
182//! }
183//!
184//! fn logp(&self, from: &[usize], to: &[usize]) -> f64 {
185//! let (x, y) = (from[0], to[0]);
186//! if x == 0 && y == 1 {
187//! 0.0
188//! } else if x > 0 && (y == x + 1 || y == x - 1) {
189//! (0.5_f64).ln()
190//! } else {
191//! f64::NEG_INFINITY
192//! }
193//! }
194//!
195//! fn set_seed(self, _seed: u64) -> Self {
196//! self
197//! }
198//! }
199//!
200//! fn ln_factorial(k: u64) -> f64 {
201//! (1..=k).map(|v| (v as f64).ln()).sum()
202//! }
203//!
204//! let target = PoissonTarget { lambda: 4.0 };
205//! let proposal = NonnegativeProposal;
206//! let initial_state = vec![vec![0]];
207//!
208//! let mut mh = MetropolisHastings::new(target, proposal, initial_state);
209//! let sample = mh.run(5000, 100).unwrap();
210//! println!("Poisson sample shape: {:?}", sample.shape());
211//! ```
212//!
213//! For more complete implementations (including Gibbs sampling and I/O helpers),
214//! see the `examples/` directory.
215//!
216//! ## Features
217//! - **Parallel Chains** for improved throughput
218//! - **Progress Indicators** (acceptance rates, max Rhat, iteration counts)
219//! - **Common Distributions** (e.g. Gaussian) plus easy traits for custom log‐prob
220//! - **Optional I/O** (CSV, Arrow, Parquet) and GPU support (WGPU)
221//! - **Effective Sample Size (ESS)** estimation following STAN's methodology
222//! - **R-hat Diagnostics** for convergence monitoring
223//!
224//! ## Roadmap
225//! - Rank-Normalized R-hat diagnostics
226//! - Ensemble Slice Sampling (ESS)
227
228pub mod batched_hmc;
229pub mod core;
230mod dev_tools;
231mod diag_mass;
232pub mod distributions;
233pub mod euclidean;
234pub mod generic_hmc;
235pub mod generic_nuts;
236pub mod gibbs;
237#[cfg(feature = "burn")]
238pub mod hmc;
239#[cfg(not(feature = "burn"))]
240pub mod hmc {
241 pub use crate::generic_hmc::{GenericHMC, HamiltonianTarget};
242}
243pub mod io;
244pub mod metropolis_hastings;
245#[cfg(feature = "burn")]
246pub mod nuts;
247#[cfg(not(feature = "burn"))]
248pub mod nuts {
249 pub use crate::generic_nuts::GenericNUTS;
250}
251pub mod stats;