mini_mcmc/hmc.rs
1//! A simple Hamiltonian (Hybrid) Monte Carlo sampler using the `burn` crate for autodiff.
2//!
3//! This is modeled similarly to a Metropolis–Hastings sampler but uses gradient-based proposals
4//! for improved efficiency. The sampler works in a data-parallel fashion and can update multiple
5//! chains simultaneously.
6//!
7//! The code relies on a target distribution provided via the `GradientTarget` trait, which computes
8//! the unnormalized log probability for a batch of positions. The HMC implementation uses the leapfrog
9//! integrator to simulate Hamiltonian dynamics, and the standard accept/reject step for proposal
10//! validation.
11
12use burn::prelude::*;
13use burn::tensor::backend::AutodiffBackend;
14use burn::tensor::cast::ToElement;
15use burn::tensor::Tensor;
16use indicatif::{ProgressBar, ProgressStyle};
17use num_traits::Float;
18use rand::prelude::*;
19use rand::Rng;
20use rand_distr::StandardNormal;
21use std::collections::VecDeque;
22
23/// A batched target trait for computing the unnormalized log probability (and gradients) for a
24/// collection of positions.
25///
26/// Implement this trait for your target distribution to enable gradient-based sampling.
27///
28/// # Type Parameters
29///
30/// * `T`: The floating-point type (e.g., f32 or f64).
31/// * `B`: The autodiff backend from the `burn` crate.
32pub trait GradientTarget<T: Float, B: AutodiffBackend> {
33 /// Compute the log probability for a batch of positions.
34 ///
35 /// # Parameters
36 ///
37 /// * `positions`: A tensor of shape `[n_chains, D]` representing the current positions for each chain.
38 ///
39 /// # Returns
40 ///
41 /// A 1D tensor of shape `[n_chains]` containing the log probabilities for each chain.
42 fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1>;
43}
44
45/// A data-parallel Hamiltonian Monte Carlo (HMC) sampler.
46///
47/// This struct encapsulates the HMC algorithm, including the leapfrog integrator and the
48/// accept/reject mechanism, for sampling from a target distribution in a batched manner.
49///
50/// # Type Parameters
51///
52/// * `T`: Floating-point type for numerical calculations.
53/// * `B`: Autodiff backend from the `burn` crate.
54/// * `GTarget`: The target distribution type implementing the `GradientTarget` trait.
55#[derive(Debug, Clone)]
56pub struct HMC<T, B, GTarget>
57where
58 B: AutodiffBackend,
59{
60 /// The target distribution which provides log probability evaluations and gradients.
61 pub target: GTarget,
62 /// The step size for the leapfrog integrator.
63 pub step_size: T,
64 /// The number of leapfrog steps to take per HMC update.
65 pub n_leapfrog: usize,
66 /// The current positions for all chains, stored as a tensor of shape `[n_chains, D]`.
67 pub positions: Tensor<B, 2>,
68 /// A random number generator for sampling momenta and uniform random numbers for the
69 /// Metropolis acceptance test.
70 pub rng: SmallRng,
71}
72
73impl<T, B, GTarget> HMC<T, B, GTarget>
74where
75 T: Float
76 + burn::tensor::ElementConversion
77 + burn::tensor::Element
78 + rand_distr::uniform::SampleUniform,
79 B: AutodiffBackend,
80 GTarget: GradientTarget<T, B> + std::marker::Sync,
81 StandardNormal: rand::distributions::Distribution<T>,
82 rand_distr::Standard: rand_distr::Distribution<T>,
83{
84 /// Create a new data-parallel HMC sampler.
85 ///
86 /// This method initializes the sampler with the target distribution, initial positions,
87 /// step size, number of leapfrog steps, and a random seed for reproducibility.
88 ///
89 /// # Parameters
90 ///
91 /// * `target`: The target distribution implementing the `GradientTarget` trait.
92 /// * `initial_positions`: A vector of vectors containing the initial positions for each chain,
93 /// with shape `[n_chains][D]`.
94 /// * `step_size`: The step size used in the leapfrog integrator.
95 /// * `n_leapfrog`: The number of leapfrog steps per update.
96 /// * `seed`: A seed for initializing the random number generator.
97 ///
98 /// # Returns
99 ///
100 /// A new instance of `HMC`.
101 pub fn new(
102 target: GTarget,
103 initial_positions: Vec<Vec<T>>,
104 step_size: T,
105 n_leapfrog: usize,
106 ) -> Self {
107 // Build a [n_chains, D] tensor from the flattened initial positions.
108 let (n_chains, dim) = (initial_positions.len(), initial_positions[0].len());
109 let td: TensorData = TensorData::new(
110 initial_positions.into_iter().flatten().collect(),
111 [n_chains, dim],
112 );
113 let positions = Tensor::<B, 2>::from_data(td, &B::Device::default());
114 let rng = SmallRng::seed_from_u64(thread_rng().gen::<u64>());
115 Self {
116 target,
117 step_size,
118 n_leapfrog,
119 positions,
120 rng,
121 }
122 }
123
124 /// Sets a new random seed.
125 ///
126 /// This method ensures reproducibility across runs.
127 ///
128 /// # Arguments
129 ///
130 /// * `seed` - The new random seed value.
131 pub fn set_seed(mut self, seed: u64) -> Self {
132 self.rng = SmallRng::seed_from_u64(seed);
133 self
134 }
135
136 /// Run the HMC sampler for a specified number of steps.
137 ///
138 /// First, the sampler discards a specified number of steps (burn-in), then collects
139 /// samples for the remaining steps. The samples for each step are stored in a 3D tensor of shape
140 /// `[n_steps, n_chains, D]`.
141 ///
142 /// # Parameters
143 ///
144 /// * `n_steps`: The number of sampling steps to run (after burn-in).
145 /// * `discard`: The number of initial steps to discard as burn-in.
146 ///
147 /// # Returns
148 ///
149 /// A tensor containing the collected samples.
150 pub fn run(&mut self, n_steps: usize, discard: usize) -> Tensor<B, 3> {
151 let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
152 let mut out = Tensor::<B, 3>::empty(
153 [n_steps, n_chains, self.positions.dims()[1]],
154 &B::Device::default(),
155 );
156
157 // Discard the first `discard` positions.
158 (0..discard).for_each(|_| self.step());
159
160 // Collect samples.
161 for step in 1..(n_steps + 1) {
162 self.step();
163 out.inplace(|_out| {
164 _out.slice_assign(
165 [step - 1..step, 0..n_chains, 0..dim],
166 self.positions.clone().unsqueeze_dim(0),
167 )
168 });
169 }
170 out
171 }
172
173 /// Runs the HMC sampler for a specified number of steps with progress updates,
174 /// discarding the first `discard` iterations as burn-in.
175 ///
176 /// This function displays a progress bar (using the `indicatif` crate) that is updated
177 /// with an approximate acceptance probability computed over a sliding window of 50 iterations.
178 ///
179 /// # Parameters
180 ///
181 /// * `n_steps` - The number of sampling steps to run (after burn-in).
182 /// * `discard` - The number of initial iterations to discard.
183 ///
184 /// # Returns
185 ///
186 /// A tensor of shape `[n_steps, n_chains, D]` containing the collected samples.
187 pub fn run_progress(&mut self, n_steps: usize, discard: usize) -> Tensor<B, 3> {
188 // Discard initial burn-in samples.
189 (0..discard).for_each(|_| self.step());
190
191 let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
192 let mut out = Tensor::<B, 3>::empty([n_steps, n_chains, dim], &B::Device::default());
193
194 let pb = ProgressBar::new(n_steps as u64);
195 pb.set_style(
196 ProgressStyle::default_bar()
197 .template("{prefix} [{elapsed_precise}, {eta}] {bar:40.white} {pos}/{len} {msg}")
198 .unwrap()
199 .progress_chars("=>-"),
200 );
201 pb.set_prefix("HMC");
202
203 // Use a sliding window of 50 iterations to estimate the acceptance probability.
204 let window_size = 50;
205 let mut accept_window: VecDeque<f32> = VecDeque::with_capacity(window_size);
206
207 let mut last_state = self.positions.clone();
208
209 for i in 0..n_steps {
210 self.step();
211 let current_state = self.positions.clone();
212
213 // For each chain, check if its state changed.
214 let accepted_count = last_state
215 .clone()
216 .not_equal(current_state.clone())
217 .all_dim(1)
218 .int()
219 .sum()
220 .into_scalar()
221 .to_f32();
222
223 let iter_accept_rate = accepted_count / n_chains as f32;
224
225 // Update the sliding window.
226 accept_window.push_front(iter_accept_rate);
227 if accept_window.len() > window_size {
228 accept_window.pop_back();
229 }
230
231 // Compute average acceptance rate over the sliding window.
232 let avg_accept_rate: f32 =
233 accept_window.iter().sum::<f32>() / accept_window.len() as f32;
234 pb.set_message(format!("p(accept) ≈ {:.2}", avg_accept_rate));
235
236 // Store the current state.
237 out.inplace(|_out| {
238 _out.slice_assign(
239 [i..i + 1, 0..n_chains, 0..dim],
240 current_state.clone().unsqueeze_dim(0),
241 )
242 });
243 pb.inc(1);
244 last_state = current_state;
245 }
246 pb.finish_with_message("Done!");
247 out
248 }
249
250 /// Perform one batched HMC update for all chains in parallel.
251 ///
252 /// The update consists of:
253 /// 1) Sampling momenta from a standard normal distribution.
254 /// 2) Running the leapfrog integrator to propose new positions.
255 /// 3) Performing an accept/reject step for each chain.
256 ///
257 /// This method updates `self.positions` in-place.
258 pub fn step(&mut self) {
259 let shape = self.positions.shape();
260 let (n_chains, dim) = (shape.dims[0], shape.dims[1]);
261
262 // 1) Sample momenta: shape [n_chains, D]
263 let momentum_0 = Tensor::<B, 2>::random(
264 Shape::new([n_chains, dim]),
265 burn::tensor::Distribution::Normal(0., 1.),
266 &B::Device::default(),
267 );
268
269 // Current log probability: shape [n_chains]
270 let logp_current = self.target.log_prob_batch(&self.positions);
271
272 // Compute kinetic energy: 0.5 * sum_{d} (p^2) for each chain.
273 let ke_current = momentum_0
274 .clone()
275 .powf_scalar(2.0)
276 .sum_dim(1) // Sum over dimension 1 => shape [n_chains]
277 .squeeze(1)
278 .mul_scalar(T::from(0.5).unwrap());
279
280 // Compute the Hamiltonian: -logp + kinetic energy, shape [n_chains]
281 let h_current: Tensor<B, 1> = -logp_current + ke_current;
282
283 // 2) Run the leapfrog integrator.
284 let (proposed_positions, proposed_momenta, logp_proposed) =
285 self.leapfrog(self.positions.clone(), momentum_0);
286
287 // Compute proposed kinetic energy.
288 let ke_proposed = proposed_momenta
289 .powf_scalar(2.0)
290 .sum_dim(1)
291 .squeeze(1)
292 .mul_scalar(T::from(0.5).unwrap());
293
294 let h_proposed = -logp_proposed + ke_proposed;
295
296 // 3) Accept/Reject each proposal.
297 let accept_logp = h_current.sub(h_proposed);
298
299 // Draw a uniform random number for each chain.
300 let mut uniform_data = Vec::with_capacity(n_chains);
301 for _ in 0..n_chains {
302 uniform_data.push(self.rng.gen::<T>());
303 }
304 let uniform = Tensor::<B, 1>::random(
305 Shape::new([n_chains]),
306 burn::tensor::Distribution::Default,
307 &B::Device::default(),
308 );
309
310 // Accept the proposal if accept_logp >= ln(u).
311 let ln_u = uniform.log(); // shape [n_chains]
312 let accept_mask = accept_logp.greater_equal(ln_u); // Boolean mask of shape [n_chains]
313 let mut accept_mask_big: Tensor<B, 2, Bool> = accept_mask.clone().unsqueeze_dim(1);
314 accept_mask_big = accept_mask_big.expand([n_chains, dim]);
315
316 // Update positions: for accepted chains, replace current positions with proposed positions.
317 self.positions.inplace(|x| {
318 x.clone()
319 .mask_where(accept_mask_big, proposed_positions)
320 .detach()
321 });
322 }
323
324 /// Perform the leapfrog integrator steps in a batched manner.
325 ///
326 /// This method performs `n_leapfrog` iterations of the leapfrog update:
327 /// - A half-step update of the momentum.
328 /// - A full-step update of the positions.
329 /// - Another half-step update of the momentum.
330 ///
331 /// # Parameters
332 ///
333 /// * `pos`: The current positions, a tensor of shape `[n_chains, D]`.
334 /// * `mom`: The initial momenta, a tensor of shape `[n_chains, D]`.
335 ///
336 /// # Returns
337 ///
338 /// A tuple containing:
339 /// - The new positions (tensor of shape `[n_chains, D]`),
340 /// - The new momenta (tensor of shape `[n_chains, D]`),
341 /// - The log probability evaluated at the new positions (tensor of shape `[n_chains]`).
342 fn leapfrog(
343 &mut self,
344 mut pos: Tensor<B, 2>,
345 mut mom: Tensor<B, 2>,
346 ) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 1>) {
347 let half = T::from(0.5).unwrap();
348 for _step_i in 0..self.n_leapfrog {
349 // Detach pos to ensure it's AD-enabled for the gradient computation.
350 pos = pos.detach().require_grad();
351
352 // Compute gradient of log probability with respect to pos (batched over chains).
353 let logp = self.target.log_prob_batch(&pos); // shape [n_chains]
354 let grads = pos.grad(&logp.backward()).unwrap();
355
356 // Update momentum by a half-step using the computed gradients.
357 mom.inplace(|_mom| {
358 _mom.add(Tensor::<B, 2>::from_inner(
359 grads.mul_scalar(self.step_size * half),
360 ))
361 });
362
363 // Full-step update for positions.
364 pos.inplace(|_pos| {
365 _pos.add(mom.clone().mul_scalar(self.step_size))
366 .detach()
367 .require_grad()
368 });
369
370 // Compute gradient at the new positions.
371 let logp2 = self.target.log_prob_batch(&pos);
372 let grads2 = pos.grad(&logp2.backward()).unwrap();
373
374 // Update momentum by another half-step using the new gradients.
375 mom.inplace(|_mom| {
376 _mom.add(Tensor::<B, 2>::from_inner(
377 grads2.mul_scalar(self.step_size * half),
378 ))
379 });
380 }
381
382 // Compute final log probability at the updated positions.
383 let logp_final = self.target.log_prob_batch(&pos);
384 (pos.detach(), mom.detach(), logp_final.detach())
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use crate::dev_tools::Timer;
391
392 use super::*;
393 use burn::{
394 backend::{Autodiff, NdArray},
395 tensor::{Element, Tensor},
396 };
397 use num_traits::Float;
398
399 // Define the Rosenbrock distribution.
400 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
401 struct Rosenbrock2D<T: Float> {
402 a: T,
403 b: T,
404 }
405
406 // For the batched version we need to implement BatchGradientTarget.
407 impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
408 where
409 T: Float + std::fmt::Debug + Element,
410 B: burn::tensor::backend::AutodiffBackend,
411 {
412 fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
413 let n = positions.dims()[0] as i64;
414 let x = positions.clone().slice([(0, n), (0, 1)]);
415 let y = positions.clone().slice([(0, n), (1, 2)]);
416
417 // Compute (a - x)^2 in place.
418 let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
419
420 // Compute (y - x^2)^2 in place.
421 let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
422
423 // Return the negative sum as a flattened 1D tensor.
424 -(term_1 + term_2).flatten(0, 1)
425 }
426 }
427
428 // Define the Rosenbrock distribution.
429 // From: https://arxiv.org/pdf/1903.09556.
430 struct RosenbrockND {}
431
432 // For the batched version we need to implement BatchGradientTarget.
433 impl<T, B> GradientTarget<T, B> for RosenbrockND
434 where
435 T: Float + std::fmt::Debug + Element,
436 B: burn::tensor::backend::AutodiffBackend,
437 {
438 fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
439 let k = positions.dims()[0] as i64;
440 let n = positions.dims()[1] as i64;
441 let low = positions.clone().slice([(0, k), (0, (n - 1))]);
442 let high = positions.clone().slice([(0, k), (1, n)]);
443 let term_1 = (high - low.clone().powi_scalar(2))
444 .powi_scalar(2)
445 .mul_scalar(100);
446 let term_2 = low.neg().add_scalar(1).powi_scalar(2);
447 -(term_1 + term_2).sum_dim(1).squeeze(1)
448 }
449 }
450
451 #[test]
452 fn test_collect_hmc_samples_single() {
453 // Use the CPU backend (NdArray) wrapped in Autodiff.
454 type BackendType = Autodiff<NdArray>;
455
456 // Create the Rosenbrock target (a = 1, b = 100)
457 let target = Rosenbrock2D {
458 a: 1.0_f32,
459 b: 100.0_f32,
460 };
461
462 // Define initial positions for a single chain (2-dimensional).
463 let initial_positions = vec![vec![0.0_f32, 0.0]];
464 let n_steps = 3;
465
466 // Create the HMC sampler.
467 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
468 target,
469 initial_positions,
470 0.01, // step size
471 2, // number of leapfrog steps per update
472 )
473 .set_seed(42);
474
475 // Run the sampler for n_steps with no discard.
476 let mut timer = Timer::new();
477 let samples: Tensor<BackendType, 3> = sampler.run(n_steps, 0);
478 timer.log(format!(
479 "Collected samples (10 chains) with shape: {:?}",
480 samples.dims()
481 ))
482 }
483
484 #[test]
485 fn test_collect_hmc_samples_10_chains() {
486 // Use the CPU backend (NdArray) wrapped in Autodiff.
487 type BackendType = Autodiff<NdArray>;
488
489 // Create the Rosenbrock target (a = 1, b = 100)
490 let target = Rosenbrock2D {
491 a: 1.0_f32,
492 b: 100.0_f32,
493 };
494
495 // Define 10 chains all initialized to (1.0, 2.0).
496 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
497 let n_steps = 1000;
498
499 // Create the HMC sampler.
500 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
501 target,
502 initial_positions,
503 0.01, // step size
504 10, // number of leapfrog steps per update
505 )
506 .set_seed(42);
507
508 // Run the sampler for n_steps with no discard.
509 let mut timer = Timer::new();
510 let samples: Tensor<BackendType, 3> = sampler.run(n_steps, 0);
511 timer.log(format!(
512 "Collected samples (10 chains) with shape: {:?}",
513 samples.dims()
514 ))
515 }
516
517 #[test]
518 fn test_collect_hmc_samples_10_chains_progress() {
519 // Use the CPU backend (NdArray) wrapped in Autodiff.
520 type BackendType = Autodiff<NdArray>;
521
522 // Create the Rosenbrock target (a = 1, b = 100)
523 let target = Rosenbrock2D {
524 a: 1.0_f32,
525 b: 100.0_f32,
526 };
527
528 // Define 10 chains all initialized to (1.0, 2.0).
529 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
530 let n_steps = 1000;
531
532 // Create the HMC sampler.
533 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
534 target,
535 initial_positions,
536 0.05, // step size
537 10, // number of leapfrog steps per update
538 )
539 .set_seed(42);
540
541 // Run the sampler for n_steps with no discard.
542 let mut timer = Timer::new();
543 let samples: Tensor<BackendType, 3> = sampler.run_progress(n_steps, 100);
544 timer.log(format!(
545 "Collected samples (10 chains) with shape: {:?}",
546 samples.dims()
547 ))
548 }
549
550 #[test]
551 #[ignore = "Benchmark test: run only when explicitly requested"]
552 fn test_collect_hmc_samples_benchmark() {
553 // Use the CPU backend (NdArray) wrapped in Autodiff.
554 type BackendType = Autodiff<burn::backend::NdArray>;
555
556 // Create the Rosenbrock target (a = 1, b = 100)
557 let target = Rosenbrock2D {
558 a: 1.0_f32,
559 b: 100.0_f32,
560 };
561
562 // We'll define 6 chains all initialized to (1.0, 2.0).
563 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 6];
564 let n_steps = 5000;
565
566 // Create the data-parallel HMC sampler.
567 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
568 target,
569 initial_positions,
570 0.01, // step size
571 50, // number of leapfrog steps per update
572 )
573 .set_seed(42);
574
575 // Run HMC for n_steps, collecting samples.
576 let mut timer = Timer::new();
577 let samples = sampler.run(n_steps, 0);
578 timer.log(format!(
579 "HMC sampler: generated {} samples.",
580 samples.dims()[0..2].iter().product::<usize>()
581 ))
582 }
583
584 #[test]
585 #[ignore = "Benchmark test: run only when explicitly requested"]
586 fn test_benchmark_10000d() {
587 // Use the CPU backend (NdArray) wrapped in Autodiff.
588 type BackendType = Autodiff<burn::backend::Wgpu>;
589
590 let seed = 42;
591 let d = 10000;
592
593 let rng = SmallRng::seed_from_u64(seed);
594 // We'll define 6 chains all initialized to (1.0, 2.0).
595 let initial_positions: Vec<Vec<f32>> =
596 vec![rng.sample_iter(StandardNormal).take(d).collect(); 6];
597 let n_steps = 500;
598
599 // Create the data-parallel HMC sampler.
600 let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
601 RosenbrockND {},
602 initial_positions,
603 0.01, // step size
604 50, // number of leapfrog steps per update
605 )
606 .set_seed(42);
607
608 // Run HMC for n_steps, collecting samples.
609 let mut timer = Timer::new();
610 let samples = sampler.run(n_steps, 0);
611 timer.log(format!(
612 "HMC sampler: generated {} samples.",
613 samples.dims()[0..2].iter().product::<usize>()
614 ))
615 }
616}