1use std::error::Error;
38use std::sync::mpsc;
39use std::sync::mpsc::{Receiver, Sender};
40use std::thread;
41use std::time::{Duration, Instant};
42
43use crate::distributions::GradientTarget;
44use crate::stats::{collect_rhat, ChainStats, ChainTracker, RunStats};
45use burn::prelude::*;
46use burn::tensor::backend::AutodiffBackend;
47use burn::tensor::cast::ToElement;
48use burn::tensor::Element;
49use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
50use ndarray::ArrayView3;
51use ndarray_stats::QuantileExt;
52use num_traits::{Float, FromPrimitive};
53use rand::prelude::*;
54use rand::Rng;
55use rand_distr::uniform::SampleUniform;
56use rand_distr::{Exp1, StandardNormal, StandardUniform};
57use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
58
59#[derive(Debug, Clone)]
70pub struct NUTS<T, B, GTarget>
71where
72 T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
73 B: AutodiffBackend,
74 GTarget: GradientTarget<T, B> + Sync,
75 StandardNormal: rand::distr::Distribution<T>,
76 StandardUniform: rand_distr::Distribution<T>,
77 rand_distr::Exp1: rand_distr::Distribution<T>,
78{
79 chains: Vec<NUTSChain<T, B, GTarget>>,
81}
82
83impl<T, B, GTarget> NUTS<T, B, GTarget>
84where
85 T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Send,
86 B: AutodiffBackend + Send,
87 GTarget: GradientTarget<T, B> + Sync + Clone + Send,
88 StandardNormal: rand::distr::Distribution<T>,
89 StandardUniform: rand_distr::Distribution<T>,
90 rand_distr::Exp1: rand_distr::Distribution<T>,
91{
92 pub fn new(target: GTarget, initial_positions: Vec<Vec<T>>, target_accept_p: T) -> Self {
124 let chains = initial_positions
125 .into_iter()
126 .map(|pos| NUTSChain::new(target.clone(), pos, target_accept_p))
127 .collect();
128 Self { chains }
129 }
130
131 pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3> {
164 let chain_samples: Vec<Tensor<B, 2>> = self
165 .chains
166 .par_iter_mut()
167 .map(|chain| chain.run(n_collect, n_discard))
168 .collect();
169 Tensor::<B, 2>::stack(chain_samples, 0)
170 }
171
172 pub fn run_progress(
195 &mut self,
196 n_collect: usize,
197 n_discard: usize,
198 ) -> Result<(Tensor<B, 3>, RunStats), Box<dyn Error>> {
199 let chains = &mut self.chains;
200
201 let mut rxs: Vec<Receiver<ChainStats>> = vec![];
202 let mut txs: Vec<Sender<ChainStats>> = vec![];
203 (0..chains.len()).for_each(|_| {
204 let (tx, rx) = mpsc::channel();
205 rxs.push(rx);
206 txs.push(tx);
207 });
208
209 let progress_handle = thread::spawn(move || {
210 let sleep_ms = Duration::from_millis(250);
211 let timeout_ms = Duration::from_millis(0);
212 let multi = MultiProgress::new();
213
214 let pb_style = ProgressStyle::default_bar()
215 .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
216 .unwrap()
217 .progress_chars("=>-");
218 let total: u64 = (n_collect + n_discard).try_into().unwrap();
219
220 let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
222 global_pb.set_style(pb_style.clone());
223 global_pb.set_prefix("Global");
224
225 let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
226 .map(|chain_idx| {
227 let pb = multi.add(ProgressBar::new(total));
228 pb.set_style(pb_style.clone());
229 pb.set_prefix(format!("Chain {chain_idx}"));
230 (chain_idx, pb)
231 })
232 .collect();
233 let mut next_active = active.len();
234 let mut n_finished = 0;
235 let mut most_recent = vec![None; rxs.len()];
236 let mut total_progress;
237
238 loop {
239 for (i, rx) in rxs.iter().enumerate() {
240 while let Ok(stats) = rx.recv_timeout(timeout_ms) {
241 most_recent[i] = Some(stats)
242 }
243 }
244
245 let mut to_replace = vec![false; active.len()];
248 let mut avg_p_accept = 0.0;
249 let mut n_available_stats = 0.0;
250 for (vec_idx, (i, pb)) in active.iter().enumerate() {
251 if let Some(stats) = &most_recent[*i] {
252 pb.set_position(stats.n);
253 pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
254 avg_p_accept += stats.p_accept;
255 n_available_stats += 1.0;
256
257 if stats.n == total {
258 to_replace[vec_idx] = true;
259 n_finished += 1;
260 }
261 }
262 }
263 avg_p_accept /= n_available_stats;
264
265 total_progress = 0;
267 for stats in most_recent.iter().flatten() {
268 total_progress += stats.n;
269 }
270 global_pb.set_position(total_progress);
271 let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
272 if valid.len() >= 2 {
273 let rhats = collect_rhat(valid.as_slice());
274 let max = rhats.max_skipnan();
275 global_pb.set_message(format!(
276 "p(accept)≈{:.2} max(rhat)≈{:.2}",
277 avg_p_accept, max
278 ))
279 }
280
281 let mut to_remove = vec![];
282 for (i, replace) in to_replace.iter().enumerate() {
283 if *replace && next_active < most_recent.len() {
284 let pb = multi.add(ProgressBar::new(total));
285 pb.set_style(pb_style.clone());
286 pb.set_prefix(format!("Chain {next_active}"));
287 active[i] = (next_active, pb);
288 next_active += 1;
289 } else if *replace {
290 to_remove.push(i);
291 }
292 }
293
294 to_remove.sort();
295 for i in to_remove.iter().rev() {
296 active.remove(*i);
297 }
298
299 if n_finished >= most_recent.len() {
300 break;
301 }
302 std::thread::sleep(sleep_ms);
303 }
304 });
305
306 let chain_sample: Vec<Tensor<B, 2>> = thread::scope(|s| {
307 let handles: Vec<thread::ScopedJoinHandle<Tensor<B, 2>>> = chains
308 .iter_mut()
309 .zip(txs)
310 .map(|(chain, tx)| {
311 s.spawn(|| {
312 chain
313 .run_progress(n_collect, n_discard, tx)
314 .expect("Expected running chain to succeed.")
315 })
316 })
317 .collect();
318 handles
319 .into_iter()
320 .map(|h| {
321 h.join()
322 .expect("Expected thread to succeed in generating observation.")
323 })
324 .collect()
325 });
326 let sample = Tensor::<B, 2>::stack(chain_sample, 0);
327
328 if let Err(e) = progress_handle.join() {
329 eprintln!("Progress bar thread emitted error message: {:?}", e);
330 }
331
332 let sample_f32 = sample.to_data();
333 let view =
334 ArrayView3::<f32>::from_shape(sample.dims(), sample_f32.as_slice().unwrap()).unwrap();
335 let run_stats = RunStats::from(view);
336
337 Ok((sample, run_stats))
338 }
339
340 pub fn set_seed(mut self, seed: u64) -> Self {
348 for (i, chain) in self.chains.iter_mut().enumerate() {
349 let chain_seed = seed + i as u64 + 1;
350 chain.rng = SmallRng::seed_from_u64(chain_seed);
351 }
352 self
353 }
354}
355
356#[derive(Debug, Clone)]
361pub struct NUTSChain<T, B, GTarget>
362where
363 B: AutodiffBackend,
364{
365 target: GTarget,
367
368 pub position: Tensor<B, 1>,
370
371 target_accept_p: T,
373
374 epsilon: T,
376
377 m: usize,
379 n_collect: usize,
380 n_discard: usize,
381 gamma: T,
382 t_0: usize,
383 kappa: T,
384 mu: T,
385 epsilon_bar: T,
386 h_bar: T,
387
388 rng: SmallRng,
389 phantom_data: std::marker::PhantomData<T>,
390}
391
392impl<T, B, GTarget> NUTSChain<T, B, GTarget>
393where
394 T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
395 B: AutodiffBackend,
396 GTarget: GradientTarget<T, B> + std::marker::Sync,
397 StandardNormal: rand::distr::Distribution<T>,
398 StandardUniform: rand_distr::Distribution<T>,
399 rand_distr::Exp1: rand_distr::Distribution<T>,
400{
401 pub fn new(target: GTarget, initial_position: Vec<T>, target_accept_p: T) -> Self {
411 let dim = initial_position.len();
412 let td: TensorData = TensorData::new(initial_position, [dim]);
413 let position = Tensor::<B, 1>::from_data(td, &B::Device::default());
414 let rng = SmallRng::from_os_rng();
415 let epsilon = -T::one();
416
417 Self {
418 target,
419 position,
420 target_accept_p,
421 epsilon,
422 m: 0,
423 n_collect: 0,
424 n_discard: 0,
425 gamma: T::from(0.05).unwrap(),
426 t_0: 10,
427 kappa: T::from(0.75).unwrap(),
428 mu: (T::from(10.0).unwrap() * T::one()).ln(),
429 epsilon_bar: T::one(),
430 h_bar: T::zero(),
431 rng,
432 phantom_data: std::marker::PhantomData,
433 }
434 }
435
436 pub fn set_seed(mut self, seed: u64) -> Self {
444 self.rng = SmallRng::seed_from_u64(seed);
445 self
446 }
447
448 pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 2> {
458 let (dim, mut sample) = self.init_chain(n_collect, n_discard);
459
460 for m in 1..(n_collect + n_discard) {
461 self.step();
462
463 if m >= n_discard {
464 sample = sample.slice_assign(
465 [m - n_discard..m - n_discard + 1, 0..dim],
466 self.position.clone().unsqueeze(),
467 );
468 }
469 }
470 sample
471 }
472
473 fn run_progress(
474 &mut self,
475 n_collect: usize,
476 n_discard: usize,
477 tx: Sender<ChainStats>,
478 ) -> Result<Tensor<B, 2>, Box<dyn Error>> {
479 let (dim, mut sample) = self.init_chain(n_collect, n_discard);
480 let pos_0: Vec<f32> = self
481 .position
482 .to_data()
483 .iter()
484 .map(|x: T| ToElement::to_f32(&x))
485 .collect();
486 let mut tracker = ChainTracker::new(dim, &pos_0);
487 let mut last = Instant::now();
488 let freq = Duration::from_secs(1);
489 let total = n_discard + n_collect;
490
491 for i in 0..total {
492 self.step();
493 let pos_i: Vec<f32> = self
494 .position
495 .to_data()
496 .iter()
497 .map(|x: T| ToElement::to_f32(&x))
498 .collect();
499 tracker.step(&pos_i).map_err(|e| {
500 let msg = format!(
501 "Chain statistics tracker caused error: {}.\nAborting generation of further observations.",
502 e
503 );
504 println!("{}", msg);
505 msg
506 })?;
507
508 let now = Instant::now();
509 if (now >= last + freq) | (i == total - 1) {
510 if let Err(e) = tx.send(tracker.stats()) {
511 eprintln!("Sending chain statistics failed: {e}");
512 }
513 last = now;
514 }
515
516 if i >= n_discard {
517 sample = sample.slice_assign(
518 [i - n_discard..i - n_discard + 1, 0..dim],
519 self.position.clone().unsqueeze(),
520 );
521 }
522 }
523
524 Ok(sample)
526 }
527
528 fn init_chain(&mut self, n_collect: usize, n_discard: usize) -> (usize, Tensor<B, 2>) {
529 let dim = self.position.dims()[0];
530 self.n_collect = n_collect;
531 self.n_discard = n_discard;
532
533 let mut sample = Tensor::<B, 2>::empty([n_collect, dim], &B::Device::default());
534 sample = sample.slice_assign([0..1, 0..dim], self.position.clone().unsqueeze());
535 let mom_0_data: Vec<T> = (&mut self.rng)
536 .sample_iter(StandardNormal)
537 .take(dim)
538 .collect();
539 let mom_0 = Tensor::<B, 1>::from_data(mom_0_data.as_slice(), &B::Device::default());
540 if T::abs(self.epsilon + T::one()) <= T::epsilon() {
541 self.epsilon = find_reasonable_epsilon(self.position.clone(), mom_0, &self.target);
542 }
543 self.mu = T::ln(T::from(10).unwrap() * self.epsilon);
544 (dim, sample)
545 }
546
547 pub fn step(&mut self) {
551 self.m += 1;
552
553 let dim = self.position.dims()[0];
554 let mom_0 = (&mut self.rng)
555 .sample_iter(StandardNormal)
556 .take(dim)
557 .collect::<Vec<T>>();
558 let mom_0 = Tensor::<B, 1>::from_data(mom_0.as_slice(), &B::Device::default());
559 let (ulogp, grad) = self.target.unnorm_logp_and_grad(self.position.clone());
560 let joint = ulogp.clone() - (mom_0.clone() * mom_0.clone()).sum() * 0.5;
561 let joint =
562 T::from_f64(joint.into_scalar().to_f64()).expect("successful conversion from 64 to T");
563 let exp1_obs = self.rng.sample(Exp1);
564 let logu = joint - exp1_obs;
565
566 let mut position_minus = self.position.clone();
567 let mut position_plus = self.position.clone();
568 let mut mom_minus = mom_0.clone();
569 let mut mom_plus = mom_0.clone();
570 let mut grad_minus = grad.clone();
571 let mut grad_plus = grad.clone();
572 let mut j = 0;
573 let mut n = 1;
574 let mut s = true; let mut alpha: T = T::zero();
576 let mut n_alpha: usize = 0;
577
578 while s {
579 let u_run_1: T = self.rng.random::<T>();
580 let v = (2 * (u_run_1 < T::from(0.5).unwrap()) as i8) - 1;
581
582 let (position_prime, n_prime, s_prime) = {
583 if v == -1 {
584 let (
585 position_minus_2,
586 mom_minus_2,
587 grad_minus_2,
588 _,
589 _,
590 _,
591 position_prime_2,
592 _,
593 _,
594 n_prime_2,
595 s_prime_2,
596 alpha_2,
597 n_alpha_2,
598 ) = build_tree(
599 position_minus.clone(),
600 mom_minus.clone(),
601 grad_minus.clone(),
602 logu,
603 v,
604 j,
605 self.epsilon,
606 &self.target,
607 joint,
608 &mut self.rng,
609 );
610
611 position_minus = position_minus_2;
612 mom_minus = mom_minus_2;
613 grad_minus = grad_minus_2;
614 alpha = alpha_2;
615 n_alpha = n_alpha_2;
616
617 (position_prime_2, n_prime_2, s_prime_2)
618 } else {
619 let (
620 _,
621 _,
622 _,
623 position_plus_2,
624 mom_plus_2,
625 grad_plus_2,
626 position_prime_2,
627 _,
628 _,
629 n_prime_2,
630 s_prime_2,
631 alpha_2,
632 n_alpha_2,
633 ) = build_tree(
634 position_plus.clone(),
635 mom_plus.clone(),
636 grad_plus.clone(),
637 logu,
638 v,
639 j,
640 self.epsilon,
641 &self.target,
642 joint,
643 &mut self.rng,
644 );
645
646 position_plus = position_plus_2;
647 mom_plus = mom_plus_2;
648 grad_plus = grad_plus_2;
649 alpha = alpha_2;
650 n_alpha = n_alpha_2;
651
652 (position_prime_2, n_prime_2, s_prime_2)
653 }
654 };
655
656 let tmp = T::one().min(
657 T::from(n_prime).expect("successful conversion of n_prime from usize to T")
658 / T::from(n).expect("successful conversion of n from usize to T"),
659 );
660 let u_run_2 = self.rng.random::<T>();
661 if s_prime && (u_run_2 < tmp) {
662 self.position = position_prime;
663 }
664 n += n_prime;
665
666 s = s_prime
667 && stop_criterion(
668 position_minus.clone(),
669 position_plus.clone(),
670 mom_minus.clone(),
671 mom_plus.clone(),
672 );
673 j += 1
674 }
675
676 let mut eta =
677 T::one() / T::from(self.m + self.t_0).expect("successful conversion of m + t_0 to T");
678 self.h_bar = (T::one() - eta) * self.h_bar
679 + eta
680 * (self.target_accept_p
681 - alpha / T::from(n_alpha).expect("successful conversion of n_alpha to T"));
682 if self.m <= self.n_discard {
683 let _m = T::from(self.m).expect("successful conversion of m to T");
684 self.epsilon = T::exp(self.mu - T::sqrt(_m) / self.gamma * self.h_bar);
685 eta = _m.powf(-self.kappa);
686 self.epsilon_bar =
687 T::exp((T::one() - eta) * T::ln(self.epsilon_bar) + eta * T::ln(self.epsilon));
688 } else {
689 self.epsilon = self.epsilon_bar;
690 }
691 }
692}
693
694#[allow(dead_code)]
695fn find_reasonable_epsilon<B, T, GTarget>(
696 position: Tensor<B, 1>,
697 mom: Tensor<B, 1>,
698 gradient_target: >arget,
699) -> T
700where
701 T: Float + Element,
702 B: AutodiffBackend,
703 GTarget: GradientTarget<T, B> + Sync,
704{
705 let mut epsilon = T::one();
706 let half = T::from(0.5).unwrap();
707 let (ulogp, grad) = gradient_target.unnorm_logp_and_grad(position.clone());
708 let (_, mut mom_prime, grad_prime, mut ulogp_prime) = leapfrog(
709 position.clone(),
710 mom.clone(),
711 grad.clone(),
712 epsilon,
713 gradient_target,
714 );
715 let mut k = T::one();
716
717 while !all_real::<B, T>(ulogp_prime.clone()) && !all_real::<B, T>(grad_prime.clone()) {
718 k = k * half;
719 (_, mom_prime, _, ulogp_prime) = leapfrog(
720 position.clone(),
721 mom.clone(),
722 grad.clone(),
723 epsilon * k,
724 gradient_target,
725 );
726 }
727
728 epsilon = half * k * epsilon;
729 let log_accept_prob = ulogp_prime
730 - ulogp.clone()
731 - ((mom_prime.clone() * mom_prime).sum() - (mom.clone() * mom.clone()).sum()) * half;
732 let mut log_accept_prob = T::from(log_accept_prob.into_scalar().to_f64()).unwrap();
733
734 let a = if log_accept_prob > half.ln() {
735 T::one()
736 } else {
737 -T::one()
738 };
739
740 while a * log_accept_prob > -a * T::from(2.0).unwrap().ln() {
741 epsilon = epsilon * T::from(2.0).unwrap().powf(a);
742 (_, mom_prime, _, ulogp_prime) = leapfrog(
743 position.clone(),
744 mom.clone(),
745 grad.clone(),
746 epsilon,
747 gradient_target,
748 );
749 log_accept_prob = T::from(
750 (ulogp_prime
751 - ulogp.clone()
752 - ((mom_prime.clone() * mom_prime).sum() - (mom.clone() * mom.clone()).sum())
753 * 0.5)
754 .into_scalar()
755 .to_f64(),
756 )
757 .unwrap();
758 }
759
760 epsilon
761}
762
763#[allow(clippy::too_many_arguments, clippy::type_complexity)]
764fn build_tree<B, T, GTarget>(
765 position: Tensor<B, 1>,
766 mom: Tensor<B, 1>,
767 grad: Tensor<B, 1>,
768 logu: T,
769 v: i8,
770 j: usize,
771 epsilon: T,
772 gradient_target: >arget,
773 joint_0: T,
774 rng: &mut SmallRng,
775) -> (
776 Tensor<B, 1>,
777 Tensor<B, 1>,
778 Tensor<B, 1>,
779 Tensor<B, 1>,
780 Tensor<B, 1>,
781 Tensor<B, 1>,
782 Tensor<B, 1>,
783 Tensor<B, 1>,
784 Tensor<B, 1>,
785 usize,
786 bool,
787 T,
788 usize,
789)
790where
791 T: Float + Element,
792 B: AutodiffBackend,
793 GTarget: GradientTarget<T, B> + Sync,
794{
795 if j == 0 {
796 let (position_prime, mom_prime, grad_prime, logp_prime) = leapfrog(
797 position.clone(),
798 mom.clone(),
799 grad.clone(),
800 T::from(v as i32).unwrap() * epsilon,
801 gradient_target,
802 );
803 let joint = logp_prime.clone() - (mom_prime.clone() * mom_prime.clone()).sum() * 0.5;
804 let joint = T::from(joint.into_scalar().to_f64())
805 .expect("type conversion from joint tensor to scalar type T to succeed");
806 let n_prime = (logu < joint) as usize;
807 let s_prime = (logu - T::from(1000.0).unwrap()) < joint;
808 let position_minus = position_prime.clone();
809 let position_plus = position_prime.clone();
810 let mom_minus = mom_prime.clone();
811 let mom_plus = mom_prime.clone();
812 let grad_minus = grad_prime.clone();
813 let grad_plus = grad_prime.clone();
814 let alpha_prime = T::min(T::one(), (joint - joint_0).exp());
815 let n_alpha_prime = 1_usize;
816 (
817 position_minus,
818 mom_minus,
819 grad_minus,
820 position_plus,
821 mom_plus,
822 grad_plus,
823 position_prime,
824 grad_prime,
825 logp_prime,
826 n_prime,
827 s_prime,
828 alpha_prime,
829 n_alpha_prime,
830 )
831 } else {
832 let (
833 mut position_minus,
834 mut mom_minus,
835 mut grad_minus,
836 mut position_plus,
837 mut mom_plus,
838 mut grad_plus,
839 mut position_prime,
840 mut grad_prime,
841 mut logp_prime,
842 mut n_prime,
843 mut s_prime,
844 mut alpha_prime,
845 mut n_alpha_prime,
846 ) = build_tree(
847 position,
848 mom,
849 grad,
850 logu,
851 v,
852 j - 1,
853 epsilon,
854 gradient_target,
855 joint_0,
856 rng,
857 );
858 if s_prime {
859 let (
860 position_minus_2,
861 mom_minus_2,
862 grad_minus_2,
863 position_plus_2,
864 mom_plus_2,
865 grad_plus_2,
866 position_prime_2,
867 grad_prime_2,
868 logp_prime_2,
869 n_prime_2,
870 s_prime_2,
871 alpha_prime_2,
872 n_alpha_prime_2,
873 ) = if v == -1 {
874 build_tree(
875 position_minus.clone(),
876 mom_minus.clone(),
877 grad_minus.clone(),
878 logu,
879 v,
880 j - 1,
881 epsilon,
882 gradient_target,
883 joint_0,
884 rng,
885 )
886 } else {
887 build_tree(
888 position_plus.clone(),
889 mom_plus.clone(),
890 grad_plus.clone(),
891 logu,
892 v,
893 j - 1,
894 epsilon,
895 gradient_target,
896 joint_0,
897 rng,
898 )
899 };
900 if v == -1 {
901 position_minus = position_minus_2;
902 mom_minus = mom_minus_2;
903 grad_minus = grad_minus_2;
904 } else {
905 position_plus = position_plus_2;
906 mom_plus = mom_plus_2;
907 grad_plus = grad_plus_2;
908 }
909
910 let u_build_tree: f64 = (*rng).random::<f64>();
911 if u_build_tree < (n_prime_2 as f64 / (n_prime + n_prime_2).max(1) as f64) {
912 position_prime = position_prime_2;
913 grad_prime = grad_prime_2;
914 logp_prime = logp_prime_2;
915 }
916
917 n_prime += n_prime_2;
918
919 s_prime = s_prime
920 && s_prime_2
921 && stop_criterion(
922 position_minus.clone(),
923 position_plus.clone(),
924 mom_minus.clone(),
925 mom_plus.clone(),
926 );
927 alpha_prime = alpha_prime + alpha_prime_2;
928 n_alpha_prime += n_alpha_prime_2;
929 }
930 (
931 position_minus,
932 mom_minus,
933 grad_minus,
934 position_plus,
935 mom_plus,
936 grad_plus,
937 position_prime,
938 grad_prime,
939 logp_prime,
940 n_prime,
941 s_prime,
942 alpha_prime,
943 n_alpha_prime,
944 )
945 }
946}
947
948fn all_real<B, T>(x: Tensor<B, 1>) -> bool
949where
950 T: Float + Element,
951 B: AutodiffBackend,
952{
953 x.clone()
954 .equal_elem(T::infinity())
955 .bool_or(x.clone().equal_elem(T::neg_infinity()))
956 .bool_or(x.is_nan())
957 .any()
958 .bool_not()
959 .into_scalar()
960 .to_bool()
961}
962
963fn stop_criterion<B>(
964 position_minus: Tensor<B, 1>,
965 position_plus: Tensor<B, 1>,
966 mom_minus: Tensor<B, 1>,
967 mom_plus: Tensor<B, 1>,
968) -> bool
969where
970 B: AutodiffBackend,
971{
972 let diff = position_plus - position_minus;
973 let dot_minus = (diff.clone() * mom_minus).sum();
974 let dot_plus = (diff * mom_plus).sum();
975 dot_minus.greater_equal_elem(0).into_scalar().to_bool()
976 && dot_plus.greater_equal_elem(0).into_scalar().to_bool()
977}
978
979fn leapfrog<B, T, GTarget>(
980 position: Tensor<B, 1>,
981 mom: Tensor<B, 1>,
982 grad: Tensor<B, 1>,
983 epsilon: T,
984 gradient_target: >arget,
985) -> (Tensor<B, 1>, Tensor<B, 1>, Tensor<B, 1>, Tensor<B, 1>)
986where
987 T: Float + ElementConversion,
988 B: AutodiffBackend,
989 GTarget: GradientTarget<T, B>,
990{
991 let mom_prime = mom + grad * epsilon * 0.5;
992 let position_prime = position + mom_prime.clone() * epsilon;
993 let (ulogp_prime, grad_prime) = gradient_target.unnorm_logp_and_grad(position_prime.clone());
994 let mom_prime = mom_prime + grad_prime.clone() * epsilon * 0.5;
995 (position_prime, mom_prime, grad_prime, ulogp_prime)
996}
997
998#[cfg(test)]
999mod tests {
1000 use std::fmt::Debug;
1001
1002 use crate::{
1003 core::init,
1004 dev_tools::Timer,
1005 distributions::{DiffableGaussian2D, Rosenbrock2D},
1006 stats::split_rhat_mean_ess,
1007 };
1008
1009 #[cfg(feature = "csv")]
1010 use crate::io::csv::save_csv_tensor;
1011
1012 use super::*;
1013 use burn::{
1014 backend::{Autodiff, NdArray},
1015 tensor::{Tensor, Tolerance},
1016 };
1017 use ndarray::ArrayView3;
1018 use ndarray_stats::QuantileExt;
1019 use num_traits::Float;
1020
1021 type BackendType = Autodiff<NdArray>;
1023
1024 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1025 pub struct StandardNormal;
1026
1027 impl<T, B> GradientTarget<T, B> for StandardNormal
1028 where
1029 T: Float + Debug + ElementConversion + Element,
1030 B: AutodiffBackend,
1031 {
1032 fn unnorm_logp(&self, positions: Tensor<B, 1>) -> Tensor<B, 1> {
1033 let sq = positions.clone().powi_scalar(2);
1034 let half = T::from(0.5).unwrap();
1035 -(sq.mul_scalar(half)).sum()
1036 }
1037 }
1038
1039 fn assert_tensor_approx_eq<T: Backend, F: Float + burn::tensor::Element>(
1040 actual: Tensor<T, 1>,
1041 expected: &[f64],
1042 tol: Tolerance<F>,
1043 ) {
1044 let a = actual.clone().to_data();
1045 let e = Tensor::<T, 1>::from(expected).to_data();
1046 a.assert_approx_eq(&e, tol);
1047 }
1048
1049 #[test]
1050 fn test_find_reasonable_epsilon() {
1051 let position = Tensor::<BackendType, 1>::from([0.0, 1.0]);
1052 let mom = Tensor::<BackendType, 1>::from([1.0, 0.0]);
1053 let epsilon = find_reasonable_epsilon::<_, f64, _>(position, mom, &StandardNormal);
1054 assert_eq!(epsilon, 2.0);
1055 }
1056
1057 #[test]
1058 fn test_build_tree() {
1059 let gradient_target = DiffableGaussian2D::new([0.0_f64, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
1060 let position = Tensor::<BackendType, 1>::from([0.0, 1.0]);
1061 let mom = Tensor::<BackendType, 1>::from([2.0, 3.0]);
1062 let grad = Tensor::<BackendType, 1>::from([4.0, 5.0]);
1063 let logu = -2.0;
1064 let v: i8 = -1;
1065 let j: usize = 3;
1066 let epsilon: f64 = 0.01;
1067 let joint_0 = 0.1_f64;
1068 let mut rng = SmallRng::seed_from_u64(0);
1069 let (
1070 position_minus,
1071 mom_minus,
1072 grad_minus,
1073 position_plus,
1074 mom_plus,
1075 grad_plus,
1076 position_prime,
1077 grad_prime,
1078 logp_prime,
1079 n_prime,
1080 s_prime,
1081 alpha_prime,
1082 n_alpha_prime,
1083 ) = build_tree::<BackendType, f64, _>(
1084 position,
1085 mom,
1086 grad,
1087 logu,
1088 v,
1089 j,
1090 epsilon,
1091 &gradient_target,
1092 joint_0,
1093 &mut rng,
1094 );
1095 let tol = Tolerance::<f64>::default()
1096 .set_relative(1e-5)
1097 .set_absolute(1e-6);
1098
1099 assert_tensor_approx_eq(position_minus, &[-0.1584001, 0.76208336], tol);
1100 assert_tensor_approx_eq(mom_minus, &[1.980_003_6, 2.971_825_3], tol);
1101 assert_tensor_approx_eq(grad_minus, &[-7.912_36e-5, 7.935_829_5e-2], tol);
1102
1103 assert_tensor_approx_eq(position_plus, &[-0.0198, 0.97025], tol);
1104 assert_tensor_approx_eq(mom_plus, &[1.98, 2.974_950_3], tol);
1105 assert_tensor_approx_eq(grad_plus, &[-1.250e-05, 9.925e-03], tol);
1106
1107 assert_tensor_approx_eq(position_prime, &[-0.0198, 0.97025], tol);
1108 assert_tensor_approx_eq(grad_prime, &[-1.250e-05, 9.925e-03], tol);
1109
1110 assert_eq!(n_prime, 0);
1111 assert!(s_prime);
1112 assert_eq!(n_alpha_prime, 8);
1113
1114 let logp_exp = -2.877_745_4_f64;
1115 let alpha_exp = 0.000_686_661_7_f64;
1116 assert!(
1117 (logp_prime.into_scalar().to_f64() - logp_exp).abs() < 1e-6,
1118 "logp mismatch"
1119 );
1120 assert!((alpha_prime - alpha_exp).abs() < 1e-8, "alpha mismatch");
1121 }
1122
1123 #[test]
1124 fn test_chain_1() {
1125 let target = DiffableGaussian2D::new([0.0_f64, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
1126 let initial_positions = vec![0.0_f64, 1.0];
1127 let n_discard = 0;
1128 let n_collect = 1;
1129 let mut sampler = NUTSChain::new(target, initial_positions, 0.8).set_seed(42);
1130 let sample: Tensor<BackendType, 2> = sampler.run(n_collect, n_discard);
1131 assert_eq!(sample.dims(), [n_collect, 2]);
1132 let tol = Tolerance::<f64>::default()
1133 .set_relative(1e-5)
1134 .set_absolute(1e-6);
1135 assert_tensor_approx_eq(sample.flatten(0, 1), &[0.0, 1.0], tol);
1136 }
1137
1138 #[test]
1139 fn test_chain_2() {
1140 let target = DiffableGaussian2D::new([0.0_f64, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
1141 let initial_positions = vec![0.0_f64, 1.0];
1142 let n_discard = 3;
1143 let n_collect = 3;
1144 let mut sampler = NUTSChain::new(target, initial_positions, 0.8).set_seed(42);
1145 let sample: Tensor<BackendType, 2> = sampler.run(n_collect, n_discard);
1146 assert_eq!(sample.dims(), [n_collect, 2]);
1147 let tol = Tolerance::<f64>::default()
1148 .set_relative(1e-5)
1149 .set_absolute(1e-6);
1150 assert_tensor_approx_eq(
1151 sample.flatten(0, 1),
1152 &[
1153 -1.168318748474121,
1154 -0.4077277183532715,
1155 -1.8463939428329468,
1156 0.19176559150218964,
1157 -1.0662782192230225,
1158 -0.3948383331298828,
1159 ],
1160 tol,
1161 );
1162 }
1163
1164 #[test]
1165 fn test_chain_3() {
1166 let target = DiffableGaussian2D::new([1.0_f64, 2.0], [[1.0, 2.0], [2.0, 5.0]]);
1167 let initial_positions = vec![-2.0_f64, 1.0];
1168 let n_discard = 5;
1169 let n_collect = 5;
1170 let mut sampler = NUTSChain::new(target, initial_positions, 0.8).set_seed(42);
1171 let sample: Tensor<BackendType, 2> = sampler.run(n_collect, n_discard);
1172 assert_eq!(sample.dims(), [n_collect, 2]);
1173 let tol = Tolerance::<f64>::default()
1174 .set_relative(1e-5)
1175 .set_absolute(1e-6);
1176 assert_tensor_approx_eq(
1177 sample.flatten(0, 1),
1178 &[
1179 2.653707265853882,
1180 5.560618877410889,
1181 2.9760334491729736,
1182 6.325948715209961,
1183 2.187873125076294,
1184 5.611990928649902,
1185 2.1512224674224854,
1186 5.416507720947266,
1187 2.4165120124816895,
1188 3.9120564460754395,
1189 ],
1190 tol,
1191 );
1192 }
1193
1194 #[test]
1195 fn test_run_1() {
1196 let target = DiffableGaussian2D::new([1.0_f64, 2.0], [[1.0, 2.0], [2.0, 5.0]]);
1197 let initial_positions = vec![vec![-2_f64, 1.0]];
1198 let n_discard = 5;
1199 let n_collect = 5;
1200 let mut sampler = NUTS::new(target, initial_positions, 0.8).set_seed(41);
1201 let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
1202 assert_eq!(sample.dims(), [1, n_collect, 2]);
1203 let tol = Tolerance::<f64>::default()
1204 .set_relative(1e-5)
1205 .set_absolute(1e-6);
1206 assert_tensor_approx_eq(
1207 sample.flatten(0, 2),
1208 &[
1209 2.653707265853882,
1210 5.560618877410889,
1211 2.9760334491729736,
1212 6.325948715209961,
1213 2.187873125076294,
1214 5.611990928649902,
1215 2.1512224674224854,
1216 5.416507720947266,
1217 2.4165120124816895,
1218 3.9120564460754395,
1219 ],
1220 tol,
1221 );
1222 }
1223
1224 #[test]
1225 fn test_progress_1() {
1226 let target = Rosenbrock2D {
1227 a: 1.0_f32,
1228 b: 100.0_f32,
1229 };
1230
1231 let initial_positions = init::<f32>(6, 2);
1233 let n_collect = 10;
1234 let n_discard = 10;
1235
1236 let mut sampler =
1237 NUTS::<_, BackendType, _>::new(target, initial_positions, 0.95).set_seed(42);
1238 let (sample, stats) = sampler.run_progress(n_collect, n_discard).unwrap();
1239 println!(
1240 "NUTS sampler: generated {} observations.",
1241 sample.dims()[0..2].iter().product::<usize>()
1242 );
1243 assert_eq!(sample.dims(), [6, n_collect, 2]);
1244
1245 println!("Statistics: {stats}");
1246
1247 #[cfg(feature = "csv")]
1248 save_csv_tensor(sample, "/tmp/nuts-sample.csv").expect("saving data should succeed")
1249 }
1250
1251 #[test]
1252 #[ignore = "Benchmark test: run only when explicitly requested"]
1253 fn test_bench_noprogress_1() {
1254 let target = Rosenbrock2D {
1255 a: 1.0_f32,
1256 b: 100.0_f32,
1257 };
1258
1259 let initial_positions = init::<f32>(6, 2);
1261 let n_collect = 5000;
1262 let n_discard = 500;
1263
1264 let mut sampler = NUTS::new(target, initial_positions, 0.95).set_seed(42);
1265 let mut timer = Timer::new();
1266 let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
1267 timer.log(format!(
1268 "NUTS sampler: generated {} observations.",
1269 sample.dims()[0..2].iter().product::<usize>()
1270 ));
1271 assert_eq!(sample.dims(), [6, 5000, 2]);
1272
1273 let data = sample.to_data();
1274 let array = ArrayView3::from_shape(sample.dims(), data.as_slice().unwrap()).unwrap();
1275 let (split_rhat, ess) = split_rhat_mean_ess(array);
1276 println!("AVG Split Rhat: {}", split_rhat.mean().unwrap());
1277 println!("AVG ESS: {}", ess.mean().unwrap());
1278
1279 #[cfg(feature = "csv")]
1280 save_csv_tensor(sample, "/tmp/nuts-sample.csv").expect("saving data should succeed")
1281 }
1282
1283 #[test]
1284 #[ignore = "Benchmark test: run only when explicitly requested"]
1285 fn test_bench_noprogress_2() {
1286 let target = Rosenbrock2D {
1287 a: 1.0_f32,
1288 b: 100.0_f32,
1289 };
1290
1291 let initial_positions = init::<f32>(6, 2);
1293 let n_collect = 1000;
1294 let n_discard = 1000;
1295
1296 let mut sampler = NUTS::new(target, initial_positions, 0.95).set_seed(42);
1297 let mut timer = Timer::new();
1298 let sample: Tensor<BackendType, 3> = sampler.run(n_collect, n_discard);
1299 timer.log(format!(
1300 "NUTS sampler: generated {} observations.",
1301 sample.dims()[0..2].iter().product::<usize>()
1302 ));
1303 assert_eq!(sample.dims(), [6, 1000, 2]);
1304
1305 let data = sample.to_data();
1306 let array = ArrayView3::from_shape(sample.dims(), data.as_slice().unwrap()).unwrap();
1307 let (split_rhat, ess) = split_rhat_mean_ess(array);
1308 println!("MIN Split Rhat: {}", split_rhat.min().unwrap());
1309 println!("MIN ESS: {}", ess.min().unwrap());
1310
1311 #[cfg(feature = "csv")]
1312 save_csv_tensor(sample, "/tmp/nuts-sample.csv").expect("saving data should succeed")
1313 }
1314}