mini_mcmc/
lib.rs

1//! # Mini MCMC
2//!
3//! A compact Rust library offering **Markov Chain Monte Carlo (MCMC)** methods,
4//! including **Hamiltonian Monte Carlo (HMC)**, **Metropolis–Hastings**, and
5//! **Gibbs Sampling** for both discrete and continuous targets.
6//!
7//! ## Example 1: Sampling a 2D Gaussian (Metropolis–Hastings)
8//!
9//! ```rust
10//! use mini_mcmc::core::{ChainRunner, init};
11//! use mini_mcmc::distributions::{Gaussian2D, IsotropicGaussian};
12//! use mini_mcmc::metropolis_hastings::MetropolisHastings;
13//! use ndarray::{arr1, arr2};
14//!
15//! let target = Gaussian2D {
16//!     mean: arr1(&[0.0, 0.0]),
17//!     cov: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
18//! };
19//! let proposal = IsotropicGaussian::new(1.0);
20//! let initial_state = [0.0, 0.0];
21//!
22//! let mut mh = MetropolisHastings::new(target, proposal, init(4, 2));
23//! let samples = mh.run(1000, 100).unwrap();
24//! println!("Metropolis–Hastings samples shape: {:?}", samples.shape());
25//! ```
26//!
27//! ## Example 2: Sampling a 3D Rosenbrock (HMC)
28//!
29//! ```rust
30//! use burn::tensor::Element;
31//! use burn::{backend::Autodiff, prelude::Tensor};
32//! use mini_mcmc::hmc::HMC;
33//! use mini_mcmc::distributions::GradientTarget;
34//! use mini_mcmc::core::init;
35//! use num_traits::Float;
36//!
37//! /// The 3D Rosenbrock distribution.
38//! ///
39//! /// For a point x = (x₁, x₂, x₃), the log probability is defined as the negative of
40//! /// the sum of two Rosenbrock terms:
41//! ///
42//! ///   f(x) = 100*(x₂ - x₁²)² + (1 - x₁)² + 100*(x₃ - x₂²)² + (1 - x₂)²
43//! ///
44//! /// This implementation generalizes to d dimensions, but here we use it for 3D.
45//! struct RosenbrockND {}
46//!
47//! impl<T, B> GradientTarget<T, B> for RosenbrockND
48//! where
49//!     T: Float + std::fmt::Debug + Element,
50//!     B: burn::tensor::backend::AutodiffBackend,
51//! {
52//!     fn log_prob_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
53//!         // Assume positions has shape [n_chains, d] with d = 3.
54//!         let k = positions.dims()[0] as i64;
55//!         let n = positions.dims()[1] as i64;
56//!         let low = positions.clone().slice([(0, k), (0, n - 1)]);
57//!         let high = positions.clone().slice([(0, k), (1, n)]);
58//!         let term_1 = (high - low.clone().powi_scalar(2))
59//!             .powi_scalar(2)
60//!             .mul_scalar(100);
61//!         let term_2 = low.neg().add_scalar(1).powi_scalar(2);
62//!         -(term_1 + term_2).sum_dim(1).squeeze(1)
63//!     }
64//! }
65//!
66//! // Use the CPU backend wrapped in Autodiff (e.g., NdArray).
67//! type BackendType = Autodiff<burn::backend::NdArray>;
68//!
69//! // Create the 3D Rosenbrock target.
70//! let target = RosenbrockND {};
71//!
72//! // Define initial positions for 6 chains (each a 3D point).
73//! let initial_positions = init(6, 3);
74//!
75//! // Create the HMC sampler with a step size of 0.01 and 5 leapfrog steps.
76//! let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
77//!     target,
78//!     initial_positions,
79//!     0.032,
80//!     5,
81//! );
82//!
83//! // Run the sampler for 123+45 iterations, discard 45 burnin samples
84//! let samples = sampler.run(123, 45);
85//!
86//! // Print the shape of the collected samples.
87//! println!("Collected samples with shape: {:?}", samples.dims());
88//! ```
89//!
90//! ## Example 3: Sampling a Poisson Distribution (Discrete)
91//!
92//! ```rust
93//! use mini_mcmc::core::{ChainRunner, init};
94//! use mini_mcmc::distributions::{Proposal, Target};
95//! use mini_mcmc::metropolis_hastings::MetropolisHastings;
96//! use rand::Rng;
97//!
98//! #[derive(Clone)]
99//! struct PoissonTarget {
100//!     lambda: f64,
101//! }
102//!
103//! impl Target<usize, f64> for PoissonTarget {
104//!     /// unnorm_log_prob(k) = log( p(k) ), ignoring normalizing constants if you wish.
105//!     /// For Poisson(k|lambda) = exp(-lambda) * (lambda^k / k!)
106//!     /// so log p(k) = -lambda + k*ln(lambda) - ln(k!)
107//!     /// which is enough to do MH acceptance.
108//!     fn unnorm_log_prob(&self, theta: &[usize]) -> f64 {
109//!         let k = theta[0];
110//!         -self.lambda + (k as f64) * self.lambda.ln() - ln_factorial(k as u64)
111//!     }
112//! }
113//!
114//! #[derive(Clone)]
115//! struct NonnegativeProposal;
116//!
117//! impl Proposal<usize, f64> for NonnegativeProposal {
118//!     fn sample(&mut self, current: &[usize]) -> Vec<usize> {
119//!         let x = current[0];
120//!         if x == 0 {
121//!             vec![1]
122//!         } else {
123//!             let step_up = rand::thread_rng().gen_bool(0.5);
124//!             vec![if step_up { x + 1 } else { x - 1 }]
125//!         }
126//!     }
127//!
128//!     fn log_prob(&self, from: &[usize], to: &[usize]) -> f64 {
129//!         let (x, y) = (from[0], to[0]);
130//!         if x == 0 && y == 1 {
131//!             0.0
132//!         } else if x > 0 && (y == x + 1 || y == x - 1) {
133//!             (0.5_f64).ln()
134//!         } else {
135//!             f64::NEG_INFINITY
136//!         }
137//!     }
138//!     fn set_seed(self, _seed: u64) -> Self {
139//!         self
140//!     }
141//! }
142//!
143//! fn ln_factorial(k: u64) -> f64 {
144//!     (1..=k).map(|v| (v as f64).ln()).sum()
145//! }
146//!
147//! let target = PoissonTarget { lambda: 4.0 };
148//! let proposal = NonnegativeProposal;
149//! let initial_state = vec![vec![0]];
150//!
151//! let mut mh = MetropolisHastings::new(target, proposal, initial_state);
152//! let samples = mh.run(5000, 100).unwrap();
153//! println!("Poisson samples shape: {:?}", samples.shape());
154//! ```
155//!
156//! For more complete implementations (including Gibbs sampling and I/O helpers),
157//! see the `examples/` directory.
158//!
159//! ## Features
160//! - **Parallel Chains** for improved throughput
161//! - **Progress Indicators** (acceptance rates, iteration counts)
162//! - **Common Distributions** (e.g. Gaussian) plus easy traits for custom log‐prob
163//! - **Optional I/O** (CSV, Arrow, Parquet) and GPU sampling (WGPU)
164//!
165//! ## Roadmap
166//! - No-U-Turn Sampler (NUTS)
167//! - Rank-Normalized R-hat diagnostics
168//! - Ensemble Slice Sampling (ESS)
169//! - Effective Sample Size estimation
170
171pub mod core;
172mod dev_tools;
173pub mod distributions;
174pub mod gibbs;
175pub mod hmc;
176pub mod io;
177pub mod ks_test;
178pub mod metropolis_hastings;
179pub mod stats;