mini_mcmc/
hmc.rs

1//! Hamiltonian Monte Carlo (HMC) sampler.
2//!
3//! This is modeled similarly to a Metropolis–Hastings sampler but uses gradient-based proposals
4//! for improved efficiency. The sampler works in a data-parallel fashion and can update multiple
5//! chains simultaneously.
6//!
7//! The code relies on a target distribution provided via the `BatchedGradientTarget` trait, which computes
8//! the unnormalized log probability for a batch of positions. The HMC implementation uses the leapfrog
9//! integrator to simulate Hamiltonian dynamics, and the standard accept/reject step for proposal
10//! validation.
11
12use crate::distributions::BatchedGradientTarget;
13use crate::stats::MultiChainTracker;
14use crate::stats::RunStats;
15use burn::prelude::*;
16use burn::tensor::backend::AutodiffBackend;
17use burn::tensor::Tensor;
18use indicatif::{ProgressBar, ProgressStyle};
19use num_traits::Float;
20use rand::prelude::*;
21use rand::Rng;
22use rand_distr::{StandardNormal, StandardUniform};
23use std::error::Error;
24
25/// A data-parallel Hamiltonian Monte Carlo (HMC) sampler.
26///
27/// This struct encapsulates the HMC algorithm, including the leapfrog integrator and the
28/// accept/reject mechanism, for sampling from a target distribution in a batched manner.
29///
30/// # Type Parameters
31///
32/// * `T`: Floating-point type for numerical calculations.
33/// * `B`: Autodiff backend from the `burn` crate.
34/// * `GTarget`: The target distribution type implementing the `BatchedGradientTarget` trait.
35#[derive(Debug, Clone)]
36pub struct HMC<T, B, GTarget>
37where
38    B: AutodiffBackend,
39{
40    /// The target distribution which provides log probability evaluations and gradients.
41    pub target: GTarget,
42    /// The step size for the leapfrog integrator.
43    pub step_size: T,
44    /// The number of leapfrog steps to take per HMC update.
45    pub n_leapfrog: usize,
46    /// The current positions for all chains, stored as a tensor of shape `[n_chains, D]` where:
47    /// - `n_chains`: number of parallel chains
48    /// - `D`: dimensionality of the state space
49    pub positions: Tensor<B, 2>,
50
51    /// Last step's position gradient
52    last_grad_summands: Tensor<B, 2>,
53
54    /// A random number generator for sampling momenta and uniform random numbers for the
55    /// Metropolis acceptance test.
56    pub rng: SmallRng,
57}
58
59impl<T, B, GTarget> HMC<T, B, GTarget>
60where
61    T: Float
62        + burn::tensor::ElementConversion
63        + burn::tensor::Element
64        + rand_distr::uniform::SampleUniform
65        + num_traits::FromPrimitive,
66    B: AutodiffBackend,
67    GTarget: BatchedGradientTarget<T, B> + std::marker::Sync,
68    StandardNormal: rand::distr::Distribution<T>,
69    StandardUniform: rand_distr::Distribution<T>,
70{
71    /// Create a new data-parallel HMC sampler.
72    ///
73    /// This method initializes the sampler with the target distribution, initial positions,
74    /// step size, number of leapfrog steps, and a random seed for reproducibility.
75    ///
76    /// # Parameters
77    ///
78    /// * `target`: The target distribution implementing the `BatchedGradientTarget` trait.
79    /// * `initial_positions`: A vector of vectors containing the initial positions for each chain, with shape `[n_chains][D]`.
80    /// * `step_size`: The step size used in the leapfrog integrator.
81    /// * `n_leapfrog`: The number of leapfrog steps per update.
82    /// * `seed`: A seed for initializing the random number generator.
83    ///
84    /// # Returns
85    ///
86    /// A new instance of `HMC`.
87    pub fn new(
88        target: GTarget,
89        initial_positions: Vec<Vec<T>>,
90        step_size: T,
91        n_leapfrog: usize,
92    ) -> Self {
93        // Build a [n_chains, D] tensor from the flattened initial positions.
94        let (n_chains, dim) = (initial_positions.len(), initial_positions[0].len());
95        let td: TensorData = TensorData::new(
96            initial_positions.into_iter().flatten().collect(),
97            [n_chains, dim],
98        );
99        let positions = Tensor::<B, 2>::from_data(td, &B::Device::default());
100        let rng = SmallRng::seed_from_u64(rand::rng().random::<u64>());
101        Self {
102            target,
103            step_size,
104            n_leapfrog,
105            last_grad_summands: Tensor::<B, 2>::zeros_like(&positions),
106            positions,
107            rng,
108        }
109    }
110
111    /// Sets a new random seed.
112    ///
113    /// This method ensures reproducibility across runs.
114    ///
115    /// # Arguments
116    ///
117    /// * `seed` - The new random seed value.
118    pub fn set_seed(mut self, seed: u64) -> Self {
119        self.rng = SmallRng::seed_from_u64(seed);
120        self
121    }
122
123    /// Run the HMC sampler for `n_collect` + `n_discard` steps.
124    ///
125    /// First, the sampler takes `n_discard` burn-in steps, then takes
126    /// `n_collect` further steps and collects those observations in a 3D tensor of
127    /// shape `[n_chains, n_collect, D]`.
128    ///
129    /// # Parameters
130    ///
131    /// * `n_collect` - The number of observations to collect and return.
132    /// * `n_discard` - The number of observations to discard (burn-in).
133    ///
134    /// # Returns
135    ///
136    /// A tensor containing the collected observations.
137    pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3> {
138        let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
139        let mut out = Tensor::<B, 3>::empty(
140            [n_collect, n_chains, self.positions.dims()[1]],
141            &B::Device::default(),
142        );
143
144        // Discard the first `discard` positions.
145        (0..n_discard).for_each(|_| self.step());
146
147        // Collect observations.
148        for step in 1..(n_collect + 1) {
149            self.step();
150            out.inplace(|_out| {
151                _out.slice_assign(
152                    [step - 1..step, 0..n_chains, 0..dim],
153                    self.positions.clone().unsqueeze_dim(0),
154                )
155            });
156        }
157        out.permute([1, 0, 2])
158    }
159
160    /// Run the HMC sampler for `n_collect` + `n_discard` steps and displays progress with
161    /// convergence statistics.
162    ///
163    /// First, the sampler takes `n_discard` burn-in steps, then takes
164    /// `n_collect` further steps and collects those observations in a 3D tensor of
165    /// shape `[n_chains, n_collect, D]`.
166    ///
167    /// This function displays a progress bar (using the `indicatif` crate) that is updated
168    /// with an approximate acceptance probability computed over a sliding window of 100 iterations
169    /// as well as the potential scale reduction factor, see [Stan Reference Manual.][1]
170    ///
171    /// # Parameters
172    ///
173    /// * `n_collect` - The number of observations to collect and return.
174    /// * `n_discard` - The number of observations to discard (burn-in).
175    ///
176    /// # Returns
177    ///
178    /// A tuple containing:
179    /// - A tensor of shape `[n_chains, n_collect, D]` containing the collected observations.
180    /// - A `RunStats` object containing convergence statistics including:
181    ///   - Acceptance probability
182    ///   - Potential scale reduction factor (R-hat)
183    ///   - Effective sample size (ESS)
184    ///   - Other convergence diagnostics
185    ///
186    /// # Example
187    ///
188    /// ```rust
189    /// use mini_mcmc::hmc::HMC;
190    /// use mini_mcmc::distributions::DiffableGaussian2D;
191    /// use burn::backend::{Autodiff, NdArray};
192    /// use burn::prelude::*;
193    ///
194    /// // Create a 2D Gaussian target distribution
195    /// let target = DiffableGaussian2D::new(
196    ///     [0.0_f32, 1.0],  // mean
197    ///     [[4.0, 2.0],     // covariance
198    ///      [2.0, 3.0]]
199    /// );
200    ///
201    /// // Create HMC sampler with:
202    /// // - target distribution
203    /// // - initial positions for each chain
204    /// // - step size for leapfrog integration
205    /// // - number of leapfrog steps
206    /// type BackendType = Autodiff<NdArray>;
207    /// let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
208    ///     target,
209    ///     vec![vec![0.0; 2]; 4],    // Initial positions for 4 chains
210    ///     0.1,                      // Step size
211    ///     5,                       // Number of leapfrog steps
212    /// );
213    ///
214    /// // Run sampler with progress tracking
215    /// let (sample, stats) = sampler.run_progress(12, 34).unwrap();
216    ///
217    /// // Print convergence statistics
218    /// println!("{stats}");
219    /// ```
220    ///
221    /// [1]: https://mc-stan.org/docs/2_18/reference-manual/notation-for-samples-chains-and-draws.html
222    pub fn run_progress(
223        &mut self,
224        n_collect: usize,
225        n_discard: usize,
226    ) -> Result<(Tensor<B, 3>, RunStats), Box<dyn Error>> {
227        // Discard initial burn-in observations.
228        (0..n_discard).for_each(|_| self.step());
229
230        let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
231        let mut out = Tensor::<B, 3>::empty([n_collect, n_chains, dim], &B::Device::default());
232
233        let pb = ProgressBar::new(n_collect as u64);
234        pb.set_style(
235            ProgressStyle::default_bar()
236                .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
237                .unwrap()
238                .progress_chars("=>-"),
239        );
240        pb.set_prefix("HMC");
241
242        let mut tracker = MultiChainTracker::new(n_chains, dim);
243
244        let mut last_state = self.positions.clone();
245
246        let mut last_state_data = last_state.to_data();
247        if let Err(e) = tracker.step(last_state_data.as_slice::<T>().unwrap()) {
248            eprintln!("Warning: Shown progress statistics may be unreliable since updating them failed with: {}", e);
249        }
250
251        for i in 0..n_collect {
252            self.step();
253            let current_state = self.positions.clone();
254
255            // Store the current state.
256            out.inplace(|_out| {
257                _out.slice_assign(
258                    [i..i + 1, 0..n_chains, 0..dim],
259                    current_state.clone().unsqueeze_dim(0),
260                )
261            });
262            pb.inc(1);
263            last_state = current_state;
264
265            last_state_data = last_state.to_data();
266            if let Err(e) = tracker.step(last_state_data.as_slice::<T>().unwrap()) {
267                eprintln!("Warning: Shown progress statistics may be unreliable since updating them failed with: {}", e);
268            }
269
270            match tracker.max_rhat() {
271                Err(e) => {
272                    eprintln!("Computing max(rhat) failed with: {}", e);
273                }
274                Ok(max_rhat) => {
275                    pb.set_message(format!(
276                        "p(accept)≈{:.2} max(rhat)≈{:.2}",
277                        tracker.p_accept, max_rhat
278                    ));
279                }
280            }
281        }
282        pb.finish_with_message("Done!");
283        let sample = out.permute([1, 0, 2]);
284
285        let stats = match tracker.stats(sample.clone()) {
286            Ok(stats) => stats,
287            Err(e) => {
288                eprintln!("Getting run statistics failed with: {}", e);
289                return Err(e);
290            }
291        };
292
293        Ok((sample, stats))
294    }
295
296    /// Perform one batched HMC update for all chains in parallel.
297    ///
298    /// The update consists of:
299    /// 1) Sampling momenta from a standard normal distribution.
300    /// 2) Running the leapfrog integrator to propose new positions.
301    /// 3) Performing an accept/reject step for each chain.
302    ///
303    /// This method updates `self.positions` in-place.
304    pub fn step(&mut self) {
305        let shape = self.positions.shape();
306        let (n_chains, dim) = (shape.dims[0], shape.dims[1]);
307
308        // 1) Sample momenta: shape [n_chains, D]
309        let momentum_0 = Tensor::<B, 2>::random(
310            Shape::new([n_chains, dim]),
311            burn::tensor::Distribution::Normal(0., 1.),
312            &B::Device::default(),
313        );
314
315        // Current log probability: shape [n_chains]
316        // Detach pos to ensure it's AD-enabled for the gradient computation.
317        let pos = self.positions.clone().detach().require_grad();
318        let logp_current = self.target.unnorm_logp_batch(pos.clone());
319
320        // Compute gradient of log probability with respect to pos.
321        // First gradient step in leapfrog needs it.
322        let grads = pos.grad(&logp_current.backward()).unwrap();
323        let grad_summands =
324            Tensor::<B, 2>::from_inner(grads.mul_scalar(self.step_size * T::from(0.5).unwrap()));
325        self.last_grad_summands = grad_summands;
326
327        // Compute kinetic energy: 0.5 * sum_{d} (p^2) for each chain.
328        let ke_current = momentum_0
329            .clone()
330            .powf_scalar(2.0)
331            .sum_dim(1) // Sum over dimension 1 => shape [n_chains]
332            .squeeze(1)
333            .mul_scalar(T::from(0.5).unwrap());
334
335        // Compute the Hamiltonian: -logp + kinetic energy, shape [n_chains]
336        let h_current: Tensor<B, 1> = -logp_current + ke_current;
337
338        // 2) Run the leapfrog integrator.
339        let (proposed_positions, proposed_momenta, logp_proposed) =
340            self.leapfrog(self.positions.clone(), momentum_0);
341
342        // Compute proposed kinetic energy.
343        let ke_proposed = proposed_momenta
344            .powf_scalar(2.0)
345            .sum_dim(1)
346            .squeeze(1)
347            .mul_scalar(T::from(0.5).unwrap());
348
349        let h_proposed = -logp_proposed + ke_proposed;
350
351        // 3) Accept/Reject each proposal.
352        let accept_logp = h_current.sub(h_proposed);
353
354        // Draw a uniform random number for each chain.
355        let mut uniform_data = Vec::with_capacity(n_chains);
356        for _ in 0..n_chains {
357            uniform_data.push(self.rng.random::<T>());
358        }
359        let uniform = Tensor::<B, 1>::random(
360            Shape::new([n_chains]),
361            burn::tensor::Distribution::Default,
362            &B::Device::default(),
363        );
364
365        // Accept the proposal if accept_logp >= ln(u).
366        let ln_u = uniform.log(); // shape [n_chains]
367        let accept_mask = accept_logp.greater_equal(ln_u); // Boolean mask of shape [n_chains]
368        let mut accept_mask_big: Tensor<B, 2, Bool> = accept_mask.clone().unsqueeze_dim(1);
369        accept_mask_big = accept_mask_big.expand([n_chains, dim]);
370
371        // Update positions: for accepted chains, replace current positions with proposed positions.
372        self.positions.inplace(|x| {
373            x.clone()
374                .mask_where(accept_mask_big, proposed_positions)
375                .detach()
376        });
377    }
378
379    /// Perform the leapfrog integrator steps in a batched manner.
380    ///
381    /// This method performs `n_leapfrog` iterations of the leapfrog update:
382    /// - A half-step update of the momentum.
383    /// - A full-step update of the positions.
384    /// - Another half-step update of the momentum.
385    ///
386    /// # Parameters
387    ///
388    /// * `pos`: The current positions, a tensor of shape `[n_chains, D]`.
389    /// * `mom`: The initial momenta, a tensor of shape `[n_chains, D]`.
390    ///
391    /// # Returns
392    ///
393    /// A tuple containing:
394    /// - The new positions (tensor of shape `[n_chains, D]`),
395    /// - The new momenta (tensor of shape `[n_chains, D]`),
396    /// - The log probability evaluated at the new positions (tensor of shape `[n_chains]`).
397    fn leapfrog(
398        &mut self,
399        mut pos: Tensor<B, 2>,
400        mut mom: Tensor<B, 2>,
401    ) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 1>) {
402        let half = T::from(0.5).unwrap();
403        for _step_i in 0..self.n_leapfrog {
404            // Detach pos to ensure it's AD-enabled for the gradient computation.
405            pos = pos.detach().require_grad();
406
407            // Update momentum by a half-step using the computed gradients.
408            mom.inplace(|_mom| _mom.add(self.last_grad_summands.clone()));
409
410            // Full-step update for positions.
411            pos.inplace(|_pos| {
412                _pos.add(mom.clone().mul_scalar(self.step_size))
413                    .detach()
414                    .require_grad()
415            });
416
417            // Compute gradient at the new positions.
418            let logp = self.target.unnorm_logp_batch(pos.clone());
419            let grads = pos.grad(&logp.backward()).unwrap();
420            let grad_summands = Tensor::<B, 2>::from_inner(grads.mul_scalar(self.step_size * half));
421
422            // Update momentum by another half-step using the new gradients.
423            mom.inplace(|_mom| _mom.add(grad_summands.clone()));
424
425            self.last_grad_summands = grad_summands;
426        }
427
428        // Compute final log probability at the updated positions.
429        let logp_final = self.target.unnorm_logp_batch(pos.clone());
430        (pos.detach(), mom.detach(), logp_final.detach())
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use crate::{
437        core::init,
438        dev_tools::Timer,
439        distributions::{DiffableGaussian2D, Rosenbrock2D, RosenbrockND},
440        stats::split_rhat_mean_ess,
441    };
442    use ndarray::ArrayView3;
443    use ndarray_stats::QuantileExt;
444
445    use super::*;
446    use burn::{
447        backend::{Autodiff, NdArray},
448        tensor::Tensor,
449    };
450
451    // Use the CPU backend (NdArray) wrapped in Autodiff.
452    type BackendType = Autodiff<NdArray>;
453
454    #[test]
455    fn test_hmc_single() {
456        // Create the Rosenbrock target (a = 1, b = 100)
457        let target = Rosenbrock2D {
458            a: 1.0_f32,
459            b: 100.0_f32,
460        };
461
462        // Define initial positions for a single chain (2-dimensional).
463        let initial_positions = vec![vec![0.0_f32, 0.0]];
464        let n_collect = 3;
465
466        // Create the HMC sampler.
467        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
468            target,
469            initial_positions,
470            0.01, // step size
471            2,    // number of leapfrog steps per update
472        )
473        .set_seed(42);
474
475        // Run the sampler for n_collect steps.
476        let mut timer = Timer::new();
477        let sample: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
478        timer.log(format!(
479            "Collected sample (10 chains) with shape: {:?}",
480            sample.dims()
481        ));
482        assert_eq!(sample.dims(), [1, 3, 2]);
483    }
484
485    #[test]
486    fn test_3_chains() {
487        // Use the CPU backend (NdArray) wrapped in Autodiff.
488        type BackendType = Autodiff<NdArray>;
489
490        // Create the Rosenbrock target (a = 1, b = 100)
491        let target = Rosenbrock2D {
492            a: 1.0_f32,
493            b: 100.0_f32,
494        };
495
496        // Define 3 chains all initialized to (1.0, 2.0).
497        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 3];
498        let n_collect = 10;
499
500        // Create the HMC sampler.
501        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
502            target,
503            initial_positions,
504            0.01, // step size
505            2,    // number of leapfrog steps per update
506        )
507        .set_seed(42);
508
509        // Run the sampler for n_collect.
510        let mut timer = Timer::new();
511        let sample: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
512        timer.log(format!(
513            "Collected sample (3 chains) with shape: {:?}",
514            sample.dims()
515        ));
516        assert_eq!(sample.dims(), [3, 10, 2]);
517    }
518
519    #[test]
520    fn test_progress_3_chains() {
521        // Use the CPU backend (NdArray) wrapped in Autodiff.
522        type BackendType = Autodiff<NdArray>;
523
524        // Create the Rosenbrock target (a = 1, b = 100)
525        let target = Rosenbrock2D {
526            a: 1.0_f32,
527            b: 100.0_f32,
528        };
529
530        // Define 3 chains all initialized to (1.0, 2.0).
531        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 3];
532        let n_collect = 10;
533
534        // Create the HMC sampler.
535        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
536            target,
537            initial_positions,
538            0.05, // step size
539            2,    // number of leapfrog steps per update
540        )
541        .set_seed(42);
542
543        // Run the sampler for n_collect with no discard.
544        let mut timer = Timer::new();
545        let sample: Tensor<BackendType, 3> = sampler.run_progress(n_collect, 3).unwrap().0;
546        timer.log(format!(
547            "Collected sample (10 chains) with shape: {:?}",
548            sample.dims()
549        ));
550        assert_eq!(sample.dims(), [3, 10, 2]);
551    }
552
553    #[test]
554    fn test_gaussian_2d_hmc_debug() {
555        let n_chains = 1;
556        let n_discard = 1;
557        let n_collect = 1;
558
559        let target = DiffableGaussian2D::new([0.0, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
560        let initial_positions = vec![vec![0.0_f32, 0.0_f32]];
561
562        type BackendType = Autodiff<NdArray>;
563        let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
564            target,
565            initial_positions,
566            0.1,
567            1,
568        )
569        .set_seed(42);
570
571        let sample_3d = sampler.run(n_collect, n_discard);
572
573        assert_eq!(sample_3d.dims(), [n_chains, n_collect, 2]);
574    }
575
576    #[test]
577    #[ignore = "Benchmark test: run only when explicitly requested"]
578    fn test_gaussian_2d_hmc_single_run() {
579        // Each experiment uses 3 chains:
580        let n_chains = 3;
581
582        let n_discard = 500;
583        let n_collect = 1000;
584
585        // 1) Define the 2D Gaussian target distribution:
586        //    mean: [0.0, 1.0], cov: [[4.0, 2.0], [2.0, 3.0]]
587        let target = DiffableGaussian2D::new([0.0, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
588
589        // 2) Define 3 chains, each chain is 2-dimensional:
590        let initial_positions = vec![
591            vec![1.0_f32, 2.0_f32],
592            vec![1.0_f32, 2.0_f32],
593            vec![1.0_f32, 2.0_f32],
594        ];
595
596        // 3) Create the HMC sampler using NdArray backend with autodiff
597        type BackendType = Autodiff<NdArray>;
598        let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
599            target,
600            initial_positions,
601            0.1, // step size
602            10,  // leapfrog steps
603        )
604        .set_seed(42);
605
606        // 4) Run the sampler for (burn_in + collected) steps, discard the first `burn_in`
607        //    The shape of `sample` will be [n_chains, collected, 2]
608        let sample_3d = sampler.run(n_collect, n_discard);
609
610        // Check shape is as expected
611        assert_eq!(sample_3d.dims(), [n_chains, n_collect, 2]);
612
613        // 5) Convert the sample into an ndarray view
614        let data = sample_3d.to_data();
615        let arr = ArrayView3::from_shape(sample_3d.dims(), data.as_slice().unwrap()).unwrap();
616
617        // 6) Compute split-Rhat and ESS
618        let (rhat, ess_vals) = split_rhat_mean_ess(arr.view());
619        let ess1 = ess_vals[0];
620        let ess2 = ess_vals[1];
621
622        println!("\nSingle Run Results:");
623        println!("Rhat: {:?}", rhat);
624        println!("ESS(Param1): {:.2}", ess1);
625        println!("ESS(Param2): {:.2}", ess2);
626
627        // Optionally, add some asserts about expected minimal ESS
628        assert!(ess1 > 50.0, "Expected param1 ESS > 50, got {:.2}", ess1);
629        assert!(ess2 > 50.0, "Expected param2 ESS > 50, got {:.2}", ess2);
630    }
631
632    #[test]
633    #[ignore = "Benchmark test: run only when explicitly requested"]
634    fn test_gaussian_2d_hmc_ess_stats() {
635        use crate::stats::basic_stats;
636        use indicatif::{ProgressBar, ProgressStyle};
637        use ndarray::Array1;
638
639        let n_runs = 100;
640        let n_chains = 3;
641        let n_discard = 500;
642        let n_collect = 1000;
643        let mut rng = SmallRng::seed_from_u64(42);
644
645        // We'll store the ESS and R-hat values for each parameter across all runs
646        let mut ess_param1s = Vec::with_capacity(n_runs);
647        let mut ess_param2s = Vec::with_capacity(n_runs);
648        let mut rhat_param1s = Vec::with_capacity(n_runs);
649        let mut rhat_param2s = Vec::with_capacity(n_runs);
650
651        // Set up the progress bar
652        let pb = ProgressBar::new(n_runs as u64);
653        pb.set_style(
654            ProgressStyle::default_bar()
655                .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
656                .unwrap()
657                .progress_chars("=>-"),
658        );
659        pb.set_prefix("HMC Test");
660
661        for run in 0..n_runs {
662            // 1) Define the 2D Gaussian target distribution:
663            //    mean: [0.0, 1.0], cov: [[4.0, 2.0], [2.0, 3.0]]
664            let target = DiffableGaussian2D::new([0.0_f32, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
665
666            // 2) Define 3 chains, each chain is 2-dimensional:
667            // Create a seeded RNG for reproducible initial positions
668            let initial_positions: Vec<Vec<f32>> = (0..n_chains)
669                .map(|_| {
670                    // Sample 2D position from standard normal
671                    vec![
672                        rng.sample::<f32, _>(StandardNormal),
673                        rng.sample::<f32, _>(StandardNormal),
674                    ]
675                })
676                .collect();
677
678            // 3) Create the HMC sampler using NdArray backend with autodiff
679            type BackendType = Autodiff<NdArray>;
680            let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
681                target,
682                initial_positions,
683                0.1, // step size
684                10,  // leapfrog steps
685            )
686            .set_seed(run as u64); // Use run number as seed for reproducibility
687
688            // 4) Run the sampler for (n_discard + n_collect) steps, discard the first `n_discard`
689            //    observations
690            let sample_3d = sampler.run(n_collect, n_discard);
691
692            // Check shape is as expected
693            assert_eq!(sample_3d.dims(), [n_chains, n_collect, 2]);
694
695            // 5) Convert the sample into an ndarray view
696            let data = sample_3d.to_data();
697            let arr = ArrayView3::from_shape(sample_3d.dims(), data.as_slice().unwrap()).unwrap();
698
699            // 6) Compute split-Rhat and ESS
700            let (rhat, ess_vals) = split_rhat_mean_ess(arr.view());
701            let ess1 = ess_vals[0];
702            let ess2 = ess_vals[1];
703
704            // Store ESS values
705            ess_param1s.push(ess1);
706            ess_param2s.push(ess2);
707
708            // Store R-hat values from stats object
709            rhat_param1s.push(rhat[0]);
710            rhat_param2s.push(rhat[1]);
711
712            pb.inc(1);
713
714            // Update progress bar with current ESS statistics across runs
715            if run > 0 {
716                // Calculate mean and std of ESS for both parameters across all runs so far
717                let mean_ess1 = ess_param1s.iter().sum::<f32>() / (run as f32 + 1.0);
718                let mean_ess2 = ess_param2s.iter().sum::<f32>() / (run as f32 + 1.0);
719
720                // Calculate standard deviations
721                let var_ess1 = ess_param1s
722                    .iter()
723                    .map(|&x| (x - mean_ess1).powi(2))
724                    .sum::<f32>()
725                    / (run as f32 + 1.0);
726                let var_ess2 = ess_param2s
727                    .iter()
728                    .map(|&x| (x - mean_ess2).powi(2))
729                    .sum::<f32>()
730                    / (run as f32 + 1.0);
731
732                let std_ess1 = var_ess1.sqrt();
733                let std_ess2 = var_ess2.sqrt();
734
735                pb.set_message(format!(
736                    "ESS1={:.0}±{:.0} ESS2={:.0}±{:.0}",
737                    mean_ess1, std_ess1, mean_ess2, std_ess2
738                ));
739            } else {
740                // For the first run, just show the current values
741                pb.set_message(format!("ESS1={:.0} ESS2={:.0}", ess1, ess2));
742            }
743        }
744        pb.finish_with_message("All runs complete!");
745
746        // Convert to ndarray for statistics
747        let ess_param1_array = Array1::from_vec(ess_param1s);
748        let ess_param2_array = Array1::from_vec(ess_param2s);
749        let rhat_param1_array = Array1::from_vec(rhat_param1s);
750        let rhat_param2_array = Array1::from_vec(rhat_param2s);
751
752        // Compute and print statistics
753        let stats_p1_ess = basic_stats("ESS(Param1)", ess_param1_array);
754        let stats_p2_ess = basic_stats("ESS(Param2)", ess_param2_array);
755        let stats_p1_rhat = basic_stats("R-hat(Param1)", rhat_param1_array);
756        let stats_p2_rhat = basic_stats("R-hat(Param2)", rhat_param2_array);
757
758        println!("\nStatistics over {} runs:", n_runs);
759        println!("\nESS Statistics:");
760        println!("{stats_p1_ess}\n{stats_p2_ess}");
761        println!("\nR-hat Statistics:");
762        println!("{stats_p1_rhat}\n{stats_p2_rhat}");
763
764        // Assertions for ESS
765        assert!(
766            (135.0..=185.0).contains(&stats_p1_ess.mean),
767            "Expected param1 ESS to average in [135, 185], got {:.2}",
768            stats_p1_ess.mean
769        );
770        assert!(
771            (141.0..=191.0).contains(&stats_p2_ess.mean),
772            "Expected param2 ESS to average in [141, 191], got {:.2}",
773            stats_p2_ess.mean
774        );
775
776        // Assertions for R-hat (should be close to 1.0)
777        assert!(
778            (0.95..=1.05).contains(&stats_p1_rhat.mean),
779            "Expected param1 R-hat to be in [0.95, 1.05], got {:.2}",
780            stats_p1_rhat.mean
781        );
782        assert!(
783            (0.95..=1.05).contains(&stats_p2_rhat.mean),
784            "Expected param2 R-hat to be in [0.95, 1.05], got {:.2}",
785            stats_p2_rhat.mean
786        );
787    }
788
789    #[test]
790    #[ignore = "Benchmark test: run only when explicitly requested"]
791    fn test_bench_noprogress() {
792        // Use the CPU backend (NdArray) wrapped in Autodiff.
793        type BackendType = Autodiff<burn::backend::NdArray>;
794
795        // Create the Rosenbrock target (a = 1, b = 100)
796        let target = Rosenbrock2D {
797            a: 1.0_f32,
798            b: 100.0_f32,
799        };
800
801        // We'll define 6 chains all initialized to (1.0, 2.0).
802        let initial_positions = init(6, 2);
803        let n_collect = 5000;
804        let n_discard = 500;
805
806        // Create the data-parallel HMC sampler.
807        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
808            target,
809            initial_positions,
810            0.01, // step size
811            50,   // number of leapfrog steps per update
812        )
813        .set_seed(42);
814
815        // Run HMC for `n_collect` steps.
816        let mut timer = Timer::new();
817        let sample = sampler.run(n_collect, n_discard);
818        timer.log(format!(
819            "HMC sampler: generated {} observations.",
820            sample.dims()[0..2].iter().product::<usize>()
821        ));
822        assert_eq!(sample.dims(), [6, 5000, 2]);
823
824        let data = sample.to_data();
825        let array = ArrayView3::from_shape(sample.dims(), data.as_slice().unwrap()).unwrap();
826        let (split_rhat, ess) = split_rhat_mean_ess(array);
827        println!("MIN Split Rhat: {}", split_rhat.min().unwrap());
828        println!("MIN ESS: {}", ess.min().unwrap());
829    }
830
831    #[test]
832    #[ignore = "Benchmark test: run only when explicitly requested"]
833    fn test_progress_bench() {
834        // Use the CPU backend (NdArray) wrapped in Autodiff.
835        type BackendType = Autodiff<burn::backend::NdArray>;
836        BackendType::seed(42);
837
838        // Create the Rosenbrock target (a = 1, b = 100)
839        let target = Rosenbrock2D {
840            a: 1.0_f32,
841            b: 100.0_f32,
842        };
843
844        // We'll define 6 chains all initialized to (1.0, 2.0).
845        let n_chains = 6;
846        let initial_positions = vec![vec![1.0_f32, 2.0_f32]; n_chains];
847        let n_collect = 1000;
848        let n_discard = 1000;
849
850        // Create the data-parallel HMC sampler.
851        let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
852            target,
853            initial_positions,
854            0.01, // step size
855            50,   // number of leapfrog steps per update
856        )
857        .set_seed(42);
858
859        // Run HMC for n_collect steps.
860        let mut timer = Timer::new();
861        let sample = sampler.run_progress(n_collect, n_discard).unwrap().0;
862        timer.log(format!(
863            "HMC sampler: generated {} observations.",
864            sample.dims()[0..2].iter().product::<usize>()
865        ));
866        println!(
867            "Chain 1, first 10: {}",
868            sample.clone().slice([0..1, 0..10, 0..1])
869        );
870        println!(
871            "Chain 2, first 10: {}",
872            sample.clone().slice([2..3, 0..10, 0..1])
873        );
874
875        #[cfg(feature = "csv")]
876        crate::io::csv::save_csv_tensor(sample.clone(), "/tmp/hmc-sample.csv")
877            .expect("Expected saving to succeed");
878
879        assert_eq!(sample.dims(), [n_chains, n_collect, 2]);
880    }
881
882    #[test]
883    #[ignore = "Benchmark test: run only when explicitly requested"]
884    fn test_bench_10000d() {
885        // Use the CPU backend (NdArray) wrapped in Autodiff.
886        type BackendType = Autodiff<burn::backend::NdArray>;
887
888        let seed = 42;
889        let d = 10000;
890        let n_chains = 6;
891        let n_collect = 100;
892        let n_discard = 100;
893
894        let rng = SmallRng::seed_from_u64(seed);
895        // We'll define 6 chains all initialized to (1.0, 2.0).
896        let initial_positions: Vec<Vec<f32>> =
897            vec![rng.sample_iter(StandardNormal).take(d).collect(); n_chains];
898
899        // Create the data-parallel HMC sampler.
900        let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
901            RosenbrockND {},
902            initial_positions,
903            0.01, // step size
904            50,   // number of leapfrog steps per update
905        )
906        .set_seed(42);
907
908        // Run HMC for n_collect steps.
909        let mut timer = Timer::new();
910        let sample = sampler.run(n_collect, n_discard);
911        timer.log(format!(
912            "HMC sampler: generated {} observations.",
913            sample.dims()[0..2].iter().product::<usize>()
914        ));
915        assert_eq!(sample.dims(), [n_chains, n_collect, d]);
916    }
917
918    #[test]
919    #[ignore = "Benchmark test: run only when explicitly requested"]
920    #[cfg(feature = "wgpu")]
921    fn test_progress_10000d_bench() {
922        type BackendType = Autodiff<burn::backend::Wgpu>;
923
924        let seed = 42;
925        let d = 10000;
926        let n_chains = 6;
927
928        let rng = SmallRng::seed_from_u64(seed);
929        // We'll define 6 chains all initialized to (1.0, 2.0).
930        let initial_positions: Vec<Vec<f32>> =
931            vec![rng.sample_iter(StandardNormal).take(d).collect(); n_chains];
932        let n_collect = 100;
933        let n_discard = 100;
934
935        // Create the data-parallel HMC sampler.
936        let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
937            RosenbrockND {},
938            initial_positions,
939            0.01, // step size
940            50,   // number of leapfrog steps per update
941        )
942        .set_seed(42);
943
944        // Run HMC for n_collect steps.
945        let mut timer = Timer::new();
946        let sample = sampler.run_progress(n_collect, n_discard).unwrap().0;
947        timer.log(format!(
948            "HMC sampler: generated {} observations.",
949            sample.dims()[0..2].iter().product::<usize>()
950        ));
951        assert_eq!(sample.dims(), [n_chains, n_collect, d]);
952    }
953}