mini_mcmc/
nuts.rs

1//! No-U-Turn Sampler (NUTS).
2//!
3//! A parallel implementation of NUTS running independent Markov chains via Rayon.
4//!
5//! ## Example: custom 2D Rosenbrock target
6//! ```rust
7//! use mini_mcmc::core::init;
8//! use mini_mcmc::distributions::GradientTarget;
9//! use mini_mcmc::nuts::NUTS;
10//! use burn::backend::{Autodiff, NdArray};
11//! use burn::prelude::*;
12//!
13//! type B = Autodiff<NdArray>;
14//!
15//! #[derive(Clone)]
16//! struct Rosenbrock2D { a: f32, b: f32 }
17//!
18//! impl GradientTarget<f32, B> for Rosenbrock2D {
19//!     fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
20//!         let x = position.clone().slice(s![0..1]);
21//!         let y = position.slice(s![1..2]);
22//!         let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
23//!         let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
24//!         -(term_1 + term_2)
25//!     }
26//! }
27//!
28//! let target = Rosenbrock2D { a: 1.0, b: 100.0 };
29//! let initial_positions = init::<f32>(4, 2);    // 4 chains in 2D
30//! let mut sampler = NUTS::new(target, initial_positions, 0.9);
31//! let (samples, stats) = sampler.run_progress(100, 20).unwrap();
32//! ```
33//!
34//! ## Inspiration
35//! Borrowed ideas from [mfouesneau/NUTS](https://github.com/mfouesneau/NUTS).
36
37use std::error::Error;
38use std::sync::mpsc;
39use std::sync::mpsc::{Receiver, Sender};
40use std::thread;
41use std::time::{Duration, Instant};
42
43use crate::distributions::GradientTarget;
44use crate::stats::{collect_rhat, ChainStats, ChainTracker, RunStats};
45use burn::prelude::*;
46use burn::tensor::backend::AutodiffBackend;
47use burn::tensor::cast::ToElement;
48use burn::tensor::Element;
49use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
50use ndarray::ArrayView3;
51use ndarray_stats::QuantileExt;
52use num_traits::{Float, FromPrimitive};
53use rand::prelude::*;
54use rand::Rng;
55use rand_distr::uniform::SampleUniform;
56use rand_distr::{Exp1, StandardNormal, StandardUniform};
57use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
58
59/// No-U-Turn Sampler (NUTS).
60///
61/// Encapsulates multiple independent Markov chains using the NUTS algorithm. Utilizes dual-averaging
62/// step size adaptation and dynamic trajectory lengths to efficiently explore complex posterior geometries.
63/// Chains are executed concurrently via Rayon, each evolving independently.
64///
65/// # Type Parameters
66/// - `T`: Floating-point type for numerical calculations.
67/// - `B`: Autodiff backend from the `burn` crate.
68/// - `GTarget`: Target distribution type implementing the `GradientTarget` trait.
69#[derive(Debug, Clone)]
70pub struct NUTS<T, B, GTarget>
71where
72    T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
73    B: AutodiffBackend,
74    GTarget: GradientTarget<T, B> + Sync,
75    StandardNormal: rand::distr::Distribution<T>,
76    StandardUniform: rand_distr::Distribution<T>,
77    rand_distr::Exp1: rand_distr::Distribution<T>,
78{
79    /// The vector of independent Markov chains.
80    chains: Vec<NUTSChain<T, B, GTarget>>,
81}
82
83impl<T, B, GTarget> NUTS<T, B, GTarget>
84where
85    T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Send,
86    B: AutodiffBackend + Send,
87    GTarget: GradientTarget<T, B> + Sync + Clone + Send,
88    StandardNormal: rand::distr::Distribution<T>,
89    StandardUniform: rand_distr::Distribution<T>,
90    rand_distr::Exp1: rand_distr::Distribution<T>,
91{
92    /// Creates a new NUTS sampler with the given target distribution and initial state for each chain.
93    ///
94    /// # Parameters
95    /// - `target`: The target distribution implementing `GradientTarget`.
96    /// - `initial_positions`: A vector of initial positions for each chain, shape `[n_chains, D]`.
97    /// - `target_accept_p`: Desired average acceptance probability for the dual-averaging adaptation. Try values between 0.6 and 0.95.
98    ///
99    /// # Returns
100    /// A newly initialized `NUTS` instance.
101    ///
102    /// # Example
103    ///
104    /// ```rust
105    /// # use burn::backend::{Autodiff, NdArray};
106    /// # use mini_mcmc::nuts::NUTS;
107    /// # use mini_mcmc::distributions::DiffableGaussian2D;
108    /// type B = Autodiff<NdArray>;
109    ///
110    /// // Create a 2D Gaussian with mean [0,0] and identity covariance
111    /// let gauss = DiffableGaussian2D::new([0.0_f64, 0.0], [[1.0, 0.0], [0.0, 1.0]]);
112    ///
113    /// // Initialize 3 chains in 2D at different starting points
114    /// let init_positions = vec![
115    ///     vec![-1.0, -1.0],
116    ///     vec![ 0.0,  0.0],
117    ///     vec![ 1.0,  1.0],
118    /// ];
119    ///
120    /// // Build the sampler targeting 85% acceptance probability
121    /// let sampler: NUTS<f64, B, _> = NUTS::new(gauss, init_positions, 0.85);
122    /// ```
123    pub fn new(target: GTarget, initial_positions: Vec<Vec<T>>, target_accept_p: T) -> Self {
124        let chains = initial_positions
125            .into_iter()
126            .map(|pos| NUTSChain::new(target.clone(), pos, target_accept_p))
127            .collect();
128        Self { chains }
129    }
130
131    /// Runs all chains for a total of `n_collect + n_discard` steps and collects samples.
132    ///
133    /// First discards `n_discard` warm-up steps for each chain (during which adaptation occurs),
134    /// then collects `n_collect` samples per chain.
135    ///
136    /// # Parameters
137    /// - `n_collect`: Number of samples to collect after warm-up per chain.
138    /// - `n_discard`: Number of warm-up (burn-in) steps to discard per chain.
139    ///
140    /// # Returns
141    /// A 3D tensor of shape `[n_chains, n_collect, D]` containing the collected samples.
142    ///
143    /// # Example
144    ///
145    /// ```rust
146    /// # use burn::backend::{Autodiff, NdArray};
147    /// # use burn::prelude::Tensor;
148    /// # use mini_mcmc::nuts::NUTS;
149    /// # use mini_mcmc::core::init;
150    /// # use mini_mcmc::distributions::DiffableGaussian2D;
151    /// type B = Autodiff<NdArray>;
152    ///
153    /// // As above, construct the sampler
154    /// let gauss = DiffableGaussian2D::new([0.0_f32, 0.0], [[1.0,0.0],[0.0,1.0]]);
155    /// let mut sampler = NUTS::new(gauss, init::<f32>(2, 2), 0.8);
156    ///
157    /// // Discard 50 warm-up steps, then collect 150 observations per chain
158    /// let sample: Tensor<B, 3> = sampler.run(150, 50);
159    ///
160    /// // sample.dims() == [2 chains, 150 observations, 2 dimensions]
161    /// assert_eq!(sample.dims(), [2, 150, 2]);
162    /// ```
163    pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3> {
164        let chain_samples: Vec<Tensor<B, 2>> = self
165            .chains
166            .par_iter_mut()
167            .map(|chain| chain.run(n_collect, n_discard))
168            .collect();
169        Tensor::<B, 2>::stack(chain_samples, 0)
170    }
171
172    /// Run with live progress bars and collect summary stats.
173    ///
174    /// Spawns a background thread to render per-chain and global bars,
175    /// then returns `(samples, RunStats)` when done.
176    ///
177    /// # Example
178    ///
179    /// ```rust
180    /// use burn::backend::{Autodiff, NdArray};
181    /// use mini_mcmc::distributions::Rosenbrock2D;
182    /// use mini_mcmc::nuts::NUTS;
183    /// use mini_mcmc::core::init;
184    ///
185    /// type B = Autodiff<NdArray>;
186    ///
187    /// let target = Rosenbrock2D { a: 1.0, b: 100.0 };
188    /// let init   = init::<f64>(4, 2);    // 4 chains in 2D
189    /// let mut sampler = NUTS::<f64, B, Rosenbrock2D<f64>>::new(target, init, 0.9);
190    /// let (samples, stats) = sampler.run_progress(100, 20).unwrap();
191    /// ```
192    ///
193    /// You can swap in any other [`GradientTarget`] just as easily.
194    pub fn run_progress(
195        &mut self,
196        n_collect: usize,
197        n_discard: usize,
198    ) -> Result<(Tensor<B, 3>, RunStats), Box<dyn Error>> {
199        let chains = &mut self.chains;
200
201        let mut rxs: Vec<Receiver<ChainStats>> = vec![];
202        let mut txs: Vec<Sender<ChainStats>> = vec![];
203        (0..chains.len()).for_each(|_| {
204            let (tx, rx) = mpsc::channel();
205            rxs.push(rx);
206            txs.push(tx);
207        });
208
209        let progress_handle = thread::spawn(move || {
210            let sleep_ms = Duration::from_millis(250);
211            let timeout_ms = Duration::from_millis(0);
212            let multi = MultiProgress::new();
213
214            let pb_style = ProgressStyle::default_bar()
215                .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
216                .unwrap()
217                .progress_chars("=>-");
218            let total: u64 = (n_collect + n_discard).try_into().unwrap();
219
220            // Global Progress bar
221            let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
222            global_pb.set_style(pb_style.clone());
223            global_pb.set_prefix("Global");
224
225            let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
226                .map(|chain_idx| {
227                    let pb = multi.add(ProgressBar::new(total));
228                    pb.set_style(pb_style.clone());
229                    pb.set_prefix(format!("Chain {chain_idx}"));
230                    (chain_idx, pb)
231                })
232                .collect();
233            let mut next_active = active.len();
234            let mut n_finished = 0;
235            let mut most_recent = vec![None; rxs.len()];
236            let mut total_progress;
237
238            loop {
239                for (i, rx) in rxs.iter().enumerate() {
240                    while let Ok(stats) = rx.recv_timeout(timeout_ms) {
241                        most_recent[i] = Some(stats)
242                    }
243                }
244
245                // Update chain progress bar messages
246                // and compute average acceptance probability
247                let mut to_replace = vec![false; active.len()];
248                let mut avg_p_accept = 0.0;
249                let mut n_available_stats = 0.0;
250                for (vec_idx, (i, pb)) in active.iter().enumerate() {
251                    if let Some(stats) = &most_recent[*i] {
252                        pb.set_position(stats.n);
253                        pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
254                        avg_p_accept += stats.p_accept;
255                        n_available_stats += 1.0;
256
257                        if stats.n == total {
258                            to_replace[vec_idx] = true;
259                            n_finished += 1;
260                        }
261                    }
262                }
263                avg_p_accept /= n_available_stats;
264
265                // Update global progress bar
266                total_progress = 0;
267                for stats in most_recent.iter().flatten() {
268                    total_progress += stats.n;
269                }
270                global_pb.set_position(total_progress);
271                let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
272                if valid.len() >= 2 {
273                    let rhats = collect_rhat(valid.as_slice());
274                    let max = rhats.max_skipnan();
275                    global_pb.set_message(format!(
276                        "p(accept)≈{:.2} max(rhat)≈{:.2}",
277                        avg_p_accept, max
278                    ))
279                }
280
281                let mut to_remove = vec![];
282                for (i, replace) in to_replace.iter().enumerate() {
283                    if *replace && next_active < most_recent.len() {
284                        let pb = multi.add(ProgressBar::new(total));
285                        pb.set_style(pb_style.clone());
286                        pb.set_prefix(format!("Chain {next_active}"));
287                        active[i] = (next_active, pb);
288                        next_active += 1;
289                    } else if *replace {
290                        to_remove.push(i);
291                    }
292                }
293
294                to_remove.sort();
295                for i in to_remove.iter().rev() {
296                    active.remove(*i);
297                }
298
299                if n_finished >= most_recent.len() {
300                    break;
301                }
302                std::thread::sleep(sleep_ms);
303            }
304        });
305
306        let chain_sample: Vec<Tensor<B, 2>> = thread::scope(|s| {
307            let handles: Vec<thread::ScopedJoinHandle<Tensor<B, 2>>> = chains
308                .iter_mut()
309                .zip(txs)
310                .map(|(chain, tx)| {
311                    s.spawn(|| {
312                        chain
313                            .run_progress(n_collect, n_discard, tx)
314                            .expect("Expected running chain to succeed.")
315                    })
316                })
317                .collect();
318            handles
319                .into_iter()
320                .map(|h| {
321                    h.join()
322                        .expect("Expected thread to succeed in generating observation.")
323                })
324                .collect()
325        });
326        let sample = Tensor::<B, 2>::stack(chain_sample, 0);
327
328        if let Err(e) = progress_handle.join() {
329            eprintln!("Progress bar thread emitted error message: {:?}", e);
330        }
331
332        let sample_f32 = sample.to_data();
333        let view =
334            ArrayView3::<f32>::from_shape(sample.dims(), sample_f32.as_slice().unwrap()).unwrap();
335        let run_stats = RunStats::from(view);
336
337        Ok((sample, run_stats))
338    }
339
340    /// Sets a new random seed for all chains to ensure reproducibility.
341    ///
342    /// # Parameters
343    /// - `seed`: Base seed value. Each chain will derive its own seed for independence.
344    ///
345    /// # Returns
346    /// `self` with the RNGs re-seeded.
347    pub fn set_seed(mut self, seed: u64) -> Self {
348        for (i, chain) in self.chains.iter_mut().enumerate() {
349            let chain_seed = seed + i as u64 + 1;
350            chain.rng = SmallRng::seed_from_u64(chain_seed);
351        }
352        self
353    }
354}
355
356/// Single-chain state and adaptation for NUTS.
357///
358/// Manages the dynamic trajectory building, dual-averaging adaptation of step size,
359/// and current position for one chain.
360#[derive(Debug, Clone)]
361pub struct NUTSChain<T, B, GTarget>
362where
363    B: AutodiffBackend,
364{
365    /// Target distribution providing gradients and log-probabilities.
366    target: GTarget,
367
368    /// Current position in parameter space.
369    pub position: Tensor<B, 1>,
370
371    /// Desired average acceptance probability.
372    target_accept_p: T,
373
374    /// Current step size (epsilon).
375    epsilon: T,
376
377    // Internal variables
378    m: usize,
379    n_collect: usize,
380    n_discard: usize,
381    gamma: T,
382    t_0: usize,
383    kappa: T,
384    mu: T,
385    epsilon_bar: T,
386    h_bar: T,
387
388    rng: SmallRng,
389    phantom_data: std::marker::PhantomData<T>,
390}
391
392impl<T, B, GTarget> NUTSChain<T, B, GTarget>
393where
394    T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
395    B: AutodiffBackend,
396    GTarget: GradientTarget<T, B> + std::marker::Sync,
397    StandardNormal: rand::distr::Distribution<T>,
398    StandardUniform: rand_distr::Distribution<T>,
399    rand_distr::Exp1: rand_distr::Distribution<T>,
400{
401    /// Constructs a new NUTSChain for a single chain with the given initial position.
402    ///
403    /// # Parameters
404    /// - `target`: The target distribution implementing `GradientTarget`.
405    /// - `initial_position`: Initial position vector of length `D`.
406    /// - `target_accept_p`: Desired average acceptance probability for adaptation.
407    ///
408    /// # Returns
409    /// An initialized `NUTSChain`.
410    pub fn new(target: GTarget, initial_position: Vec<T>, target_accept_p: T) -> Self {
411        let dim = initial_position.len();
412        let td: TensorData = TensorData::new(initial_position, [dim]);
413        let position = Tensor::<B, 1>::from_data(td, &B::Device::default());
414        let rng = SmallRng::from_os_rng();
415        let epsilon = -T::one();
416
417        Self {
418            target,
419            position,
420            target_accept_p,
421            epsilon,
422            m: 0,
423            n_collect: 0,
424            n_discard: 0,
425            gamma: T::from(0.05).unwrap(),
426            t_0: 10,
427            kappa: T::from(0.75).unwrap(),
428            mu: (T::from(10.0).unwrap() * T::one()).ln(),
429            epsilon_bar: T::one(),
430            h_bar: T::zero(),
431            rng,
432            phantom_data: std::marker::PhantomData,
433        }
434    }
435
436    /// Sets a new random seed for this chain to ensure reproducibility.
437    ///
438    /// # Parameters
439    /// - `seed`: Seed value for the chain's RNG.
440    ///
441    /// # Returns
442    /// `self` with the RNG re-seeded.
443    pub fn set_seed(mut self, seed: u64) -> Self {
444        self.rng = SmallRng::seed_from_u64(seed);
445        self
446    }
447
448    /// Runs the chain for `n_collect + n_discard` steps, adapting during burn-in and
449    /// returning collected samples.
450    ///
451    /// # Parameters
452    /// - `n_collect`: Number of samples to collect after adaptation.
453    /// - `n_discard`: Number of burn-in steps for adaptation.
454    ///
455    /// # Returns
456    /// A 2D tensor of shape `[n_collect, D]` containing collected samples.
457    pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 2> {
458        let (dim, mut sample) = self.init_chain(n_collect, n_discard);
459
460        for m in 1..(n_collect + n_discard) {
461            self.step();
462
463            if m >= n_discard {
464                sample = sample.slice_assign(
465                    [m - n_discard..m - n_discard + 1, 0..dim],
466                    self.position.clone().unsqueeze(),
467                );
468            }
469        }
470        sample
471    }
472
473    fn run_progress(
474        &mut self,
475        n_collect: usize,
476        n_discard: usize,
477        tx: Sender<ChainStats>,
478    ) -> Result<Tensor<B, 2>, Box<dyn Error>> {
479        let (dim, mut sample) = self.init_chain(n_collect, n_discard);
480        let pos_0: Vec<f32> = self
481            .position
482            .to_data()
483            .iter()
484            .map(|x: T| ToElement::to_f32(&x))
485            .collect();
486        let mut tracker = ChainTracker::new(dim, &pos_0);
487        let mut last = Instant::now();
488        let freq = Duration::from_secs(1);
489        let total = n_discard + n_collect;
490
491        for i in 0..total {
492            self.step();
493            let pos_i: Vec<f32> = self
494                .position
495                .to_data()
496                .iter()
497                .map(|x: T| ToElement::to_f32(&x))
498                .collect();
499            tracker.step(&pos_i).map_err(|e| {
500                let msg = format!(
501                "Chain statistics tracker caused error: {}.\nAborting generation of further observations.",
502                e
503                );
504                println!("{}", msg);
505                msg
506            })?;
507
508            let now = Instant::now();
509            if (now >= last + freq) | (i == total - 1) {
510                if let Err(e) = tx.send(tracker.stats()) {
511                    eprintln!("Sending chain statistics failed: {e}");
512                }
513                last = now;
514            }
515
516            if i >= n_discard {
517                sample = sample.slice_assign(
518                    [i - n_discard..i - n_discard + 1, 0..dim],
519                    self.position.clone().unsqueeze(),
520                );
521            }
522        }
523
524        // TODO: Somehow save state of the chains and enable continuing runs
525        Ok(sample)
526    }
527
528    fn init_chain(&mut self, n_collect: usize, n_discard: usize) -> (usize, Tensor<B, 2>) {
529        let dim = self.position.dims()[0];
530        self.n_collect = n_collect;
531        self.n_discard = n_discard;
532
533        let mut sample = Tensor::<B, 2>::empty([n_collect, dim], &B::Device::default());
534        sample = sample.slice_assign([0..1, 0..dim], self.position.clone().unsqueeze());
535        let mom_0_data: Vec<T> = (&mut self.rng)
536            .sample_iter(StandardNormal)
537            .take(dim)
538            .collect();
539        let mom_0 = Tensor::<B, 1>::from_data(mom_0_data.as_slice(), &B::Device::default());
540        if T::abs(self.epsilon + T::one()) <= T::epsilon() {
541            self.epsilon = find_reasonable_epsilon(self.position.clone(), mom_0, &self.target);
542        }
543        self.mu = T::ln(T::from(10).unwrap() * self.epsilon);
544        (dim, sample)
545    }
546
547    /// Performs one NUTS update step, including tree expansion and adaptation updates.
548    ///
549    /// This method updates `self.position` and adaptation statistics in-place.
550    pub fn step(&mut self) {
551        self.m += 1;
552
553        let dim = self.position.dims()[0];
554        let mom_0 = (&mut self.rng)
555            .sample_iter(StandardNormal)
556            .take(dim)
557            .collect::<Vec<T>>();
558        let mom_0 = Tensor::<B, 1>::from_data(mom_0.as_slice(), &B::Device::default());
559        let (ulogp, grad) = self.target.unnorm_logp_and_grad(self.position.clone());
560        let joint = ulogp.clone() - (mom_0.clone() * mom_0.clone()).sum() * 0.5;
561        let joint =
562            T::from_f64(joint.into_scalar().to_f64()).expect("successful conversion from 64 to T");
563        let exp1_obs = self.rng.sample(Exp1);
564        let logu = joint - exp1_obs;
565
566        let mut position_minus = self.position.clone();
567        let mut position_plus = self.position.clone();
568        let mut mom_minus = mom_0.clone();
569        let mut mom_plus = mom_0.clone();
570        let mut grad_minus = grad.clone();
571        let mut grad_plus = grad.clone();
572        let mut j = 0;
573        let mut n = 1;
574        let mut s = true; // 's' stands for 'stop', indicating the stopping of inner while loop
575        let mut alpha: T = T::zero();
576        let mut n_alpha: usize = 0;
577
578        while s {
579            let u_run_1: T = self.rng.random::<T>();
580            let v = (2 * (u_run_1 < T::from(0.5).unwrap()) as i8) - 1;
581
582            let (position_prime, n_prime, s_prime) = {
583                if v == -1 {
584                    let (
585                        position_minus_2,
586                        mom_minus_2,
587                        grad_minus_2,
588                        _,
589                        _,
590                        _,
591                        position_prime_2,
592                        _,
593                        _,
594                        n_prime_2,
595                        s_prime_2,
596                        alpha_2,
597                        n_alpha_2,
598                    ) = build_tree(
599                        position_minus.clone(),
600                        mom_minus.clone(),
601                        grad_minus.clone(),
602                        logu,
603                        v,
604                        j,
605                        self.epsilon,
606                        &self.target,
607                        joint,
608                        &mut self.rng,
609                    );
610
611                    position_minus = position_minus_2;
612                    mom_minus = mom_minus_2;
613                    grad_minus = grad_minus_2;
614                    alpha = alpha_2;
615                    n_alpha = n_alpha_2;
616
617                    (position_prime_2, n_prime_2, s_prime_2)
618                } else {
619                    let (
620                        _,
621                        _,
622                        _,
623                        position_plus_2,
624                        mom_plus_2,
625                        grad_plus_2,
626                        position_prime_2,
627                        _,
628                        _,
629                        n_prime_2,
630                        s_prime_2,
631                        alpha_2,
632                        n_alpha_2,
633                    ) = build_tree(
634                        position_plus.clone(),
635                        mom_plus.clone(),
636                        grad_plus.clone(),
637                        logu,
638                        v,
639                        j,
640                        self.epsilon,
641                        &self.target,
642                        joint,
643                        &mut self.rng,
644                    );
645
646                    position_plus = position_plus_2;
647                    mom_plus = mom_plus_2;
648                    grad_plus = grad_plus_2;
649                    alpha = alpha_2;
650                    n_alpha = n_alpha_2;
651
652                    (position_prime_2, n_prime_2, s_prime_2)
653                }
654            };
655
656            let tmp = T::one().min(
657                T::from(n_prime).expect("successful conversion of n_prime from usize to T")
658                    / T::from(n).expect("successful conversion of n from usize to T"),
659            );
660            let u_run_2 = self.rng.random::<T>();
661            if s_prime && (u_run_2 < tmp) {
662                self.position = position_prime;
663            }
664            n += n_prime;
665
666            s = s_prime
667                && stop_criterion(
668                    position_minus.clone(),
669                    position_plus.clone(),
670                    mom_minus.clone(),
671                    mom_plus.clone(),
672                );
673            j += 1
674        }
675
676        let mut eta =
677            T::one() / T::from(self.m + self.t_0).expect("successful conversion of m + t_0 to T");
678        self.h_bar = (T::one() - eta) * self.h_bar
679            + eta
680                * (self.target_accept_p
681                    - alpha / T::from(n_alpha).expect("successful conversion of n_alpha to T"));
682        if self.m <= self.n_discard {
683            let _m = T::from(self.m).expect("successful conversion of m to T");
684            self.epsilon = T::exp(self.mu - T::sqrt(_m) / self.gamma * self.h_bar);
685            eta = _m.powf(-self.kappa);
686            self.epsilon_bar =
687                T::exp((T::one() - eta) * T::ln(self.epsilon_bar) + eta * T::ln(self.epsilon));
688        } else {
689            self.epsilon = self.epsilon_bar;
690        }
691    }
692}
693
694#[allow(dead_code)]
695fn find_reasonable_epsilon<B, T, GTarget>(
696    position: Tensor<B, 1>,
697    mom: Tensor<B, 1>,
698    gradient_target: &GTarget,
699) -> T
700where
701    T: Float + Element,
702    B: AutodiffBackend,
703    GTarget: GradientTarget<T, B> + Sync,
704{
705    let mut epsilon = T::one();
706    let half = T::from(0.5).unwrap();
707    let (ulogp, grad) = gradient_target.unnorm_logp_and_grad(position.clone());
708    let (_, mut mom_prime, grad_prime, mut ulogp_prime) = leapfrog(
709        position.clone(),
710        mom.clone(),
711        grad.clone(),
712        epsilon,
713        gradient_target,
714    );
715    let mut k = T::one();
716
717    while !all_real::<B, T>(ulogp_prime.clone()) && !all_real::<B, T>(grad_prime.clone()) {
718        k = k * half;
719        (_, mom_prime, _, ulogp_prime) = leapfrog(
720            position.clone(),
721            mom.clone(),
722            grad.clone(),
723            epsilon * k,
724            gradient_target,
725        );
726    }
727
728    epsilon = half * k * epsilon;
729    let log_accept_prob = ulogp_prime
730        - ulogp.clone()
731        - ((mom_prime.clone() * mom_prime).sum() - (mom.clone() * mom.clone()).sum()) * half;
732    let mut log_accept_prob = T::from(log_accept_prob.into_scalar().to_f64()).unwrap();
733
734    let a = if log_accept_prob > half.ln() {
735        T::one()
736    } else {
737        -T::one()
738    };
739
740    while a * log_accept_prob > -a * T::from(2.0).unwrap().ln() {
741        epsilon = epsilon * T::from(2.0).unwrap().powf(a);
742        (_, mom_prime, _, ulogp_prime) = leapfrog(
743            position.clone(),
744            mom.clone(),
745            grad.clone(),
746            epsilon,
747            gradient_target,
748        );
749        log_accept_prob = T::from(
750            (ulogp_prime
751                - ulogp.clone()
752                - ((mom_prime.clone() * mom_prime).sum() - (mom.clone() * mom.clone()).sum())
753                    * 0.5)
754                .into_scalar()
755                .to_f64(),
756        )
757        .unwrap();
758    }
759
760    epsilon
761}
762
763#[allow(clippy::too_many_arguments, clippy::type_complexity)]
764fn build_tree<B, T, GTarget>(
765    position: Tensor<B, 1>,
766    mom: Tensor<B, 1>,
767    grad: Tensor<B, 1>,
768    logu: T,
769    v: i8,
770    j: usize,
771    epsilon: T,
772    gradient_target: &GTarget,
773    joint_0: T,
774    rng: &mut SmallRng,
775) -> (
776    Tensor<B, 1>,
777    Tensor<B, 1>,
778    Tensor<B, 1>,
779    Tensor<B, 1>,
780    Tensor<B, 1>,
781    Tensor<B, 1>,
782    Tensor<B, 1>,
783    Tensor<B, 1>,
784    Tensor<B, 1>,
785    usize,
786    bool,
787    T,
788    usize,
789)
790where
791    T: Float + Element,
792    B: AutodiffBackend,
793    GTarget: GradientTarget<T, B> + Sync,
794{
795    if j == 0 {
796        let (position_prime, mom_prime, grad_prime, logp_prime) = leapfrog(
797            position.clone(),
798            mom.clone(),
799            grad.clone(),
800            T::from(v as i32).unwrap() * epsilon,
801            gradient_target,
802        );
803        let joint = logp_prime.clone() - (mom_prime.clone() * mom_prime.clone()).sum() * 0.5;
804        let joint = T::from(joint.into_scalar().to_f64())
805            .expect("type conversion from joint tensor to scalar type T to succeed");
806        let n_prime = (logu < joint) as usize;
807        let s_prime = (logu - T::from(1000.0).unwrap()) < joint;
808        let position_minus = position_prime.clone();
809        let position_plus = position_prime.clone();
810        let mom_minus = mom_prime.clone();
811        let mom_plus = mom_prime.clone();
812        let grad_minus = grad_prime.clone();
813        let grad_plus = grad_prime.clone();
814        let alpha_prime = T::min(T::one(), (joint - joint_0).exp());
815        let n_alpha_prime = 1_usize;
816        (
817            position_minus,
818            mom_minus,
819            grad_minus,
820            position_plus,
821            mom_plus,
822            grad_plus,
823            position_prime,
824            grad_prime,
825            logp_prime,
826            n_prime,
827            s_prime,
828            alpha_prime,
829            n_alpha_prime,
830        )
831    } else {
832        let (
833            mut position_minus,
834            mut mom_minus,
835            mut grad_minus,
836            mut position_plus,
837            mut mom_plus,
838            mut grad_plus,
839            mut position_prime,
840            mut grad_prime,
841            mut logp_prime,
842            mut n_prime,
843            mut s_prime,
844            mut alpha_prime,
845            mut n_alpha_prime,
846        ) = build_tree(
847            position,
848            mom,
849            grad,
850            logu,
851            v,
852            j - 1,
853            epsilon,
854            gradient_target,
855            joint_0,
856            rng,
857        );
858        if s_prime {
859            let (
860                position_minus_2,
861                mom_minus_2,
862                grad_minus_2,
863                position_plus_2,
864                mom_plus_2,
865                grad_plus_2,
866                position_prime_2,
867                grad_prime_2,
868                logp_prime_2,
869                n_prime_2,
870                s_prime_2,
871                alpha_prime_2,
872                n_alpha_prime_2,
873            ) = if v == -1 {
874                build_tree(
875                    position_minus.clone(),
876                    mom_minus.clone(),
877                    grad_minus.clone(),
878                    logu,
879                    v,
880                    j - 1,
881                    epsilon,
882                    gradient_target,
883                    joint_0,
884                    rng,
885                )
886            } else {
887                build_tree(
888                    position_plus.clone(),
889                    mom_plus.clone(),
890                    grad_plus.clone(),
891                    logu,
892                    v,
893                    j - 1,
894                    epsilon,
895                    gradient_target,
896                    joint_0,
897                    rng,
898                )
899            };
900            if v == -1 {
901                position_minus = position_minus_2;
902                mom_minus = mom_minus_2;
903                grad_minus = grad_minus_2;
904            } else {
905                position_plus = position_plus_2;
906                mom_plus = mom_plus_2;
907                grad_plus = grad_plus_2;
908            }
909
910            let u_build_tree: f64 = (*rng).random::<f64>();
911            if u_build_tree < (n_prime_2 as f64 / (n_prime + n_prime_2).max(1) as f64) {
912                position_prime = position_prime_2;
913                grad_prime = grad_prime_2;
914                logp_prime = logp_prime_2;
915            }
916
917            n_prime += n_prime_2;
918
919            s_prime = s_prime
920                && s_prime_2
921                && stop_criterion(
922                    position_minus.clone(),
923                    position_plus.clone(),
924                    mom_minus.clone(),
925                    mom_plus.clone(),
926                );
927            alpha_prime = alpha_prime + alpha_prime_2;
928            n_alpha_prime += n_alpha_prime_2;
929        }
930        (
931            position_minus,
932            mom_minus,
933            grad_minus,
934            position_plus,
935            mom_plus,
936            grad_plus,
937            position_prime,
938            grad_prime,
939            logp_prime,
940            n_prime,
941            s_prime,
942            alpha_prime,
943            n_alpha_prime,
944        )
945    }
946}
947
948fn all_real<B, T>(x: Tensor<B, 1>) -> bool
949where
950    T: Float + Element,
951    B: AutodiffBackend,
952{
953    x.clone()
954        .equal_elem(T::infinity())
955        .bool_or(x.clone().equal_elem(T::neg_infinity()))
956        .bool_or(x.is_nan())
957        .any()
958        .bool_not()
959        .into_scalar()
960        .to_bool()
961}
962
963fn stop_criterion<B>(
964    position_minus: Tensor<B, 1>,
965    position_plus: Tensor<B, 1>,
966    mom_minus: Tensor<B, 1>,
967    mom_plus: Tensor<B, 1>,
968) -> bool
969where
970    B: AutodiffBackend,
971{
972    let diff = position_plus - position_minus;
973    let dot_minus = (diff.clone() * mom_minus).sum();
974    let dot_plus = (diff * mom_plus).sum();
975    dot_minus.greater_equal_elem(0).into_scalar().to_bool()
976        && dot_plus.greater_equal_elem(0).into_scalar().to_bool()
977}
978
979fn leapfrog<B, T, GTarget>(
980    position: Tensor<B, 1>,
981    mom: Tensor<B, 1>,
982    grad: Tensor<B, 1>,
983    epsilon: T,
984    gradient_target: &GTarget,
985) -> (Tensor<B, 1>, Tensor<B, 1>, Tensor<B, 1>, Tensor<B, 1>)
986where
987    T: Float + ElementConversion,
988    B: AutodiffBackend,
989    GTarget: GradientTarget<T, B>,
990{
991    let mom_prime = mom + grad * epsilon * 0.5;
992    let position_prime = position + mom_prime.clone() * epsilon;
993    let (ulogp_prime, grad_prime) = gradient_target.unnorm_logp_and_grad(position_prime.clone());
994    let mom_prime = mom_prime + grad_prime.clone() * epsilon * 0.5;
995    (position_prime, mom_prime, grad_prime, ulogp_prime)
996}
997
998#[cfg(test)]
999mod tests {
1000    use std::fmt::Debug;
1001
1002    use crate::{
1003        core::init,
1004        dev_tools::Timer,
1005        distributions::{DiffableGaussian2D, Rosenbrock2D},
1006        stats::split_rhat_mean_ess,
1007    };
1008
1009    #[cfg(feature = "csv")]
1010    use crate::io::csv::save_csv_tensor;
1011
1012    use super::*;
1013    use burn::{
1014        backend::{Autodiff, NdArray},
1015        tensor::{Tensor, Tolerance},
1016    };
1017    use ndarray::ArrayView3;
1018    use ndarray_stats::QuantileExt;
1019    use num_traits::Float;
1020
1021    // Use the CPU backend (NdArray) wrapped in Autodiff.
1022    type BackendType = Autodiff<NdArray>;
1023
1024    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1025    pub struct StandardNormal;
1026
1027    impl<T, B> GradientTarget<T, B> for StandardNormal
1028    where
1029        T: Float + Debug + ElementConversion + Element,
1030        B: AutodiffBackend,
1031    {
1032        fn unnorm_logp(&self, positions: Tensor<B, 1>) -> Tensor<B, 1> {
1033            let sq = positions.clone().powi_scalar(2);
1034            let half = T::from(0.5).unwrap();
1035            -(sq.mul_scalar(half)).sum()
1036        }
1037    }
1038
1039    fn assert_tensor_approx_eq<T: Backend, F: Float + burn::tensor::Element>(
1040        actual: Tensor<T, 1>,
1041        expected: &[f64],
1042        tol: Tolerance<F>,
1043    ) {
1044        let a = actual.clone().to_data();
1045        let e = Tensor::<T, 1>::from(expected).to_data();
1046        a.assert_approx_eq(&e, tol);
1047    }
1048
1049    #[test]
1050    fn test_find_reasonable_epsilon() {
1051        let position = Tensor::<BackendType, 1>::from([0.0, 1.0]);
1052        let mom = Tensor::<BackendType, 1>::from([1.0, 0.0]);
1053        let epsilon = find_reasonable_epsilon::<_, f64, _>(position, mom, &StandardNormal);
1054        assert_eq!(epsilon, 2.0);
1055    }
1056
1057    #[test]
1058    fn test_build_tree() {
1059        let gradient_target = DiffableGaussian2D::new([0.0_f64, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
1060        let position = Tensor::<BackendType, 1>::from([0.0, 1.0]);
1061        let mom = Tensor::<BackendType, 1>::from([2.0, 3.0]);
1062        let grad = Tensor::<BackendType, 1>::from([4.0, 5.0]);
1063        let logu = -2.0;
1064        let v: i8 = -1;
1065        let j: usize = 3;
1066        let epsilon: f64 = 0.01;
1067        let joint_0 = 0.1_f64;
1068        let mut rng = SmallRng::seed_from_u64(0);
1069        let (
1070            position_minus,
1071            mom_minus,
1072            grad_minus,
1073            position_plus,
1074            mom_plus,
1075            grad_plus,
1076            position_prime,
1077            grad_prime,
1078            logp_prime,
1079            n_prime,
1080            s_prime,
1081            alpha_prime,
1082            n_alpha_prime,
1083        ) = build_tree::<BackendType, f64, _>(
1084            position,
1085            mom,
1086            grad,
1087            logu,
1088            v,
1089            j,
1090            epsilon,
1091            &gradient_target,
1092            joint_0,
1093            &mut rng,
1094        );
1095        let tol = Tolerance::<f64>::default()
1096            .set_relative(1e-5)
1097            .set_absolute(1e-6);
1098
1099        assert_tensor_approx_eq(position_minus, &[-0.1584001, 0.76208336], tol);
1100        assert_tensor_approx_eq(mom_minus, &[1.980_003_6, 2.971_825_3], tol);
1101        assert_tensor_approx_eq(grad_minus, &[-7.912_36e-5, 7.935_829_5e-2], tol);
1102
1103        assert_tensor_approx_eq(position_plus, &[-0.0198, 0.97025], tol);
1104        assert_tensor_approx_eq(mom_plus, &[1.98, 2.974_950_3], tol);
1105        assert_tensor_approx_eq(grad_plus, &[-1.250e-05, 9.925e-03], tol);
1106
1107        assert_tensor_approx_eq(position_prime, &[-0.0198, 0.97025], tol);
1108        assert_tensor_approx_eq(grad_prime, &[-1.250e-05, 9.925e-03], tol);
1109
1110        assert_eq!(n_prime, 0);
1111        assert!(s_prime);
1112        assert_eq!(n_alpha_prime, 8);
1113
1114        let logp_exp = -2.877_745_4_f64;
1115        let alpha_exp = 0.000_686_661_7_f64;
1116        assert!(
1117            (logp_prime.into_scalar().to_f64() - logp_exp).abs() < 1e-6,
1118            "logp mismatch"
1119        );
1120        assert!((alpha_prime - alpha_exp).abs() < 1e-8, "alpha mismatch");
1121    }
1122
1123    #[test]
1124    fn test_chain_1() {
1125        let target = DiffableGaussian2D::new([0.0_f64, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
1126        let initial_positions = vec![0.0_f64, 1.0];
1127        let n_discard = 0;
1128        let n_collect = 1;
1129        let mut sampler = NUTSChain::new(target, initial_positions, 0.8).set_seed(42);
1130        let sample: Tensor<BackendType, 2> = sampler.run(n_collect, n_discard);
1131        assert_eq!(sample.dims(), [n_collect, 2]);
1132        let tol = Tolerance::<f64>::default()
1133            .set_relative(1e-5)
1134            .set_absolute(1e-6);
1135        assert_tensor_approx_eq(sample.flatten(0, 1), &[0.0, 1.0], tol);
1136    }
1137
1138    #[test]
1139    fn test_chain_2() {
1140        let target = DiffableGaussian2D::new([0.0_f64, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
1141        let initial_positions = vec![0.0_f64, 1.0];
1142        let n_discard = 3;
1143        let n_collect = 3;
1144        let mut sampler = NUTSChain::new(target, initial_positions, 0.8).set_seed(42);
1145        let sample: Tensor<BackendType, 2> = sampler.run(n_collect, n_discard);
1146        assert_eq!(sample.dims(), [n_collect, 2]);
1147        let tol = Tolerance::<f64>::default()
1148            .set_relative(1e-5)
1149            .set_absolute(1e-6);
1150        assert_tensor_approx_eq(
1151            sample.flatten(0, 1),
1152            &[
1153                -1.168318748474121,
1154                -0.4077277183532715,
1155                -1.8463939428329468,
1156                0.19176559150218964,
1157                -1.0662782192230225,
1158                -0.3948383331298828,
1159            ],
1160            tol,
1161        );
1162    }
1163
1164    #[test]
1165    fn test_chain_3() {
1166        let target = DiffableGaussian2D::new([1.0_f64, 2.0], [[1.0, 2.0], [2.0, 5.0]]);
1167        let initial_positions = vec![-2.0_f64, 1.0];
1168        let n_discard = 5;
1169        let n_collect = 5;
1170        let mut sampler = NUTSChain::new(target, initial_positions, 0.8).set_seed(42);
1171        let sample: Tensor<BackendType, 2> = sampler.run(n_collect, n_discard);
1172        assert_eq!(sample.dims(), [n_collect, 2]);
1173        let tol = Tolerance::<f64>::default()
1174            .set_relative(1e-5)
1175            .set_absolute(1e-6);
1176        assert_tensor_approx_eq(
1177            sample.flatten(0, 1),
1178            &[
1179                2.653707265853882,
1180                5.560618877410889,
1181                2.9760334491729736,
1182                6.325948715209961,
1183                2.187873125076294,
1184                5.611990928649902,
1185                2.1512224674224854,
1186                5.416507720947266,
1187                2.4165120124816895,
1188                3.9120564460754395,
1189            ],
1190            tol,
1191        );
1192    }
1193
1194    #[test]
1195    fn test_run_1() {
1196        let target = DiffableGaussian2D::new([1.0_f64, 2.0], [[1.0, 2.0], [2.0, 5.0]]);
1197        let initial_positions = vec![vec![-2_f64, 1.0]];
1198        let n_discard = 5;
1199        let n_collect = 5;
1200        let mut sampler = NUTS::new(target, initial_positions, 0.8).set_seed(41);
1201        let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
1202        assert_eq!(sample.dims(), [1, n_collect, 2]);
1203        let tol = Tolerance::<f64>::default()
1204            .set_relative(1e-5)
1205            .set_absolute(1e-6);
1206        assert_tensor_approx_eq(
1207            sample.flatten(0, 2),
1208            &[
1209                2.653707265853882,
1210                5.560618877410889,
1211                2.9760334491729736,
1212                6.325948715209961,
1213                2.187873125076294,
1214                5.611990928649902,
1215                2.1512224674224854,
1216                5.416507720947266,
1217                2.4165120124816895,
1218                3.9120564460754395,
1219            ],
1220            tol,
1221        );
1222    }
1223
1224    #[test]
1225    fn test_progress_1() {
1226        let target = Rosenbrock2D {
1227            a: 1.0_f32,
1228            b: 100.0_f32,
1229        };
1230
1231        // We'll define 6 chains all initialized to (1.0, 2.0).
1232        let initial_positions = init::<f32>(6, 2);
1233        let n_collect = 10;
1234        let n_discard = 10;
1235
1236        let mut sampler =
1237            NUTS::<_, BackendType, _>::new(target, initial_positions, 0.95).set_seed(42);
1238        let (sample, stats) = sampler.run_progress(n_collect, n_discard).unwrap();
1239        println!(
1240            "NUTS sampler: generated {} observations.",
1241            sample.dims()[0..2].iter().product::<usize>()
1242        );
1243        assert_eq!(sample.dims(), [6, n_collect, 2]);
1244
1245        println!("Statistics: {stats}");
1246
1247        #[cfg(feature = "csv")]
1248        save_csv_tensor(sample, "/tmp/nuts-sample.csv").expect("saving data should succeed")
1249    }
1250
1251    #[test]
1252    #[ignore = "Benchmark test: run only when explicitly requested"]
1253    fn test_bench_noprogress_1() {
1254        let target = Rosenbrock2D {
1255            a: 1.0_f32,
1256            b: 100.0_f32,
1257        };
1258
1259        // We'll define 6 chains all initialized to (1.0, 2.0).
1260        let initial_positions = init::<f32>(6, 2);
1261        let n_collect = 5000;
1262        let n_discard = 500;
1263
1264        let mut sampler = NUTS::new(target, initial_positions, 0.95).set_seed(42);
1265        let mut timer = Timer::new();
1266        let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
1267        timer.log(format!(
1268            "NUTS sampler: generated {} observations.",
1269            sample.dims()[0..2].iter().product::<usize>()
1270        ));
1271        assert_eq!(sample.dims(), [6, 5000, 2]);
1272
1273        let data = sample.to_data();
1274        let array = ArrayView3::from_shape(sample.dims(), data.as_slice().unwrap()).unwrap();
1275        let (split_rhat, ess) = split_rhat_mean_ess(array);
1276        println!("AVG Split Rhat: {}", split_rhat.mean().unwrap());
1277        println!("AVG ESS: {}", ess.mean().unwrap());
1278
1279        #[cfg(feature = "csv")]
1280        save_csv_tensor(sample, "/tmp/nuts-sample.csv").expect("saving data should succeed")
1281    }
1282
1283    #[test]
1284    #[ignore = "Benchmark test: run only when explicitly requested"]
1285    fn test_bench_noprogress_2() {
1286        let target = Rosenbrock2D {
1287            a: 1.0_f32,
1288            b: 100.0_f32,
1289        };
1290
1291        // We'll define 6 chains all initialized to (1.0, 2.0).
1292        let initial_positions = init::<f32>(6, 2);
1293        let n_collect = 1000;
1294        let n_discard = 1000;
1295
1296        let mut sampler = NUTS::new(target, initial_positions, 0.95).set_seed(42);
1297        let mut timer = Timer::new();
1298        let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
1299        timer.log(format!(
1300            "NUTS sampler: generated {} observations.",
1301            sample.dims()[0..2].iter().product::<usize>()
1302        ));
1303        assert_eq!(sample.dims(), [6, 1000, 2]);
1304
1305        let data = sample.to_data();
1306        let array = ArrayView3::from_shape(sample.dims(), data.as_slice().unwrap()).unwrap();
1307        let (split_rhat, ess) = split_rhat_mean_ess(array);
1308        println!("MIN Split Rhat: {}", split_rhat.min().unwrap());
1309        println!("MIN ESS: {}", ess.min().unwrap());
1310
1311        #[cfg(feature = "csv")]
1312        save_csv_tensor(sample, "/tmp/nuts-sample.csv").expect("saving data should succeed")
1313    }
1314}