mini_mcmc/
hmc.rs

1//! A simple Hamiltonian (Hybrid) Monte Carlo sampler using the `burn` crate for autodiff.
2//!
3//! This is modeled similarly to a Metropolis–Hastings sampler but uses gradient-based proposals
4//! for improved efficiency. The sampler works in a data-parallel fashion and can update multiple
5//! chains simultaneously.
6//!
7//! The code relies on a target distribution provided via the `GradientTarget` trait, which computes
8//! the unnormalized log probability for a batch of positions. The HMC implementation uses the leapfrog
9//! integrator to simulate Hamiltonian dynamics, and the standard accept/reject step for proposal
10//! validation.
11
12use burn::prelude::*;
13use burn::tensor::backend::AutodiffBackend;
14use burn::tensor::cast::ToElement;
15use burn::tensor::Tensor;
16use indicatif::{ProgressBar, ProgressStyle};
17use num_traits::Float;
18use rand::prelude::*;
19use rand::Rng;
20use rand_distr::StandardNormal;
21use std::collections::VecDeque;
22
23/// A batched target trait for computing the unnormalized log probability (and gradients) for a
24/// collection of positions.
25///
26/// Implement this trait for your target distribution to enable gradient-based sampling.
27///
28/// # Type Parameters
29///
30/// * `T`: The floating-point type (e.g., f32 or f64).
31/// * `B`: The autodiff backend from the `burn` crate.
32pub trait GradientTarget<T: Float, B: AutodiffBackend> {
33    /// Compute the log probability for a batch of positions.
34    ///
35    /// # Parameters
36    ///
37    /// * `positions`: A tensor of shape `[n_chains, D]` representing the current positions for each chain.
38    ///
39    /// # Returns
40    ///
41    /// A 1D tensor of shape `[n_chains]` containing the log probabilities for each chain.
42    fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1>;
43}
44
45/// A data-parallel Hamiltonian Monte Carlo (HMC) sampler.
46///
47/// This struct encapsulates the HMC algorithm, including the leapfrog integrator and the
48/// accept/reject mechanism, for sampling from a target distribution in a batched manner.
49///
50/// # Type Parameters
51///
52/// * `T`: Floating-point type for numerical calculations.
53/// * `B`: Autodiff backend from the `burn` crate.
54/// * `GTarget`: The target distribution type implementing the `GradientTarget` trait.
55#[derive(Debug, Clone)]
56pub struct HMC<T, B, GTarget>
57where
58    B: AutodiffBackend,
59{
60    /// The target distribution which provides log probability evaluations and gradients.
61    pub target: GTarget,
62    /// The step size for the leapfrog integrator.
63    pub step_size: T,
64    /// The number of leapfrog steps to take per HMC update.
65    pub n_leapfrog: usize,
66    /// The current positions for all chains, stored as a tensor of shape `[n_chains, D]`.
67    pub positions: Tensor<B, 2>,
68    /// A random number generator for sampling momenta and uniform random numbers for the
69    /// Metropolis acceptance test.
70    pub rng: SmallRng,
71}
72
73impl<T, B, GTarget> HMC<T, B, GTarget>
74where
75    T: Float
76        + burn::tensor::ElementConversion
77        + burn::tensor::Element
78        + rand_distr::uniform::SampleUniform,
79    B: AutodiffBackend,
80    GTarget: GradientTarget<T, B> + std::marker::Sync,
81    StandardNormal: rand::distributions::Distribution<T>,
82    rand_distr::Standard: rand_distr::Distribution<T>,
83{
84    /// Create a new data-parallel HMC sampler.
85    ///
86    /// This method initializes the sampler with the target distribution, initial positions,
87    /// step size, number of leapfrog steps, and a random seed for reproducibility.
88    ///
89    /// # Parameters
90    ///
91    /// * `target`: The target distribution implementing the `GradientTarget` trait.
92    /// * `initial_positions`: A vector of vectors containing the initial positions for each chain,
93    ///    with shape `[n_chains][D]`.
94    /// * `step_size`: The step size used in the leapfrog integrator.
95    /// * `n_leapfrog`: The number of leapfrog steps per update.
96    /// * `seed`: A seed for initializing the random number generator.
97    ///
98    /// # Returns
99    ///
100    /// A new instance of `HMC`.
101    pub fn new(
102        target: GTarget,
103        initial_positions: Vec<Vec<T>>,
104        step_size: T,
105        n_leapfrog: usize,
106    ) -> Self {
107        // Build a [n_chains, D] tensor from the flattened initial positions.
108        let (n_chains, dim) = (initial_positions.len(), initial_positions[0].len());
109        let td: TensorData = TensorData::new(
110            initial_positions.into_iter().flatten().collect(),
111            [n_chains, dim],
112        );
113        let positions = Tensor::<B, 2>::from_data(td, &B::Device::default());
114        let rng = SmallRng::seed_from_u64(thread_rng().gen::<u64>());
115        Self {
116            target,
117            step_size,
118            n_leapfrog,
119            positions,
120            rng,
121        }
122    }
123
124    /// Sets a new random seed.
125    ///
126    /// This method ensures reproducibility across runs.
127    ///
128    /// # Arguments
129    ///
130    /// * `seed` - The new random seed value.
131    pub fn set_seed(mut self, seed: u64) -> Self {
132        self.rng = SmallRng::seed_from_u64(seed);
133        self
134    }
135
136    /// Run the HMC sampler for a specified number of steps.
137    ///
138    /// First, the sampler discards a specified number of steps (burn-in), then collects
139    /// samples for the remaining steps. The samples for each step are stored in a 3D tensor of shape
140    /// `[n_steps, n_chains, D]`.
141    ///
142    /// # Parameters
143    ///
144    /// * `n_steps`: The number of sampling steps to run (after burn-in).
145    /// * `discard`: The number of initial steps to discard as burn-in.
146    ///
147    /// # Returns
148    ///
149    /// A tensor containing the collected samples.
150    pub fn run(&mut self, n_steps: usize, discard: usize) -> Tensor<B, 3> {
151        let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
152        let mut out = Tensor::<B, 3>::empty(
153            [n_steps, n_chains, self.positions.dims()[1]],
154            &B::Device::default(),
155        );
156
157        // Discard the first `discard` positions.
158        (0..discard).for_each(|_| self.step());
159
160        // Collect samples.
161        for step in 1..(n_steps + 1) {
162            self.step();
163            out.inplace(|_out| {
164                _out.slice_assign(
165                    [step - 1..step, 0..n_chains, 0..dim],
166                    self.positions.clone().unsqueeze_dim(0),
167                )
168            });
169        }
170        out
171    }
172
173    /// Runs the HMC sampler for a specified number of steps with progress updates,
174    /// discarding the first `discard` iterations as burn-in.
175    ///
176    /// This function displays a progress bar (using the `indicatif` crate) that is updated
177    /// with an approximate acceptance probability computed over a sliding window of 50 iterations.
178    ///
179    /// # Parameters
180    ///
181    /// * `n_steps` - The number of sampling steps to run (after burn-in).
182    /// * `discard` - The number of initial iterations to discard.
183    ///
184    /// # Returns
185    ///
186    /// A tensor of shape `[n_steps, n_chains, D]` containing the collected samples.
187    pub fn run_progress(&mut self, n_steps: usize, discard: usize) -> Tensor<B, 3> {
188        // Discard initial burn-in samples.
189        (0..discard).for_each(|_| self.step());
190
191        let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
192        let mut out = Tensor::<B, 3>::empty([n_steps, n_chains, dim], &B::Device::default());
193
194        let pb = ProgressBar::new(n_steps as u64);
195        pb.set_style(
196            ProgressStyle::default_bar()
197                .template("{prefix} [{elapsed_precise}, {eta}] {bar:40.white} {pos}/{len} {msg}")
198                .unwrap()
199                .progress_chars("=>-"),
200        );
201        pb.set_prefix("HMC");
202
203        // Use a sliding window of 50 iterations to estimate the acceptance probability.
204        let window_size = 50;
205        let mut accept_window: VecDeque<f32> = VecDeque::with_capacity(window_size);
206
207        let mut last_state = self.positions.clone();
208
209        for i in 0..n_steps {
210            self.step();
211            let current_state = self.positions.clone();
212
213            // For each chain, check if its state changed.
214            let accepted_count = last_state
215                .clone()
216                .not_equal(current_state.clone())
217                .all_dim(1)
218                .int()
219                .sum()
220                .into_scalar()
221                .to_f32();
222
223            let iter_accept_rate = accepted_count / n_chains as f32;
224
225            // Update the sliding window.
226            accept_window.push_front(iter_accept_rate);
227            if accept_window.len() > window_size {
228                accept_window.pop_back();
229            }
230
231            // Compute average acceptance rate over the sliding window.
232            let avg_accept_rate: f32 =
233                accept_window.iter().sum::<f32>() / accept_window.len() as f32;
234            pb.set_message(format!("p(accept) ≈ {:.2}", avg_accept_rate));
235
236            // Store the current state.
237            out.inplace(|_out| {
238                _out.slice_assign(
239                    [i..i + 1, 0..n_chains, 0..dim],
240                    current_state.clone().unsqueeze_dim(0),
241                )
242            });
243            pb.inc(1);
244            last_state = current_state;
245        }
246        pb.finish_with_message("Done!");
247        out
248    }
249
250    /// Perform one batched HMC update for all chains in parallel.
251    ///
252    /// The update consists of:
253    /// 1) Sampling momenta from a standard normal distribution.
254    /// 2) Running the leapfrog integrator to propose new positions.
255    /// 3) Performing an accept/reject step for each chain.
256    ///
257    /// This method updates `self.positions` in-place.
258    pub fn step(&mut self) {
259        let shape = self.positions.shape();
260        let (n_chains, dim) = (shape.dims[0], shape.dims[1]);
261
262        // 1) Sample momenta: shape [n_chains, D]
263        let momentum_0 = Tensor::<B, 2>::random(
264            Shape::new([n_chains, dim]),
265            burn::tensor::Distribution::Normal(0., 1.),
266            &B::Device::default(),
267        );
268
269        // Current log probability: shape [n_chains]
270        let logp_current = self.target.log_prob_batch(&self.positions);
271
272        // Compute kinetic energy: 0.5 * sum_{d} (p^2) for each chain.
273        let ke_current = momentum_0
274            .clone()
275            .powf_scalar(2.0)
276            .sum_dim(1) // Sum over dimension 1 => shape [n_chains]
277            .squeeze(1)
278            .mul_scalar(T::from(0.5).unwrap());
279
280        // Compute the Hamiltonian: -logp + kinetic energy, shape [n_chains]
281        let h_current: Tensor<B, 1> = -logp_current + ke_current;
282
283        // 2) Run the leapfrog integrator.
284        let (proposed_positions, proposed_momenta, logp_proposed) =
285            self.leapfrog(self.positions.clone(), momentum_0);
286
287        // Compute proposed kinetic energy.
288        let ke_proposed = proposed_momenta
289            .powf_scalar(2.0)
290            .sum_dim(1)
291            .squeeze(1)
292            .mul_scalar(T::from(0.5).unwrap());
293
294        let h_proposed = -logp_proposed + ke_proposed;
295
296        // 3) Accept/Reject each proposal.
297        let accept_logp = h_current.sub(h_proposed);
298
299        // Draw a uniform random number for each chain.
300        let mut uniform_data = Vec::with_capacity(n_chains);
301        for _ in 0..n_chains {
302            uniform_data.push(self.rng.gen::<T>());
303        }
304        let uniform = Tensor::<B, 1>::random(
305            Shape::new([n_chains]),
306            burn::tensor::Distribution::Default,
307            &B::Device::default(),
308        );
309
310        // Accept the proposal if accept_logp >= ln(u).
311        let ln_u = uniform.log(); // shape [n_chains]
312        let accept_mask = accept_logp.greater_equal(ln_u); // Boolean mask of shape [n_chains]
313        let mut accept_mask_big: Tensor<B, 2, Bool> = accept_mask.clone().unsqueeze_dim(1);
314        accept_mask_big = accept_mask_big.expand([n_chains, dim]);
315
316        // Update positions: for accepted chains, replace current positions with proposed positions.
317        self.positions.inplace(|x| {
318            x.clone()
319                .mask_where(accept_mask_big, proposed_positions)
320                .detach()
321        });
322    }
323
324    /// Perform the leapfrog integrator steps in a batched manner.
325    ///
326    /// This method performs `n_leapfrog` iterations of the leapfrog update:
327    /// - A half-step update of the momentum.
328    /// - A full-step update of the positions.
329    /// - Another half-step update of the momentum.
330    ///
331    /// # Parameters
332    ///
333    /// * `pos`: The current positions, a tensor of shape `[n_chains, D]`.
334    /// * `mom`: The initial momenta, a tensor of shape `[n_chains, D]`.
335    ///
336    /// # Returns
337    ///
338    /// A tuple containing:
339    /// - The new positions (tensor of shape `[n_chains, D]`),
340    /// - The new momenta (tensor of shape `[n_chains, D]`),
341    /// - The log probability evaluated at the new positions (tensor of shape `[n_chains]`).
342    fn leapfrog(
343        &mut self,
344        mut pos: Tensor<B, 2>,
345        mut mom: Tensor<B, 2>,
346    ) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 1>) {
347        let half = T::from(0.5).unwrap();
348        for _step_i in 0..self.n_leapfrog {
349            // Detach pos to ensure it's AD-enabled for the gradient computation.
350            pos = pos.detach().require_grad();
351
352            // Compute gradient of log probability with respect to pos (batched over chains).
353            let logp = self.target.log_prob_batch(&pos); // shape [n_chains]
354            let grads = pos.grad(&logp.backward()).unwrap();
355
356            // Update momentum by a half-step using the computed gradients.
357            mom.inplace(|_mom| {
358                _mom.add(Tensor::<B, 2>::from_inner(
359                    grads.mul_scalar(self.step_size * half),
360                ))
361            });
362
363            // Full-step update for positions.
364            pos.inplace(|_pos| {
365                _pos.add(mom.clone().mul_scalar(self.step_size))
366                    .detach()
367                    .require_grad()
368            });
369
370            // Compute gradient at the new positions.
371            let logp2 = self.target.log_prob_batch(&pos);
372            let grads2 = pos.grad(&logp2.backward()).unwrap();
373
374            // Update momentum by another half-step using the new gradients.
375            mom.inplace(|_mom| {
376                _mom.add(Tensor::<B, 2>::from_inner(
377                    grads2.mul_scalar(self.step_size * half),
378                ))
379            });
380        }
381
382        // Compute final log probability at the updated positions.
383        let logp_final = self.target.log_prob_batch(&pos);
384        (pos.detach(), mom.detach(), logp_final.detach())
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use crate::dev_tools::Timer;
391
392    use super::*;
393    use burn::{
394        backend::{Autodiff, NdArray},
395        tensor::{Element, Tensor},
396    };
397    use num_traits::Float;
398
399    // Define the Rosenbrock distribution.
400    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
401    struct Rosenbrock2D<T: Float> {
402        a: T,
403        b: T,
404    }
405
406    // For the batched version we need to implement BatchGradientTarget.
407    impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
408    where
409        T: Float + std::fmt::Debug + Element,
410        B: burn::tensor::backend::AutodiffBackend,
411    {
412        fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
413            let n = positions.dims()[0] as i64;
414            let x = positions.clone().slice([(0, n), (0, 1)]);
415            let y = positions.clone().slice([(0, n), (1, 2)]);
416
417            // Compute (a - x)^2 in place.
418            let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
419
420            // Compute (y - x^2)^2 in place.
421            let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
422
423            // Return the negative sum as a flattened 1D tensor.
424            -(term_1 + term_2).flatten(0, 1)
425        }
426    }
427
428    // Define the Rosenbrock distribution.
429    // From: https://arxiv.org/pdf/1903.09556.
430    struct RosenbrockND {}
431
432    // For the batched version we need to implement BatchGradientTarget.
433    impl<T, B> GradientTarget<T, B> for RosenbrockND
434    where
435        T: Float + std::fmt::Debug + Element,
436        B: burn::tensor::backend::AutodiffBackend,
437    {
438        fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
439            let k = positions.dims()[0] as i64;
440            let n = positions.dims()[1] as i64;
441            let low = positions.clone().slice([(0, k), (0, (n - 1))]);
442            let high = positions.clone().slice([(0, k), (1, n)]);
443            let term_1 = (high - low.clone().powi_scalar(2))
444                .powi_scalar(2)
445                .mul_scalar(100);
446            let term_2 = low.neg().add_scalar(1).powi_scalar(2);
447            -(term_1 + term_2).sum_dim(1).squeeze(1)
448        }
449    }
450
451    #[test]
452    fn test_collect_hmc_samples_single() {
453        // Use the CPU backend (NdArray) wrapped in Autodiff.
454        type BackendType = Autodiff<NdArray>;
455
456        // Create the Rosenbrock target (a = 1, b = 100)
457        let target = Rosenbrock2D {
458            a: 1.0_f32,
459            b: 100.0_f32,
460        };
461
462        // Define initial positions for a single chain (2-dimensional).
463        let initial_positions = vec![vec![0.0_f32, 0.0]];
464        let n_steps = 3;
465
466        // Create the HMC sampler.
467        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
468            target,
469            initial_positions,
470            0.01, // step size
471            2,    // number of leapfrog steps per update
472        )
473        .set_seed(42);
474
475        // Run the sampler for n_steps with no discard.
476        let mut timer = Timer::new();
477        let samples: Tensor<BackendType, 3> = sampler.run(n_steps, 0);
478        timer.log(format!(
479            "Collected samples (10 chains) with shape: {:?}",
480            samples.dims()
481        ))
482    }
483
484    #[test]
485    fn test_collect_hmc_samples_10_chains() {
486        // Use the CPU backend (NdArray) wrapped in Autodiff.
487        type BackendType = Autodiff<NdArray>;
488
489        // Create the Rosenbrock target (a = 1, b = 100)
490        let target = Rosenbrock2D {
491            a: 1.0_f32,
492            b: 100.0_f32,
493        };
494
495        // Define 10 chains all initialized to (1.0, 2.0).
496        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
497        let n_steps = 1000;
498
499        // Create the HMC sampler.
500        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
501            target,
502            initial_positions,
503            0.01, // step size
504            10,   // number of leapfrog steps per update
505        )
506        .set_seed(42);
507
508        // Run the sampler for n_steps with no discard.
509        let mut timer = Timer::new();
510        let samples: Tensor<BackendType, 3> = sampler.run(n_steps, 0);
511        timer.log(format!(
512            "Collected samples (10 chains) with shape: {:?}",
513            samples.dims()
514        ))
515    }
516
517    #[test]
518    fn test_collect_hmc_samples_10_chains_progress() {
519        // Use the CPU backend (NdArray) wrapped in Autodiff.
520        type BackendType = Autodiff<NdArray>;
521
522        // Create the Rosenbrock target (a = 1, b = 100)
523        let target = Rosenbrock2D {
524            a: 1.0_f32,
525            b: 100.0_f32,
526        };
527
528        // Define 10 chains all initialized to (1.0, 2.0).
529        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
530        let n_steps = 1000;
531
532        // Create the HMC sampler.
533        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
534            target,
535            initial_positions,
536            0.05, // step size
537            10,   // number of leapfrog steps per update
538        )
539        .set_seed(42);
540
541        // Run the sampler for n_steps with no discard.
542        let mut timer = Timer::new();
543        let samples: Tensor<BackendType, 3> = sampler.run_progress(n_steps, 100);
544        timer.log(format!(
545            "Collected samples (10 chains) with shape: {:?}",
546            samples.dims()
547        ))
548    }
549
550    #[test]
551    #[ignore = "Benchmark test: run only when explicitly requested"]
552    fn test_collect_hmc_samples_benchmark() {
553        // Use the CPU backend (NdArray) wrapped in Autodiff.
554        type BackendType = Autodiff<burn::backend::NdArray>;
555
556        // Create the Rosenbrock target (a = 1, b = 100)
557        let target = Rosenbrock2D {
558            a: 1.0_f32,
559            b: 100.0_f32,
560        };
561
562        // We'll define 6 chains all initialized to (1.0, 2.0).
563        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 6];
564        let n_steps = 5000;
565
566        // Create the data-parallel HMC sampler.
567        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
568            target,
569            initial_positions,
570            0.01, // step size
571            50,   // number of leapfrog steps per update
572        )
573        .set_seed(42);
574
575        // Run HMC for n_steps, collecting samples.
576        let mut timer = Timer::new();
577        let samples = sampler.run(n_steps, 0);
578        timer.log(format!(
579            "HMC sampler: generated {} samples.",
580            samples.dims()[0..2].iter().product::<usize>()
581        ))
582    }
583
584    #[test]
585    #[ignore = "Benchmark test: run only when explicitly requested"]
586    fn test_benchmark_10000d() {
587        // Use the CPU backend (NdArray) wrapped in Autodiff.
588        type BackendType = Autodiff<burn::backend::Wgpu>;
589
590        let seed = 42;
591        let d = 10000;
592
593        let rng = SmallRng::seed_from_u64(seed);
594        // We'll define 6 chains all initialized to (1.0, 2.0).
595        let initial_positions: Vec<Vec<f32>> =
596            vec![rng.sample_iter(StandardNormal).take(d).collect(); 6];
597        let n_steps = 500;
598
599        // Create the data-parallel HMC sampler.
600        let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
601            RosenbrockND {},
602            initial_positions,
603            0.01, // step size
604            50,   // number of leapfrog steps per update
605        )
606        .set_seed(42);
607
608        // Run HMC for n_steps, collecting samples.
609        let mut timer = Timer::new();
610        let samples = sampler.run(n_steps, 0);
611        timer.log(format!(
612            "HMC sampler: generated {} samples.",
613            samples.dims()[0..2].iter().product::<usize>()
614        ))
615    }
616}