mini_mcmc/hmc.rs
1//! Hamiltonian Monte Carlo (HMC) sampler.
2//!
3//! This is modeled similarly to a Metropolis–Hastings sampler but uses gradient-based proposals
4//! for improved efficiency. The sampler works in a data-parallel fashion and can update multiple
5//! chains simultaneously.
6//!
7//! The code relies on a target distribution provided via the `GradientTarget` trait, which computes
8//! the unnormalized log probability for a batch of positions. The HMC implementation uses the leapfrog
9//! integrator to simulate Hamiltonian dynamics, and the standard accept/reject step for proposal
10//! validation.
11
12use crate::stats::RhatMulti;
13use burn::prelude::*;
14use burn::tensor::backend::AutodiffBackend;
15use burn::tensor::cast::ToElement;
16use burn::tensor::Tensor;
17use indicatif::{ProgressBar, ProgressStyle};
18use num_traits::Float;
19use rand::prelude::*;
20use rand::Rng;
21use rand_distr::StandardNormal;
22use std::collections::VecDeque;
23use std::error::Error;
24
25/// A batched target trait for computing the unnormalized log probability (and gradients) for a
26/// collection of positions.
27///
28/// Implement this trait for your target distribution to enable gradient-based sampling.
29///
30/// # Type Parameters
31///
32/// * `T`: The floating-point type (e.g., f32 or f64).
33/// * `B`: The autodiff backend from the `burn` crate.
34pub trait GradientTarget<T: Float, B: AutodiffBackend> {
35 /// Compute the log probability for a batch of positions.
36 ///
37 /// # Parameters
38 ///
39 /// * `positions`: A tensor of shape `[n_chains, D]` representing the current positions for each chain.
40 ///
41 /// # Returns
42 ///
43 /// A 1D tensor of shape `[n_chains]` containing the log probabilities for each chain.
44 fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1>;
45}
46
47/// A data-parallel Hamiltonian Monte Carlo (HMC) sampler.
48///
49/// This struct encapsulates the HMC algorithm, including the leapfrog integrator and the
50/// accept/reject mechanism, for sampling from a target distribution in a batched manner.
51///
52/// # Type Parameters
53///
54/// * `T`: Floating-point type for numerical calculations.
55/// * `B`: Autodiff backend from the `burn` crate.
56/// * `GTarget`: The target distribution type implementing the `GradientTarget` trait.
57#[derive(Debug, Clone)]
58pub struct HMC<T, B, GTarget>
59where
60 B: AutodiffBackend,
61{
62 /// The target distribution which provides log probability evaluations and gradients.
63 pub target: GTarget,
64 /// The step size for the leapfrog integrator.
65 pub step_size: T,
66 /// The number of leapfrog steps to take per HMC update.
67 pub n_leapfrog: usize,
68 /// The current positions for all chains, stored as a tensor of shape `[n_chains, D]`.
69 pub positions: Tensor<B, 2>,
70 /// A random number generator for sampling momenta and uniform random numbers for the
71 /// Metropolis acceptance test.
72 pub rng: SmallRng,
73}
74
75impl<T, B, GTarget> HMC<T, B, GTarget>
76where
77 T: Float
78 + burn::tensor::ElementConversion
79 + burn::tensor::Element
80 + rand_distr::uniform::SampleUniform
81 + num_traits::FromPrimitive,
82 B: AutodiffBackend,
83 GTarget: GradientTarget<T, B> + std::marker::Sync,
84 StandardNormal: rand::distributions::Distribution<T>,
85 rand_distr::Standard: rand_distr::Distribution<T>,
86{
87 /// Create a new data-parallel HMC sampler.
88 ///
89 /// This method initializes the sampler with the target distribution, initial positions,
90 /// step size, number of leapfrog steps, and a random seed for reproducibility.
91 ///
92 /// # Parameters
93 ///
94 /// * `target`: The target distribution implementing the `GradientTarget` trait.
95 /// * `initial_positions`: A vector of vectors containing the initial positions for each chain,
96 /// with shape `[n_chains][D]`.
97 /// * `step_size`: The step size used in the leapfrog integrator.
98 /// * `n_leapfrog`: The number of leapfrog steps per update.
99 /// * `seed`: A seed for initializing the random number generator.
100 ///
101 /// # Returns
102 ///
103 /// A new instance of `HMC`.
104 pub fn new(
105 target: GTarget,
106 initial_positions: Vec<Vec<T>>,
107 step_size: T,
108 n_leapfrog: usize,
109 ) -> Self {
110 // Build a [n_chains, D] tensor from the flattened initial positions.
111 let (n_chains, dim) = (initial_positions.len(), initial_positions[0].len());
112 let td: TensorData = TensorData::new(
113 initial_positions.into_iter().flatten().collect(),
114 [n_chains, dim],
115 );
116 let positions = Tensor::<B, 2>::from_data(td, &B::Device::default());
117 let rng = SmallRng::seed_from_u64(thread_rng().gen::<u64>());
118 Self {
119 target,
120 step_size,
121 n_leapfrog,
122 positions,
123 rng,
124 }
125 }
126
127 /// Sets a new random seed.
128 ///
129 /// This method ensures reproducibility across runs.
130 ///
131 /// # Arguments
132 ///
133 /// * `seed` - The new random seed value.
134 pub fn set_seed(mut self, seed: u64) -> Self {
135 self.rng = SmallRng::seed_from_u64(seed);
136 self
137 }
138
139 /// Run the HMC sampler for `n_collect` + `n_discard` steps.
140 ///
141 /// First, the sampler takes `n_discard` burn-in steps, then takes
142 /// `n_collect` further steps and collects those samples in a 3D tensor of
143 /// shape `[n_collect, n_chains, D]`.
144 ///
145 /// # Parameters
146 ///
147 /// * `n_collect` - The number of samples to collect and return.
148 /// * `n_discard` - The number of samples to discard (burn-in).
149 ///
150 /// # Returns
151 ///
152 /// A tensor containing the collected samples.
153 pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3> {
154 let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
155 let mut out = Tensor::<B, 3>::empty(
156 [n_collect, n_chains, self.positions.dims()[1]],
157 &B::Device::default(),
158 );
159
160 // Discard the first `discard` positions.
161 (0..n_discard).for_each(|_| self.step());
162
163 // Collect samples.
164 for step in 1..(n_collect + 1) {
165 self.step();
166 out.inplace(|_out| {
167 _out.slice_assign(
168 [step - 1..step, 0..n_chains, 0..dim],
169 self.positions.clone().unsqueeze_dim(0),
170 )
171 });
172 }
173 out
174 }
175
176 /// Run the HMC sampler for `n_collect` + `n_discard` steps and displays progress with
177 /// convergence statistics.
178 ///
179 /// First, the sampler takes `n_discard` burn-in steps, then takes
180 /// `n_collect` further steps and collects those samples in a 3D tensor of
181 /// shape `[n_collect, n_chains, D]`.
182 ///
183 /// This function displays a progress bar (using the `indicatif` crate) that is updated
184 /// with an approximate acceptance probability computed over a sliding window of 100 iterations
185 /// as well as the potential scale reduction factor, see [Stan Reference Manual.][1]
186 ///
187 /// # Parameters
188 ///
189 /// * `n_collect` - The number of samples to collect and return.
190 /// * `n_discard` - The number of samples to discard (burn-in).
191 ///
192 /// # Returns
193 ///
194 /// A tensor of shape `[n_collect, n_chains, D]` containing the collected samples.
195 ///
196 /// [1]: https://mc-stan.org/docs/2_18/reference-manual/notation-for-samples-chains-and-draws.html
197 pub fn run_progress(
198 &mut self,
199 n_collect: usize,
200 n_discard: usize,
201 ) -> Result<Tensor<B, 3>, Box<dyn Error>> {
202 // Discard initial burn-in samples.
203 (0..n_discard).for_each(|_| self.step());
204
205 let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
206 let mut out = Tensor::<B, 3>::empty([n_collect, n_chains, dim], &B::Device::default());
207
208 let pb = ProgressBar::new(n_collect as u64);
209 pb.set_style(
210 ProgressStyle::default_bar()
211 .template("{prefix:8} {bar:40.white} ETA {eta:3} | {msg}")
212 .unwrap()
213 .progress_chars("=>-"),
214 );
215 pb.set_prefix("HMC");
216
217 // Use a sliding window of 100 iterations to estimate the acceptance probability.
218 let window_size = 100;
219 let mut accept_window: VecDeque<f32> = VecDeque::with_capacity(window_size);
220
221 let mut psr = RhatMulti::new(n_chains, dim);
222
223 let mut last_state = self.positions.clone();
224
225 let mut last_state_data = last_state.to_data();
226 psr.step(last_state_data.as_slice::<T>().unwrap())?;
227
228 for i in 0..n_collect {
229 self.step();
230 let current_state = self.positions.clone();
231
232 // For each chain, check if its state changed.
233 let accepted_count = last_state
234 .clone()
235 .not_equal(current_state.clone())
236 .all_dim(1)
237 .int()
238 .sum()
239 .into_scalar()
240 .to_f32();
241
242 let iter_accept_rate = accepted_count / n_chains as f32;
243
244 // Update the sliding window.
245 accept_window.push_front(iter_accept_rate);
246 if accept_window.len() > window_size {
247 accept_window.pop_back();
248 }
249
250 // Store the current state.
251 out.inplace(|_out| {
252 _out.slice_assign(
253 [i..i + 1, 0..n_chains, 0..dim],
254 current_state.clone().unsqueeze_dim(0),
255 )
256 });
257 pb.inc(1);
258 last_state = current_state;
259
260 last_state_data = last_state.to_data();
261 psr.step(last_state_data.as_slice::<T>().unwrap())?;
262 let maxrhat = psr.max()?;
263
264 // Compute average acceptance rate over the sliding window.
265 let avg_accept_rate: f32 =
266 accept_window.iter().sum::<f32>() / accept_window.len() as f32;
267 pb.set_message(format!(
268 "p(accept)≈{:.2} max(rhat)≈{:.2}",
269 avg_accept_rate, maxrhat
270 ));
271 }
272 pb.finish_with_message("Done!");
273 Ok(out)
274 }
275
276 /// Perform one batched HMC update for all chains in parallel.
277 ///
278 /// The update consists of:
279 /// 1) Sampling momenta from a standard normal distribution.
280 /// 2) Running the leapfrog integrator to propose new positions.
281 /// 3) Performing an accept/reject step for each chain.
282 ///
283 /// This method updates `self.positions` in-place.
284 pub fn step(&mut self) {
285 let shape = self.positions.shape();
286 let (n_chains, dim) = (shape.dims[0], shape.dims[1]);
287
288 // 1) Sample momenta: shape [n_chains, D]
289 let momentum_0 = Tensor::<B, 2>::random(
290 Shape::new([n_chains, dim]),
291 burn::tensor::Distribution::Normal(0., 1.),
292 &B::Device::default(),
293 );
294
295 // Current log probability: shape [n_chains]
296 let logp_current = self.target.log_prob_batch(&self.positions);
297
298 // Compute kinetic energy: 0.5 * sum_{d} (p^2) for each chain.
299 let ke_current = momentum_0
300 .clone()
301 .powf_scalar(2.0)
302 .sum_dim(1) // Sum over dimension 1 => shape [n_chains]
303 .squeeze(1)
304 .mul_scalar(T::from(0.5).unwrap());
305
306 // Compute the Hamiltonian: -logp + kinetic energy, shape [n_chains]
307 let h_current: Tensor<B, 1> = -logp_current + ke_current;
308
309 // 2) Run the leapfrog integrator.
310 let (proposed_positions, proposed_momenta, logp_proposed) =
311 self.leapfrog(self.positions.clone(), momentum_0);
312
313 // Compute proposed kinetic energy.
314 let ke_proposed = proposed_momenta
315 .powf_scalar(2.0)
316 .sum_dim(1)
317 .squeeze(1)
318 .mul_scalar(T::from(0.5).unwrap());
319
320 let h_proposed = -logp_proposed + ke_proposed;
321
322 // 3) Accept/Reject each proposal.
323 let accept_logp = h_current.sub(h_proposed);
324
325 // Draw a uniform random number for each chain.
326 let mut uniform_data = Vec::with_capacity(n_chains);
327 for _ in 0..n_chains {
328 uniform_data.push(self.rng.gen::<T>());
329 }
330 let uniform = Tensor::<B, 1>::random(
331 Shape::new([n_chains]),
332 burn::tensor::Distribution::Default,
333 &B::Device::default(),
334 );
335
336 // Accept the proposal if accept_logp >= ln(u).
337 let ln_u = uniform.log(); // shape [n_chains]
338 let accept_mask = accept_logp.greater_equal(ln_u); // Boolean mask of shape [n_chains]
339 let mut accept_mask_big: Tensor<B, 2, Bool> = accept_mask.clone().unsqueeze_dim(1);
340 accept_mask_big = accept_mask_big.expand([n_chains, dim]);
341
342 // Update positions: for accepted chains, replace current positions with proposed positions.
343 self.positions.inplace(|x| {
344 x.clone()
345 .mask_where(accept_mask_big, proposed_positions)
346 .detach()
347 });
348 }
349
350 /// Perform the leapfrog integrator steps in a batched manner.
351 ///
352 /// This method performs `n_leapfrog` iterations of the leapfrog update:
353 /// - A half-step update of the momentum.
354 /// - A full-step update of the positions.
355 /// - Another half-step update of the momentum.
356 ///
357 /// # Parameters
358 ///
359 /// * `pos`: The current positions, a tensor of shape `[n_chains, D]`.
360 /// * `mom`: The initial momenta, a tensor of shape `[n_chains, D]`.
361 ///
362 /// # Returns
363 ///
364 /// A tuple containing:
365 /// - The new positions (tensor of shape `[n_chains, D]`),
366 /// - The new momenta (tensor of shape `[n_chains, D]`),
367 /// - The log probability evaluated at the new positions (tensor of shape `[n_chains]`).
368 fn leapfrog(
369 &mut self,
370 mut pos: Tensor<B, 2>,
371 mut mom: Tensor<B, 2>,
372 ) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 1>) {
373 let half = T::from(0.5).unwrap();
374 for _step_i in 0..self.n_leapfrog {
375 // Detach pos to ensure it's AD-enabled for the gradient computation.
376 pos = pos.detach().require_grad();
377
378 // Compute gradient of log probability with respect to pos (batched over chains).
379 let logp = self.target.log_prob_batch(&pos); // shape [n_chains]
380 let grads = pos.grad(&logp.backward()).unwrap();
381
382 // Update momentum by a half-step using the computed gradients.
383 mom.inplace(|_mom| {
384 _mom.add(Tensor::<B, 2>::from_inner(
385 grads.mul_scalar(self.step_size * half),
386 ))
387 });
388
389 // Full-step update for positions.
390 pos.inplace(|_pos| {
391 _pos.add(mom.clone().mul_scalar(self.step_size))
392 .detach()
393 .require_grad()
394 });
395
396 // Compute gradient at the new positions.
397 let logp2 = self.target.log_prob_batch(&pos);
398 let grads2 = pos.grad(&logp2.backward()).unwrap();
399
400 // Update momentum by another half-step using the new gradients.
401 mom.inplace(|_mom| {
402 _mom.add(Tensor::<B, 2>::from_inner(
403 grads2.mul_scalar(self.step_size * half),
404 ))
405 });
406 }
407
408 // Compute final log probability at the updated positions.
409 let logp_final = self.target.log_prob_batch(&pos);
410 (pos.detach(), mom.detach(), logp_final.detach())
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use crate::dev_tools::Timer;
417
418 use super::*;
419 use burn::{
420 backend::{Autodiff, NdArray},
421 tensor::{Element, Tensor},
422 };
423 use num_traits::Float;
424
425 // Define the Rosenbrock distribution.
426 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
427 struct Rosenbrock2D<T: Float> {
428 a: T,
429 b: T,
430 }
431
432 // For the batched version we need to implement BatchGradientTarget.
433 impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
434 where
435 T: Float + std::fmt::Debug + Element,
436 B: burn::tensor::backend::AutodiffBackend,
437 {
438 fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
439 let n = positions.dims()[0] as i64;
440 let x = positions.clone().slice([(0, n), (0, 1)]);
441 let y = positions.clone().slice([(0, n), (1, 2)]);
442
443 // Compute (a - x)^2 in place.
444 let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
445
446 // Compute (y - x^2)^2 in place.
447 let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
448
449 // Return the negative sum as a flattened 1D tensor.
450 -(term_1 + term_2).flatten(0, 1)
451 }
452 }
453
454 // Define the Rosenbrock distribution.
455 // From: https://arxiv.org/pdf/1903.09556.
456 struct RosenbrockND {}
457
458 // For the batched version we need to implement BatchGradientTarget.
459 impl<T, B> GradientTarget<T, B> for RosenbrockND
460 where
461 T: Float + std::fmt::Debug + Element,
462 B: burn::tensor::backend::AutodiffBackend,
463 {
464 fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1> {
465 let k = positions.dims()[0] as i64;
466 let n = positions.dims()[1] as i64;
467 let low = positions.clone().slice([(0, k), (0, (n - 1))]);
468 let high = positions.clone().slice([(0, k), (1, n)]);
469 let term_1 = (high - low.clone().powi_scalar(2))
470 .powi_scalar(2)
471 .mul_scalar(100);
472 let term_2 = low.neg().add_scalar(1).powi_scalar(2);
473 -(term_1 + term_2).sum_dim(1).squeeze(1)
474 }
475 }
476
477 #[test]
478 fn test_single() {
479 // Use the CPU backend (NdArray) wrapped in Autodiff.
480 type BackendType = Autodiff<NdArray>;
481
482 // Create the Rosenbrock target (a = 1, b = 100)
483 let target = Rosenbrock2D {
484 a: 1.0_f32,
485 b: 100.0_f32,
486 };
487
488 // Define initial positions for a single chain (2-dimensional).
489 let initial_positions = vec![vec![0.0_f32, 0.0]];
490 let n_collect = 3;
491
492 // Create the HMC sampler.
493 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
494 target,
495 initial_positions,
496 0.01, // step size
497 2, // number of leapfrog steps per update
498 )
499 .set_seed(42);
500
501 // Run the sampler for n_collect steps.
502 let mut timer = Timer::new();
503 let samples: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
504 timer.log(format!(
505 "Collected samples (10 chains) with shape: {:?}",
506 samples.dims()
507 ))
508 }
509
510 #[test]
511 fn test_10_chains() {
512 // Use the CPU backend (NdArray) wrapped in Autodiff.
513 type BackendType = Autodiff<NdArray>;
514
515 // Create the Rosenbrock target (a = 1, b = 100)
516 let target = Rosenbrock2D {
517 a: 1.0_f32,
518 b: 100.0_f32,
519 };
520
521 // Define 10 chains all initialized to (1.0, 2.0).
522 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
523 let n_collect = 1000;
524
525 // Create the HMC sampler.
526 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
527 target,
528 initial_positions,
529 0.01, // step size
530 10, // number of leapfrog steps per update
531 )
532 .set_seed(42);
533
534 // Run the sampler for n_collect.
535 let mut timer = Timer::new();
536 let samples: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
537 timer.log(format!(
538 "Collected samples (10 chains) with shape: {:?}",
539 samples.dims()
540 ))
541 }
542
543 #[test]
544 fn test_progress_10_chains() {
545 // Use the CPU backend (NdArray) wrapped in Autodiff.
546 type BackendType = Autodiff<NdArray>;
547
548 // Create the Rosenbrock target (a = 1, b = 100)
549 let target = Rosenbrock2D {
550 a: 1.0_f32,
551 b: 100.0_f32,
552 };
553
554 // Define 10 chains all initialized to (1.0, 2.0).
555 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 10];
556 let n_collect = 1000;
557
558 // Create the HMC sampler.
559 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
560 target,
561 initial_positions,
562 0.05, // step size
563 10, // number of leapfrog steps per update
564 )
565 .set_seed(42);
566
567 // Run the sampler for n_collect with no discard.
568 let mut timer = Timer::new();
569 let samples: Tensor<BackendType, 3> = sampler.run_progress(n_collect, 100).unwrap();
570 timer.log(format!(
571 "Collected samples (10 chains) with shape: {:?}",
572 samples.dims()
573 ))
574 }
575
576 #[test]
577 #[ignore = "Benchmark test: run only when explicitly requested"]
578 fn test_bench() {
579 // Use the CPU backend (NdArray) wrapped in Autodiff.
580 type BackendType = Autodiff<burn::backend::NdArray>;
581
582 // Create the Rosenbrock target (a = 1, b = 100)
583 let target = Rosenbrock2D {
584 a: 1.0_f32,
585 b: 100.0_f32,
586 };
587
588 // We'll define 6 chains all initialized to (1.0, 2.0).
589 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 6];
590 let n_collect = 5000;
591
592 // Create the data-parallel HMC sampler.
593 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
594 target,
595 initial_positions,
596 0.01, // step size
597 50, // number of leapfrog steps per update
598 )
599 .set_seed(42);
600
601 // Run HMC for `n_collect` steps.
602 let mut timer = Timer::new();
603 let samples = sampler.run(n_collect, 0);
604 timer.log(format!(
605 "HMC sampler: generated {} samples.",
606 samples.dims()[0..2].iter().product::<usize>()
607 ))
608 }
609
610 #[test]
611 #[ignore = "Benchmark test: run only when explicitly requested"]
612 fn test_progress_bench() {
613 // Use the CPU backend (NdArray) wrapped in Autodiff.
614 type BackendType = Autodiff<burn::backend::NdArray>;
615
616 // Create the Rosenbrock target (a = 1, b = 100)
617 let target = Rosenbrock2D {
618 a: 1.0_f32,
619 b: 100.0_f32,
620 };
621
622 // We'll define 6 chains all initialized to (1.0, 2.0).
623 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 6];
624 let n_collect = 5000;
625
626 // Create the data-parallel HMC sampler.
627 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
628 target,
629 initial_positions,
630 0.01, // step size
631 50, // number of leapfrog steps per update
632 )
633 .set_seed(42);
634
635 // Run HMC for n_collect steps.
636 let mut timer = Timer::new();
637 let samples = sampler.run_progress(n_collect, 0).unwrap();
638 timer.log(format!(
639 "HMC sampler: generated {} samples.",
640 samples.dims()[0..2].iter().product::<usize>()
641 ))
642 }
643
644 #[test]
645 #[ignore = "Benchmark test: run only when explicitly requested"]
646 fn test_bench_10000d() {
647 // Use the CPU backend (NdArray) wrapped in Autodiff.
648 type BackendType = Autodiff<burn::backend::NdArray>;
649
650 let seed = 42;
651 let d = 10000;
652
653 let rng = SmallRng::seed_from_u64(seed);
654 // We'll define 6 chains all initialized to (1.0, 2.0).
655 let initial_positions: Vec<Vec<f32>> =
656 vec![rng.sample_iter(StandardNormal).take(d).collect(); 6];
657 let n_collect = 500;
658
659 // Create the data-parallel HMC sampler.
660 let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
661 RosenbrockND {},
662 initial_positions,
663 0.01, // step size
664 50, // number of leapfrog steps per update
665 )
666 .set_seed(42);
667
668 // Run HMC for n_collect steps.
669 let mut timer = Timer::new();
670 let samples = sampler.run(n_collect, 0);
671 timer.log(format!(
672 "HMC sampler: generated {} samples.",
673 samples.dims()[0..2].iter().product::<usize>()
674 ))
675 }
676
677 #[test]
678 #[ignore = "Benchmark test: run only when explicitly requested"]
679 #[cfg(feature = "wgpu")]
680 fn test_progress_10000d_bench() {
681 type BackendType = Autodiff<burn::backend::Wgpu>;
682
683 let seed = 42;
684 let d = 10000;
685
686 let rng = SmallRng::seed_from_u64(seed);
687 // We'll define 6 chains all initialized to (1.0, 2.0).
688 let initial_positions: Vec<Vec<f32>> =
689 vec![rng.sample_iter(StandardNormal).take(d).collect(); 6];
690 let n_collect = 5000;
691
692 // Create the data-parallel HMC sampler.
693 let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
694 RosenbrockND {},
695 initial_positions,
696 0.01, // step size
697 50, // number of leapfrog steps per update
698 )
699 .set_seed(42);
700
701 // Run HMC for n_collect steps.
702 let mut timer = Timer::new();
703 let samples = sampler.run_progress(n_collect, 0).unwrap();
704 timer.log(format!(
705 "HMC sampler: generated {} samples.",
706 samples.dims()[0..2].iter().product::<usize>()
707 ))
708 }
709}