1use crate::distributions::BatchedGradientTarget;
13use crate::stats::MultiChainTracker;
14use crate::stats::RunStats;
15use burn::prelude::*;
16use burn::tensor::backend::AutodiffBackend;
17use burn::tensor::Tensor;
18use indicatif::{ProgressBar, ProgressStyle};
19use num_traits::Float;
20use rand::prelude::*;
21use rand::Rng;
22use rand_distr::{StandardNormal, StandardUniform};
23use std::error::Error;
24
25#[derive(Debug, Clone)]
36pub struct HMC<T, B, GTarget>
37where
38 B: AutodiffBackend,
39{
40 pub target: GTarget,
42 pub step_size: T,
44 pub n_leapfrog: usize,
46 pub positions: Tensor<B, 2>,
50
51 last_grad_summands: Tensor<B, 2>,
53
54 pub rng: SmallRng,
57}
58
59impl<T, B, GTarget> HMC<T, B, GTarget>
60where
61 T: Float
62 + burn::tensor::ElementConversion
63 + burn::tensor::Element
64 + rand_distr::uniform::SampleUniform
65 + num_traits::FromPrimitive,
66 B: AutodiffBackend,
67 GTarget: BatchedGradientTarget<T, B> + std::marker::Sync,
68 StandardNormal: rand::distr::Distribution<T>,
69 StandardUniform: rand_distr::Distribution<T>,
70{
71 pub fn new(
88 target: GTarget,
89 initial_positions: Vec<Vec<T>>,
90 step_size: T,
91 n_leapfrog: usize,
92 ) -> Self {
93 let (n_chains, dim) = (initial_positions.len(), initial_positions[0].len());
95 let td: TensorData = TensorData::new(
96 initial_positions.into_iter().flatten().collect(),
97 [n_chains, dim],
98 );
99 let positions = Tensor::<B, 2>::from_data(td, &B::Device::default());
100 let rng = SmallRng::seed_from_u64(rand::rng().random::<u64>());
101 Self {
102 target,
103 step_size,
104 n_leapfrog,
105 last_grad_summands: Tensor::<B, 2>::zeros_like(&positions),
106 positions,
107 rng,
108 }
109 }
110
111 pub fn set_seed(mut self, seed: u64) -> Self {
119 self.rng = SmallRng::seed_from_u64(seed);
120 self
121 }
122
123 pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3> {
138 let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
139 let mut out = Tensor::<B, 3>::empty(
140 [n_collect, n_chains, self.positions.dims()[1]],
141 &B::Device::default(),
142 );
143
144 (0..n_discard).for_each(|_| self.step());
146
147 for step in 1..(n_collect + 1) {
149 self.step();
150 out.inplace(|_out| {
151 _out.slice_assign(
152 [step - 1..step, 0..n_chains, 0..dim],
153 self.positions.clone().unsqueeze_dim(0),
154 )
155 });
156 }
157 out.permute([1, 0, 2])
158 }
159
160 pub fn run_progress(
223 &mut self,
224 n_collect: usize,
225 n_discard: usize,
226 ) -> Result<(Tensor<B, 3>, RunStats), Box<dyn Error>> {
227 (0..n_discard).for_each(|_| self.step());
229
230 let (n_chains, dim) = (self.positions.dims()[0], self.positions.dims()[1]);
231 let mut out = Tensor::<B, 3>::empty([n_collect, n_chains, dim], &B::Device::default());
232
233 let pb = ProgressBar::new(n_collect as u64);
234 pb.set_style(
235 ProgressStyle::default_bar()
236 .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
237 .unwrap()
238 .progress_chars("=>-"),
239 );
240 pb.set_prefix("HMC");
241
242 let mut tracker = MultiChainTracker::new(n_chains, dim);
243
244 let mut last_state = self.positions.clone();
245
246 let mut last_state_data = last_state.to_data();
247 if let Err(e) = tracker.step(last_state_data.as_slice::<T>().unwrap()) {
248 eprintln!("Warning: Shown progress statistics may be unreliable since updating them failed with: {}", e);
249 }
250
251 for i in 0..n_collect {
252 self.step();
253 let current_state = self.positions.clone();
254
255 out.inplace(|_out| {
257 _out.slice_assign(
258 [i..i + 1, 0..n_chains, 0..dim],
259 current_state.clone().unsqueeze_dim(0),
260 )
261 });
262 pb.inc(1);
263 last_state = current_state;
264
265 last_state_data = last_state.to_data();
266 if let Err(e) = tracker.step(last_state_data.as_slice::<T>().unwrap()) {
267 eprintln!("Warning: Shown progress statistics may be unreliable since updating them failed with: {}", e);
268 }
269
270 match tracker.max_rhat() {
271 Err(e) => {
272 eprintln!("Computing max(rhat) failed with: {}", e);
273 }
274 Ok(max_rhat) => {
275 pb.set_message(format!(
276 "p(accept)≈{:.2} max(rhat)≈{:.2}",
277 tracker.p_accept, max_rhat
278 ));
279 }
280 }
281 }
282 pb.finish_with_message("Done!");
283 let sample = out.permute([1, 0, 2]);
284
285 let stats = match tracker.stats(sample.clone()) {
286 Ok(stats) => stats,
287 Err(e) => {
288 eprintln!("Getting run statistics failed with: {}", e);
289 return Err(e);
290 }
291 };
292
293 Ok((sample, stats))
294 }
295
296 pub fn step(&mut self) {
305 let shape = self.positions.shape();
306 let (n_chains, dim) = (shape.dims[0], shape.dims[1]);
307
308 let momentum_0 = Tensor::<B, 2>::random(
310 Shape::new([n_chains, dim]),
311 burn::tensor::Distribution::Normal(0., 1.),
312 &B::Device::default(),
313 );
314
315 let pos = self.positions.clone().detach().require_grad();
318 let logp_current = self.target.unnorm_logp_batch(pos.clone());
319
320 let grads = pos.grad(&logp_current.backward()).unwrap();
323 let grad_summands =
324 Tensor::<B, 2>::from_inner(grads.mul_scalar(self.step_size * T::from(0.5).unwrap()));
325 self.last_grad_summands = grad_summands;
326
327 let ke_current = momentum_0
329 .clone()
330 .powf_scalar(2.0)
331 .sum_dim(1) .squeeze(1)
333 .mul_scalar(T::from(0.5).unwrap());
334
335 let h_current: Tensor<B, 1> = -logp_current + ke_current;
337
338 let (proposed_positions, proposed_momenta, logp_proposed) =
340 self.leapfrog(self.positions.clone(), momentum_0);
341
342 let ke_proposed = proposed_momenta
344 .powf_scalar(2.0)
345 .sum_dim(1)
346 .squeeze(1)
347 .mul_scalar(T::from(0.5).unwrap());
348
349 let h_proposed = -logp_proposed + ke_proposed;
350
351 let accept_logp = h_current.sub(h_proposed);
353
354 let mut uniform_data = Vec::with_capacity(n_chains);
356 for _ in 0..n_chains {
357 uniform_data.push(self.rng.random::<T>());
358 }
359 let uniform = Tensor::<B, 1>::random(
360 Shape::new([n_chains]),
361 burn::tensor::Distribution::Default,
362 &B::Device::default(),
363 );
364
365 let ln_u = uniform.log(); let accept_mask = accept_logp.greater_equal(ln_u); let mut accept_mask_big: Tensor<B, 2, Bool> = accept_mask.clone().unsqueeze_dim(1);
369 accept_mask_big = accept_mask_big.expand([n_chains, dim]);
370
371 self.positions.inplace(|x| {
373 x.clone()
374 .mask_where(accept_mask_big, proposed_positions)
375 .detach()
376 });
377 }
378
379 fn leapfrog(
398 &mut self,
399 mut pos: Tensor<B, 2>,
400 mut mom: Tensor<B, 2>,
401 ) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 1>) {
402 let half = T::from(0.5).unwrap();
403 for _step_i in 0..self.n_leapfrog {
404 pos = pos.detach().require_grad();
406
407 mom.inplace(|_mom| _mom.add(self.last_grad_summands.clone()));
409
410 pos.inplace(|_pos| {
412 _pos.add(mom.clone().mul_scalar(self.step_size))
413 .detach()
414 .require_grad()
415 });
416
417 let logp = self.target.unnorm_logp_batch(pos.clone());
419 let grads = pos.grad(&logp.backward()).unwrap();
420 let grad_summands = Tensor::<B, 2>::from_inner(grads.mul_scalar(self.step_size * half));
421
422 mom.inplace(|_mom| _mom.add(grad_summands.clone()));
424
425 self.last_grad_summands = grad_summands;
426 }
427
428 let logp_final = self.target.unnorm_logp_batch(pos.clone());
430 (pos.detach(), mom.detach(), logp_final.detach())
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use crate::{
437 core::init,
438 dev_tools::Timer,
439 distributions::{DiffableGaussian2D, Rosenbrock2D, RosenbrockND},
440 stats::split_rhat_mean_ess,
441 };
442 use ndarray::ArrayView3;
443 use ndarray_stats::QuantileExt;
444
445 use super::*;
446 use burn::{
447 backend::{Autodiff, NdArray},
448 tensor::Tensor,
449 };
450
451 type BackendType = Autodiff<NdArray>;
453
454 #[test]
455 fn test_hmc_single() {
456 let target = Rosenbrock2D {
458 a: 1.0_f32,
459 b: 100.0_f32,
460 };
461
462 let initial_positions = vec![vec![0.0_f32, 0.0]];
464 let n_collect = 3;
465
466 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
468 target,
469 initial_positions,
470 0.01, 2, )
473 .set_seed(42);
474
475 let mut timer = Timer::new();
477 let sample: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
478 timer.log(format!(
479 "Collected sample (10 chains) with shape: {:?}",
480 sample.dims()
481 ));
482 assert_eq!(sample.dims(), [1, 3, 2]);
483 }
484
485 #[test]
486 fn test_3_chains() {
487 type BackendType = Autodiff<NdArray>;
489
490 let target = Rosenbrock2D {
492 a: 1.0_f32,
493 b: 100.0_f32,
494 };
495
496 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 3];
498 let n_collect = 10;
499
500 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
502 target,
503 initial_positions,
504 0.01, 2, )
507 .set_seed(42);
508
509 let mut timer = Timer::new();
511 let sample: Tensor<BackendType, 3> = sampler.run(n_collect, 0);
512 timer.log(format!(
513 "Collected sample (3 chains) with shape: {:?}",
514 sample.dims()
515 ));
516 assert_eq!(sample.dims(), [3, 10, 2]);
517 }
518
519 #[test]
520 fn test_progress_3_chains() {
521 type BackendType = Autodiff<NdArray>;
523
524 let target = Rosenbrock2D {
526 a: 1.0_f32,
527 b: 100.0_f32,
528 };
529
530 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; 3];
532 let n_collect = 10;
533
534 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
536 target,
537 initial_positions,
538 0.05, 2, )
541 .set_seed(42);
542
543 let mut timer = Timer::new();
545 let sample: Tensor<BackendType, 3> = sampler.run_progress(n_collect, 3).unwrap().0;
546 timer.log(format!(
547 "Collected sample (10 chains) with shape: {:?}",
548 sample.dims()
549 ));
550 assert_eq!(sample.dims(), [3, 10, 2]);
551 }
552
553 #[test]
554 fn test_gaussian_2d_hmc_debug() {
555 let n_chains = 1;
556 let n_discard = 1;
557 let n_collect = 1;
558
559 let target = DiffableGaussian2D::new([0.0, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
560 let initial_positions = vec![vec![0.0_f32, 0.0_f32]];
561
562 type BackendType = Autodiff<NdArray>;
563 let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
564 target,
565 initial_positions,
566 0.1,
567 1,
568 )
569 .set_seed(42);
570
571 let sample_3d = sampler.run(n_collect, n_discard);
572
573 assert_eq!(sample_3d.dims(), [n_chains, n_collect, 2]);
574 }
575
576 #[test]
577 #[ignore = "Benchmark test: run only when explicitly requested"]
578 fn test_gaussian_2d_hmc_single_run() {
579 let n_chains = 3;
581
582 let n_discard = 500;
583 let n_collect = 1000;
584
585 let target = DiffableGaussian2D::new([0.0, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
588
589 let initial_positions = vec![
591 vec![1.0_f32, 2.0_f32],
592 vec![1.0_f32, 2.0_f32],
593 vec![1.0_f32, 2.0_f32],
594 ];
595
596 type BackendType = Autodiff<NdArray>;
598 let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
599 target,
600 initial_positions,
601 0.1, 10, )
604 .set_seed(42);
605
606 let sample_3d = sampler.run(n_collect, n_discard);
609
610 assert_eq!(sample_3d.dims(), [n_chains, n_collect, 2]);
612
613 let data = sample_3d.to_data();
615 let arr = ArrayView3::from_shape(sample_3d.dims(), data.as_slice().unwrap()).unwrap();
616
617 let (rhat, ess_vals) = split_rhat_mean_ess(arr.view());
619 let ess1 = ess_vals[0];
620 let ess2 = ess_vals[1];
621
622 println!("\nSingle Run Results:");
623 println!("Rhat: {:?}", rhat);
624 println!("ESS(Param1): {:.2}", ess1);
625 println!("ESS(Param2): {:.2}", ess2);
626
627 assert!(ess1 > 50.0, "Expected param1 ESS > 50, got {:.2}", ess1);
629 assert!(ess2 > 50.0, "Expected param2 ESS > 50, got {:.2}", ess2);
630 }
631
632 #[test]
633 #[ignore = "Benchmark test: run only when explicitly requested"]
634 fn test_gaussian_2d_hmc_ess_stats() {
635 use crate::stats::basic_stats;
636 use indicatif::{ProgressBar, ProgressStyle};
637 use ndarray::Array1;
638
639 let n_runs = 100;
640 let n_chains = 3;
641 let n_discard = 500;
642 let n_collect = 1000;
643 let mut rng = SmallRng::seed_from_u64(42);
644
645 let mut ess_param1s = Vec::with_capacity(n_runs);
647 let mut ess_param2s = Vec::with_capacity(n_runs);
648 let mut rhat_param1s = Vec::with_capacity(n_runs);
649 let mut rhat_param2s = Vec::with_capacity(n_runs);
650
651 let pb = ProgressBar::new(n_runs as u64);
653 pb.set_style(
654 ProgressStyle::default_bar()
655 .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
656 .unwrap()
657 .progress_chars("=>-"),
658 );
659 pb.set_prefix("HMC Test");
660
661 for run in 0..n_runs {
662 let target = DiffableGaussian2D::new([0.0_f32, 1.0], [[4.0, 2.0], [2.0, 3.0]]);
665
666 let initial_positions: Vec<Vec<f32>> = (0..n_chains)
669 .map(|_| {
670 vec![
672 rng.sample::<f32, _>(StandardNormal),
673 rng.sample::<f32, _>(StandardNormal),
674 ]
675 })
676 .collect();
677
678 type BackendType = Autodiff<NdArray>;
680 let mut sampler = HMC::<f32, BackendType, DiffableGaussian2D<f32>>::new(
681 target,
682 initial_positions,
683 0.1, 10, )
686 .set_seed(run as u64); let sample_3d = sampler.run(n_collect, n_discard);
691
692 assert_eq!(sample_3d.dims(), [n_chains, n_collect, 2]);
694
695 let data = sample_3d.to_data();
697 let arr = ArrayView3::from_shape(sample_3d.dims(), data.as_slice().unwrap()).unwrap();
698
699 let (rhat, ess_vals) = split_rhat_mean_ess(arr.view());
701 let ess1 = ess_vals[0];
702 let ess2 = ess_vals[1];
703
704 ess_param1s.push(ess1);
706 ess_param2s.push(ess2);
707
708 rhat_param1s.push(rhat[0]);
710 rhat_param2s.push(rhat[1]);
711
712 pb.inc(1);
713
714 if run > 0 {
716 let mean_ess1 = ess_param1s.iter().sum::<f32>() / (run as f32 + 1.0);
718 let mean_ess2 = ess_param2s.iter().sum::<f32>() / (run as f32 + 1.0);
719
720 let var_ess1 = ess_param1s
722 .iter()
723 .map(|&x| (x - mean_ess1).powi(2))
724 .sum::<f32>()
725 / (run as f32 + 1.0);
726 let var_ess2 = ess_param2s
727 .iter()
728 .map(|&x| (x - mean_ess2).powi(2))
729 .sum::<f32>()
730 / (run as f32 + 1.0);
731
732 let std_ess1 = var_ess1.sqrt();
733 let std_ess2 = var_ess2.sqrt();
734
735 pb.set_message(format!(
736 "ESS1={:.0}±{:.0} ESS2={:.0}±{:.0}",
737 mean_ess1, std_ess1, mean_ess2, std_ess2
738 ));
739 } else {
740 pb.set_message(format!("ESS1={:.0} ESS2={:.0}", ess1, ess2));
742 }
743 }
744 pb.finish_with_message("All runs complete!");
745
746 let ess_param1_array = Array1::from_vec(ess_param1s);
748 let ess_param2_array = Array1::from_vec(ess_param2s);
749 let rhat_param1_array = Array1::from_vec(rhat_param1s);
750 let rhat_param2_array = Array1::from_vec(rhat_param2s);
751
752 let stats_p1_ess = basic_stats("ESS(Param1)", ess_param1_array);
754 let stats_p2_ess = basic_stats("ESS(Param2)", ess_param2_array);
755 let stats_p1_rhat = basic_stats("R-hat(Param1)", rhat_param1_array);
756 let stats_p2_rhat = basic_stats("R-hat(Param2)", rhat_param2_array);
757
758 println!("\nStatistics over {} runs:", n_runs);
759 println!("\nESS Statistics:");
760 println!("{stats_p1_ess}\n{stats_p2_ess}");
761 println!("\nR-hat Statistics:");
762 println!("{stats_p1_rhat}\n{stats_p2_rhat}");
763
764 assert!(
766 (135.0..=185.0).contains(&stats_p1_ess.mean),
767 "Expected param1 ESS to average in [135, 185], got {:.2}",
768 stats_p1_ess.mean
769 );
770 assert!(
771 (141.0..=191.0).contains(&stats_p2_ess.mean),
772 "Expected param2 ESS to average in [141, 191], got {:.2}",
773 stats_p2_ess.mean
774 );
775
776 assert!(
778 (0.95..=1.05).contains(&stats_p1_rhat.mean),
779 "Expected param1 R-hat to be in [0.95, 1.05], got {:.2}",
780 stats_p1_rhat.mean
781 );
782 assert!(
783 (0.95..=1.05).contains(&stats_p2_rhat.mean),
784 "Expected param2 R-hat to be in [0.95, 1.05], got {:.2}",
785 stats_p2_rhat.mean
786 );
787 }
788
789 #[test]
790 #[ignore = "Benchmark test: run only when explicitly requested"]
791 fn test_bench_noprogress() {
792 type BackendType = Autodiff<burn::backend::NdArray>;
794
795 let target = Rosenbrock2D {
797 a: 1.0_f32,
798 b: 100.0_f32,
799 };
800
801 let initial_positions = init(6, 2);
803 let n_collect = 5000;
804 let n_discard = 500;
805
806 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
808 target,
809 initial_positions,
810 0.01, 50, )
813 .set_seed(42);
814
815 let mut timer = Timer::new();
817 let sample = sampler.run(n_collect, n_discard);
818 timer.log(format!(
819 "HMC sampler: generated {} observations.",
820 sample.dims()[0..2].iter().product::<usize>()
821 ));
822 assert_eq!(sample.dims(), [6, 5000, 2]);
823
824 let data = sample.to_data();
825 let array = ArrayView3::from_shape(sample.dims(), data.as_slice().unwrap()).unwrap();
826 let (split_rhat, ess) = split_rhat_mean_ess(array);
827 println!("MIN Split Rhat: {}", split_rhat.min().unwrap());
828 println!("MIN ESS: {}", ess.min().unwrap());
829 }
830
831 #[test]
832 #[ignore = "Benchmark test: run only when explicitly requested"]
833 fn test_progress_bench() {
834 type BackendType = Autodiff<burn::backend::NdArray>;
836 BackendType::seed(42);
837
838 let target = Rosenbrock2D {
840 a: 1.0_f32,
841 b: 100.0_f32,
842 };
843
844 let n_chains = 6;
846 let initial_positions = vec![vec![1.0_f32, 2.0_f32]; n_chains];
847 let n_collect = 1000;
848 let n_discard = 1000;
849
850 let mut sampler = HMC::<f32, BackendType, Rosenbrock2D<f32>>::new(
852 target,
853 initial_positions,
854 0.01, 50, )
857 .set_seed(42);
858
859 let mut timer = Timer::new();
861 let sample = sampler.run_progress(n_collect, n_discard).unwrap().0;
862 timer.log(format!(
863 "HMC sampler: generated {} observations.",
864 sample.dims()[0..2].iter().product::<usize>()
865 ));
866 println!(
867 "Chain 1, first 10: {}",
868 sample.clone().slice([0..1, 0..10, 0..1])
869 );
870 println!(
871 "Chain 2, first 10: {}",
872 sample.clone().slice([2..3, 0..10, 0..1])
873 );
874
875 #[cfg(feature = "csv")]
876 crate::io::csv::save_csv_tensor(sample.clone(), "/tmp/hmc-sample.csv")
877 .expect("Expected saving to succeed");
878
879 assert_eq!(sample.dims(), [n_chains, n_collect, 2]);
880 }
881
882 #[test]
883 #[ignore = "Benchmark test: run only when explicitly requested"]
884 fn test_bench_10000d() {
885 type BackendType = Autodiff<burn::backend::NdArray>;
887
888 let seed = 42;
889 let d = 10000;
890 let n_chains = 6;
891 let n_collect = 100;
892 let n_discard = 100;
893
894 let rng = SmallRng::seed_from_u64(seed);
895 let initial_positions: Vec<Vec<f32>> =
897 vec![rng.sample_iter(StandardNormal).take(d).collect(); n_chains];
898
899 let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
901 RosenbrockND {},
902 initial_positions,
903 0.01, 50, )
906 .set_seed(42);
907
908 let mut timer = Timer::new();
910 let sample = sampler.run(n_collect, n_discard);
911 timer.log(format!(
912 "HMC sampler: generated {} observations.",
913 sample.dims()[0..2].iter().product::<usize>()
914 ));
915 assert_eq!(sample.dims(), [n_chains, n_collect, d]);
916 }
917
918 #[test]
919 #[ignore = "Benchmark test: run only when explicitly requested"]
920 #[cfg(feature = "wgpu")]
921 fn test_progress_10000d_bench() {
922 type BackendType = Autodiff<burn::backend::Wgpu>;
923
924 let seed = 42;
925 let d = 10000;
926 let n_chains = 6;
927
928 let rng = SmallRng::seed_from_u64(seed);
929 let initial_positions: Vec<Vec<f32>> =
931 vec![rng.sample_iter(StandardNormal).take(d).collect(); n_chains];
932 let n_collect = 100;
933 let n_discard = 100;
934
935 let mut sampler = HMC::<f32, BackendType, RosenbrockND>::new(
937 RosenbrockND {},
938 initial_positions,
939 0.01, 50, )
942 .set_seed(42);
943
944 let mut timer = Timer::new();
946 let sample = sampler.run_progress(n_collect, n_discard).unwrap().0;
947 timer.log(format!(
948 "HMC sampler: generated {} observations.",
949 sample.dims()[0..2].iter().product::<usize>()
950 ));
951 assert_eq!(sample.dims(), [n_chains, n_collect, d]);
952 }
953}