mini_mcmc/
hmc.rs

1//! Hamiltonian Monte Carlo (HMC) sampler.
2//!
3//! This is modeled similarly to a Metropolis–Hastings sampler but uses gradient-based proposals
4//! for improved efficiency. The sampler works in a data-parallel fashion and can update multiple
5//! chains simultaneously.
6//!
7//! The code relies on a target distribution provided via the `GradientTarget` trait, which computes
8//! the unnormalized log probability for a batch of positions. The HMC implementation uses the leapfrog
9//! integrator to simulate Hamiltonian dynamics, and the standard accept/reject step for proposal
10//! validation.
11
12use crate::stats::RhatMulti;
13use burn::prelude::*;
14use burn::tensor::backend::AutodiffBackend;
15use burn::tensor::cast::ToElement;
16use burn::tensor::Tensor;
17use indicatif::{ProgressBar, ProgressStyle};
18use num_traits::Float;
19use rand::prelude::*;
20use rand::Rng;
21use rand_distr::StandardNormal;
22use std::collections::VecDeque;
23use std::error::Error;
24
25/// A batched target trait for computing the unnormalized log probability (and gradients) for a
26/// collection of positions.
27///
28/// Implement this trait for your target distribution to enable gradient-based sampling.
29///
30/// # Type Parameters
31///
32/// * `T`: The floating-point type (e.g., f32 or f64).
33/// * `B`: The autodiff backend from the `burn` crate.
34pub trait GradientTarget<T: Float, B: AutodiffBackend> {
35    /// Compute the log probability for a batch of positions.
36    ///
37    /// # Parameters
38    ///
39    /// * `positions`: A tensor of shape `[n_chains, D]` representing the current positions for each chain.
40    ///
41    /// # Returns
42    ///
43    /// A 1D tensor of shape `[n_chains]` containing the log probabilities for each chain.
44    fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1>;
45}
46
47/// A data-parallel Hamiltonian Monte Carlo (HMC) sampler.
48///
49/// This struct encapsulates the HMC algorithm, including the leapfrog integrator and the
50/// accept/reject mechanism, for sampling from a target distribution in a batched manner.
51///
52/// # Type Parameters
53///
54/// * `T`: Floating-point type for numerical calculations.
55/// * `B`: Autodiff backend from the `burn` crate.
56/// * `GTarget`: The target distribution type implementing the `GradientTarget` trait.
57#[derive(Debug, Clone)]
58pub struct HMC<T, B, GTarget>
59where
60    B: AutodiffBackend,
61{
62    /// The target distribution which provides log probability evaluations and gradients.
63    pub target: GTarget,
64    /// The step size for the leapfrog integrator.
65    pub step_size: T,
66    /// The number of leapfrog steps to take per HMC update.
67    pub n_leapfrog: usize,
68    /// The current positions for all chains, stored as a tensor of shape `[n_chains, D]`.
69    pub positions: Tensor<B, 2>,
70    /// A random number generator for sampling momenta and uniform random numbers for the
71    /// Metropolis acceptance test.
72    pub rng: SmallRng,
73}
74
75impl<T, B, GTarget> HMC<T, B, GTarget>
76where
77    T: Float
78        + burn::tensor::ElementConversion
79        + burn::tensor::Element
80        + rand_distr::uniform::SampleUniform
81        + num_traits::FromPrimitive,
82    B: AutodiffBackend,
83    GTarget: GradientTarget<T, B> + std::marker::Sync,
84    StandardNormal: rand::distributions::Distribution<T>,
85    rand_distr::Standard: rand_distr::Distribution<T>,
86{
87    /// Create a new data-parallel HMC sampler.
88    ///
89    /// This method initializes the sampler with the target distribution, initial positions,
90    /// step size, number of leapfrog steps, and a random seed for reproducibility.
91    ///
92    /// # Parameters
93    ///
94    /// * `target`: The target distribution implementing the `GradientTarget` trait.
95    /// * `initial_positions`: A vector of vectors containing the initial positions for each chain,
96    ///    with shape `[n_chains][D]`.
97    /// * `step_size`: The step size used in the leapfrog integrator.
98    /// * `n_leapfrog`: The number of leapfrog steps per update.
99    /// * `seed`: A seed for initializing the random number generator.
100    ///
101    /// # Returns
102    ///
103    /// A new instance of `HMC`.
104    pub fn new(
105        target: GTarget,
106        initial_positions: Vec<Vec<T>>,
107        step_size: T,
108        n_leapfrog: usize,
109    ) -> Self {
110        // Build a [n_chains, D] tensor from the flattened initial positions.
111        let (n_chains, dim) = (initial_positions.len(), initial_positions[0].len());
112        let td: TensorData = TensorData::new(
113            initial_positions.into_iter().flatten().collect(),
114            [n_chains, dim],
115        );
116        let positions = Tensor::<B, 2>::from_data(td, &B::Device::default());
117        let rng = SmallRng::seed_from_u64(thread_rng().gen::<u64>());
118        Self {
119            target,
120            step_size,
121            n_leapfrog,
122            positions,
123            rng,
124        }
125    }
126
127    /// Sets a new random seed.
128    ///
129    /// This method ensures reproducibility across runs.
130    ///
131    /// # Arguments
132    ///
133    /// * `seed` - The new random seed value.
134    pub fn set_seed(mut self, seed: u64) -> Self {
135        self.rng = SmallRng::seed_from_u64(seed);
136        self
137    }
138
139    /// Run the HMC sampler for `n_collect` + `n_discard` steps.
140    ///
141    /// First, the sampler takes `n_discard` burn-in steps, then takes
142    /// `n_collect` further steps and collects those samples in a 3D tensor of
143    /// shape `[n_collect, n_chains, D]`.
144    ///
145    /// # Parameters
146    ///
147    /// * `n_collect` - The number of samples to collect and return.
148    /// * `n_discard` - The number of samples to discard (burn-in).
149    ///
150    /// # Returns
151    ///
152    /// A tensor containing the collected samples.
153    pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3> {
154        let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
155        let mut out = Tensor::<B, 3>::empty(
156            [n_collect, n_chains, self.positions.dims()[1]],
157            &B::Device::default(),
158        );
159
160        // Discard the first `discard` positions.
161        (0..n_discard).for_each(|_| self.step());
162
163        // Collect samples.
164        for step in 1..(n_collect + 1) {
165            self.step();
166            out.inplace(|_out| {
167                _out.slice_assign(
168                    [step - 1..step, 0..n_chains, 0..dim],
169                    self.positions.clone().unsqueeze_dim(0),
170                )
171            });
172        }
173        out
174    }
175
176    /// Run the HMC sampler for `n_collect` + `n_discard` steps and displays progress with
177    /// convergence statistics.
178    ///
179    /// First, the sampler takes `n_discard` burn-in steps, then takes
180    /// `n_collect` further steps and collects those samples in a 3D tensor of
181    /// shape `[n_collect, n_chains, D]`.
182    ///
183    /// This function displays a progress bar (using the `indicatif` crate) that is updated
184    /// with an approximate acceptance probability computed over a sliding window of 100 iterations
185    /// as well as the potential scale reduction factor, see [Stan Reference Manual.][1]
186    ///
187    /// # Parameters
188    ///
189    /// * `n_collect` - The number of samples to collect and return.
190    /// * `n_discard` - The number of samples to discard (burn-in).
191    ///
192    /// # Returns
193    ///
194    /// A tensor of shape `[n_collect, n_chains, D]` containing the collected samples.
195    ///
196    /// [1]: https://mc-stan.org/docs/2_18/reference-manual/notation-for-samples-chains-and-draws.html
197    pub fn run_progress(
198        &mut self,
199        n_collect: usize,
200        n_discard: usize,
201    ) -> Result<Tensor<B, 3>, Box<dyn Error>> {
202        // Discard initial burn-in samples.
203        (0..n_discard).for_each(|_| self.step());
204
205        let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
206        let mut out = Tensor::<B, 3>::empty([n_collect, n_chains, dim], &B::Device::default());
207
208        let pb = ProgressBar::new(n_collect as u64);
209        pb.set_style(
210            ProgressStyle::default_bar()
211                .template("{prefix:8} {bar:40.white} ETA {eta:3} | {msg}")
212                .unwrap()
213                .progress_chars("=>-"),
214        );
215        pb.set_prefix("HMC");
216
217        // Use a sliding window of 100 iterations to estimate the acceptance probability.
218        let window_size = 100;
219        let mut accept_window: VecDeque<f32> = VecDeque::with_capacity(window_size);
220
221        let mut psr = RhatMulti::new(n_chains, dim);
222
223        let mut last_state = self.positions.clone();
224
225        let mut last_state_data = last_state.to_data();
226        psr.step(last_state_data.as_slice::<T>().unwrap())?;
227
228        for i in 0..n_collect {
229            self.step();
230            let current_state = self.positions.clone();
231
232            // For each chain, check if its state changed.
233            let accepted_count = last_state
234                .clone()
235                .not_equal(current_state.clone())
236                .all_dim(1)
237                .int()
238                .sum()
239                .into_scalar()
240                .to_f32();
241
242            let iter_accept_rate = accepted_count / n_chains as f32;
243
244            // Update the sliding window.
245            accept_window.push_front(iter_accept_rate);
246            if accept_window.len() > window_size {
247                accept_window.pop_back();
248            }
249
250            // Store the current state.
251            out.inplace(|_out| {
252                _out.slice_assign(
253                    [i..i + 1, 0..n_chains, 0..dim],
254                    current_state.clone().unsqueeze_dim(0),
255                )
256            });
257            pb.inc(1);
258            last_state = current_state;
259
260            last_state_data = last_state.to_data();
261            psr.step(last_state_data.as_slice::<T>().unwrap())?;
262            let maxrhat = psr.max()?;
263
264            // Compute average acceptance rate over the sliding window.
265            let avg_accept_rate: f32 =
266                accept_window.iter().sum::<f32>() / accept_window.len() as f32;
267            pb.set_message(format!(
268                "p(accept)≈{:.2} max(rhat)≈{:.2}",
269                avg_accept_rate, maxrhat
270            ));
271        }
272        pb.finish_with_message("Done!");
273        Ok(out)
274    }
275
276    /// Perform one batched HMC update for all chains in parallel.
277    ///
278    /// The update consists of:
279    /// 1) Sampling momenta from a standard normal distribution.
280    /// 2) Running the leapfrog integrator to propose new positions.
281    /// 3) Performing an accept/reject step for each chain.
282    ///
283    /// This method updates `self.positions` in-place.
284    pub fn step(&mut self) {
285        let shape = self.positions.shape();
286        let (n_chains, dim) = (shape.dims[0], shape.dims[1]);
287
288        // 1) Sample momenta: shape [n_chains, D]
289        let momentum_0 = Tensor::<B, 2>::random(
290            Shape::new([n_chains, dim]),
291            burn::tensor::Distribution::Normal(0., 1.),
292            &B::Device::default(),
293        );
294
295        // Current log probability: shape [n_chains]
296        let logp_current = self.target.log_prob_batch(&self.positions);
297
298        // Compute kinetic energy: 0.5 * sum_{d} (p^2) for each chain.
299        let ke_current = momentum_0
300            .clone()
301            .powf_scalar(2.0)
302            .sum_dim(1) // Sum over dimension 1 => shape [n_chains]
303            .squeeze(1)
304            .mul_scalar(T::from(0.5).unwrap());
305
306        // Compute the Hamiltonian: -logp + kinetic energy, shape [n_chains]
307        let h_current: Tensor<B, 1> = -logp_current + ke_current;
308
309        // 2) Run the leapfrog integrator.
310        let (proposed_positions, proposed_momenta, logp_proposed) =
311            self.leapfrog(self.positions.clone(), momentum_0);
312
313        // Compute proposed kinetic energy.
314        let ke_proposed = proposed_momenta
315            .powf_scalar(2.0)
316            .sum_dim(1)
317            .squeeze(1)
318            .mul_scalar(T::from(0.5).unwrap());
319
320        let h_proposed = -logp_proposed + ke_proposed;
321
322        // 3) Accept/Reject each proposal.
323        let accept_logp = h_current.sub(h_proposed);
324
325        // Draw a uniform random number for each chain.
326        let mut uniform_data = Vec::with_capacity(n_chains);
327        for _ in 0..n_chains {
328            uniform_data.push(self.rng.gen::<T>());
329        }
330        let uniform = Tensor::<B, 1>::random(
331            Shape::new([n_chains]),
332            burn::tensor::Distribution::Default,
333            &B::Device::default(),
334        );
335
336        // Accept the proposal if accept_logp >= ln(u).
337        let ln_u = uniform.log(); // shape [n_chains]
338        let accept_mask = accept_logp.greater_equal(ln_u); // Boolean mask of shape [n_chains]
339        let mut accept_mask_big: Tensor<B, 2, Bool> = accept_mask.clone().unsqueeze_dim(1);
340        accept_mask_big = accept_mask_big.expand([n_chains, dim]);
341
342        // Update positions: for accepted chains, replace current positions with proposed positions.
343        self.positions.inplace(|x| {
344            x.clone()
345                .mask_where(accept_mask_big, proposed_positions)
346                .detach()
347        });
348    }
349
350    /// Perform the leapfrog integrator steps in a batched manner.
351    ///
352    /// This method performs `n_leapfrog` iterations of the leapfrog update:
353    /// - A half-step update of the momentum.
354    /// - A full-step update of the positions.
355    /// - Another half-step update of the momentum.
356    ///
357    /// # Parameters
358    ///
359    /// * `pos`: The current positions, a tensor of shape `[n_chains, D]`.
360    /// * `mom`: The initial momenta, a tensor of shape `[n_chains, D]`.
361    ///
362    /// # Returns
363    ///
364    /// A tuple containing:
365    /// - The new positions (tensor of shape `[n_chains, D]`),
366    /// - The new momenta (tensor of shape `[n_chains, D]`),
367    /// - The log probability evaluated at the new positions (tensor of shape `[n_chains]`).
368    fn leapfrog(
369        &mut self,
370        mut pos: Tensor<B, 2>,
371        mut mom: Tensor<B, 2>,
372    ) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 1>) {
373        let half = T::from(0.5).unwrap();
374        for _step_i in 0..self.n_leapfrog {
375            // Detach pos to ensure it's AD-enabled for the gradient computation.
376            pos = pos.detach().require_grad();
377
378            // Compute gradient of log probability with respect to pos (batched over chains).
379            let logp = self.target.log_prob_batch(&pos); // shape [n_chains]
380            let grads = pos.grad(&logp.backward()).unwrap();
381
382            // Update momentum by a half-step using the computed gradients.
383            mom.inplace(|_mom| {
384                _mom.add(Tensor::<B, 2>::from_inner(
385                    grads.mul_scalar(self.step_size * half),
386                ))
387            });
388
389            // Full-step update for positions.
390            pos.inplace(|_pos| {
391                _pos.add(mom.clone().mul_scalar(self.step_size))
392                    .detach()
393                    .require_grad()
394            });
395
396            // Compute gradient at the new positions.
397            let logp2 = self.target.log_prob_batch(&pos);
398            let grads2 = pos.grad(&logp2.backward()).unwrap();
399
400            // Update momentum by another half-step using the new gradients.
401            mom.inplace(|_mom| {
402                _mom.add(Tensor::<B, 2>::from_inner(
403                    grads2.mul_scalar(self.step_size * half),
404                ))
405            });
406        }
407
408        // Compute final log probability at the updated positions.
409        let logp_final = self.target.log_prob_batch(&pos);
410        (pos.detach(), mom.detach(), logp_final.detach())
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use crate::dev_tools::Timer;
417
418    use super::*;
419    use burn::{
420        backend::{Autodiff, NdArray},
421        tensor::{Element, Tensor},
422    };
423    use num_traits::Float;
424
425    // Define the Rosenbrock distribution.
426    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
427    struct Rosenbrock2D<T: Float> {
428        a: T,
429        b: T,
430    }
431
432    // For the batched version we need to implement BatchGradientTarget.
433    impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
434    where
435        T: Float + std::fmt::Debug + Element,
436        B: burn::tensor::backend::AutodiffBackend,
437    {
438        fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
439            let n = positions.dims()[0] as i64;
440            let x = positions.clone().slice([(0, n), (0, 1)]);
441            let y = positions.clone().slice([(0, n), (1, 2)]);
442
443            // Compute (a - x)^2 in place.
444            let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
445
446            // Compute (y - x^2)^2 in place.
447            let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
448
449            // Return the negative sum as a flattened 1D tensor.
450            -(term_1 + term_2).flatten(0, 1)
451        }
452    }
453
454    // Define the Rosenbrock distribution.
455    // From: https://arxiv.org/pdf/1903.09556.
456    struct RosenbrockND {}
457
458    // For the batched version we need to implement BatchGradientTarget.
459    impl<T, B> GradientTarget<T, B> for RosenbrockND
460    where
461        T: Float + std::fmt::Debug + Element,
462        B: burn::tensor::backend::AutodiffBackend,
463    {
464        fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
465            let k = positions.dims()[0] as i64;
466            let n = positions.dims()[1] as i64;
467            let low = positions.clone().slice([(0, k), (0, (n - 1))]);
468            let high = positions.clone().slice([(0, k), (1, n)]);
469            let term_1 = (high - low.clone().powi_scalar(2))
470                .powi_scalar(2)
471                .mul_scalar(100);
472            let term_2 = low.neg().add_scalar(1).powi_scalar(2);
473            -(term_1 + term_2).sum_dim(1).squeeze(1)
474        }
475    }
476
477    #[test]
478    fn test_single() {
479        // Use the CPU backend (NdArray) wrapped in Autodiff.
480        type BackendType = Autodiff<NdArray>;
481
482        // Create the Rosenbrock target (a = 1, b = 100)
483        let target = Rosenbrock2D {
484            a: 1.0_f32,
485            b: 100.0_f32,
486        };
487
488        // Define initial positions for a single chain (2-dimensional).
489        let initial_positions = vec![vec![0.0_f32, 0.0]];
490        let n_collect = 3;
491
492        // Create the HMC sampler.
493        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
494            target,
495            initial_positions,
496            0.01, // step size
497            2,    // number of leapfrog steps per update
498        )
499        .set_seed(42);
500
501        // Run the sampler for n_collect steps.
502        let mut timer = Timer::new();
503        let samples: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
504        timer.log(format!(
505            "Collected samples (10 chains) with shape: {:?}",
506            samples.dims()
507        ))
508    }
509
510    #[test]
511    fn test_10_chains() {
512        // Use the CPU backend (NdArray) wrapped in Autodiff.
513        type BackendType = Autodiff<NdArray>;
514
515        // Create the Rosenbrock target (a = 1, b = 100)
516        let target = Rosenbrock2D {
517            a: 1.0_f32,
518            b: 100.0_f32,
519        };
520
521        // Define 10 chains all initialized to (1.0, 2.0).
522        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
523        let n_collect = 1000;
524
525        // Create the HMC sampler.
526        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
527            target,
528            initial_positions,
529            0.01, // step size
530            10,   // number of leapfrog steps per update
531        )
532        .set_seed(42);
533
534        // Run the sampler for n_collect.
535        let mut timer = Timer::new();
536        let samples: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
537        timer.log(format!(
538            "Collected samples (10 chains) with shape: {:?}",
539            samples.dims()
540        ))
541    }
542
543    #[test]
544    fn test_progress_10_chains() {
545        // Use the CPU backend (NdArray) wrapped in Autodiff.
546        type BackendType = Autodiff<NdArray>;
547
548        // Create the Rosenbrock target (a = 1, b = 100)
549        let target = Rosenbrock2D {
550            a: 1.0_f32,
551            b: 100.0_f32,
552        };
553
554        // Define 10 chains all initialized to (1.0, 2.0).
555        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
556        let n_collect = 1000;
557
558        // Create the HMC sampler.
559        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
560            target,
561            initial_positions,
562            0.05, // step size
563            10,   // number of leapfrog steps per update
564        )
565        .set_seed(42);
566
567        // Run the sampler for n_collect with no discard.
568        let mut timer = Timer::new();
569        let samples: Tensor<BackendType, 3> = sampler.run_progress(n_collect, 100).unwrap();
570        timer.log(format!(
571            "Collected samples (10 chains) with shape: {:?}",
572            samples.dims()
573        ))
574    }
575
576    #[test]
577    #[ignore = "Benchmark test: run only when explicitly requested"]
578    fn test_bench() {
579        // Use the CPU backend (NdArray) wrapped in Autodiff.
580        type BackendType = Autodiff<burn::backend::NdArray>;
581
582        // Create the Rosenbrock target (a = 1, b = 100)
583        let target = Rosenbrock2D {
584            a: 1.0_f32,
585            b: 100.0_f32,
586        };
587
588        // We'll define 6 chains all initialized to (1.0, 2.0).
589        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 6];
590        let n_collect = 5000;
591
592        // Create the data-parallel HMC sampler.
593        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
594            target,
595            initial_positions,
596            0.01, // step size
597            50,   // number of leapfrog steps per update
598        )
599        .set_seed(42);
600
601        // Run HMC for `n_collect` steps.
602        let mut timer = Timer::new();
603        let samples = sampler.run(n_collect, 0);
604        timer.log(format!(
605            "HMC sampler: generated {} samples.",
606            samples.dims()[0..2].iter().product::<usize>()
607        ))
608    }
609
610    #[test]
611    #[ignore = "Benchmark test: run only when explicitly requested"]
612    fn test_progress_bench() {
613        // Use the CPU backend (NdArray) wrapped in Autodiff.
614        type BackendType = Autodiff<burn::backend::NdArray>;
615
616        // Create the Rosenbrock target (a = 1, b = 100)
617        let target = Rosenbrock2D {
618            a: 1.0_f32,
619            b: 100.0_f32,
620        };
621
622        // We'll define 6 chains all initialized to (1.0, 2.0).
623        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 6];
624        let n_collect = 5000;
625
626        // Create the data-parallel HMC sampler.
627        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
628            target,
629            initial_positions,
630            0.01, // step size
631            50,   // number of leapfrog steps per update
632        )
633        .set_seed(42);
634
635        // Run HMC for n_collect steps.
636        let mut timer = Timer::new();
637        let samples = sampler.run_progress(n_collect, 0).unwrap();
638        timer.log(format!(
639            "HMC sampler: generated {} samples.",
640            samples.dims()[0..2].iter().product::<usize>()
641        ))
642    }
643
644    #[test]
645    #[ignore = "Benchmark test: run only when explicitly requested"]
646    fn test_bench_10000d() {
647        // Use the CPU backend (NdArray) wrapped in Autodiff.
648        type BackendType = Autodiff<burn::backend::NdArray>;
649
650        let seed = 42;
651        let d = 10000;
652
653        let rng = SmallRng::seed_from_u64(seed);
654        // We'll define 6 chains all initialized to (1.0, 2.0).
655        let initial_positions: Vec<Vec<f32>> =
656            vec![rng.sample_iter(StandardNormal).take(d).collect(); 6];
657        let n_collect = 500;
658
659        // Create the data-parallel HMC sampler.
660        let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
661            RosenbrockND {},
662            initial_positions,
663            0.01, // step size
664            50,   // number of leapfrog steps per update
665        )
666        .set_seed(42);
667
668        // Run HMC for n_collect steps.
669        let mut timer = Timer::new();
670        let samples = sampler.run(n_collect, 0);
671        timer.log(format!(
672            "HMC sampler: generated {} samples.",
673            samples.dims()[0..2].iter().product::<usize>()
674        ))
675    }
676
677    #[test]
678    #[ignore = "Benchmark test: run only when explicitly requested"]
679    #[cfg(feature = "wgpu")]
680    fn test_progress_10000d_bench() {
681        type BackendType = Autodiff<burn::backend::Wgpu>;
682
683        let seed = 42;
684        let d = 10000;
685
686        let rng = SmallRng::seed_from_u64(seed);
687        // We'll define 6 chains all initialized to (1.0, 2.0).
688        let initial_positions: Vec<Vec<f32>> =
689            vec![rng.sample_iter(StandardNormal).take(d).collect(); 6];
690        let n_collect = 5000;
691
692        // Create the data-parallel HMC sampler.
693        let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
694            RosenbrockND {},
695            initial_positions,
696            0.01, // step size
697            50,   // number of leapfrog steps per update
698        )
699        .set_seed(42);
700
701        // Run HMC for n_collect steps.
702        let mut timer = Timer::new();
703        let samples = sampler.run_progress(n_collect, 0).unwrap();
704        timer.log(format!(
705            "HMC sampler: generated {} samples.",
706            samples.dims()[0..2].iter().product::<usize>()
707        ))
708    }
709}