1use crate::euclidean::EuclideanVector;
4use crate::generic_hmc::HamiltonianTarget;
5use crate::stats::{ChainStats, ChainTracker, RunStats, collect_rhat, max_skipnan};
6use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
7use ndarray::{Array2, Array3, ArrayView1, ArrayView2, Axis, s};
8use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
9use rand::distr::Distribution as RandDistribution;
10use rand::rngs::SmallRng;
12use rand::{Rng, SeedableRng};
13use rand_distr::{Exp1, StandardNormal, StandardUniform};
14use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
15use std::error::Error;
16use std::sync::Arc;
17use std::sync::mpsc;
18use std::sync::mpsc::{Receiver, Sender};
19use std::thread;
20use std::time::{Duration, Instant};
21
22pub struct GenericNUTS<V, Target>
24where
25 V: EuclideanVector,
26 Target: HamiltonianTarget<V>,
27{
28 chains: Vec<GenericNUTSChain<V, Target>>,
29}
30
31type RunResult<T> = Result<(Array3<T>, RunStats), Box<dyn Error>>;
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum MassMatrixAdaptation {
36 None,
37 Diagonal,
38 Dense,
39}
40
41#[derive(Clone, Debug)]
43pub struct NUTSMassMatrixConfig {
44 pub adaptation: MassMatrixAdaptation,
45 pub start_buffer: usize,
46 pub end_buffer: usize,
47 pub initial_window: usize,
48 pub regularize: f64,
49 pub jitter: f64,
50 pub dense_max_dim: usize,
51}
52
53impl NUTSMassMatrixConfig {
54 pub fn disabled() -> Self {
55 Self {
56 adaptation: MassMatrixAdaptation::None,
57 start_buffer: 0,
58 end_buffer: 0,
59 initial_window: 0,
60 regularize: 0.0,
61 jitter: 0.0,
62 dense_max_dim: 0,
63 }
64 }
65}
66
67impl Default for NUTSMassMatrixConfig {
68 fn default() -> Self {
69 Self {
70 adaptation: MassMatrixAdaptation::Diagonal,
71 start_buffer: 75,
72 end_buffer: 50,
73 initial_window: 25,
74 regularize: 0.05,
75 jitter: 1e-6,
76 dense_max_dim: 75,
77 }
78 }
79}
80
81struct RunningCov<S: Float> {
82 dim: usize,
83 n: usize,
84 mean: Vec<S>,
85 m2_diag: Vec<S>,
86 m2_dense: Option<Vec<S>>,
87}
88
89impl<S: Float + FromPrimitive> RunningCov<S> {
90 fn new(dim: usize, dense: bool) -> Self {
91 Self {
92 dim,
93 n: 0,
94 mean: vec![S::zero(); dim],
95 m2_diag: vec![S::zero(); dim],
96 m2_dense: dense.then(|| vec![S::zero(); dim * dim]),
97 }
98 }
99
100 fn reset(&mut self) {
101 self.n = 0;
102 self.mean.fill(S::zero());
103 self.m2_diag.fill(S::zero());
104 if let Some(m2) = self.m2_dense.as_mut() {
105 m2.fill(S::zero());
106 }
107 }
108
109 fn update(&mut self, x: &[S]) {
110 self.n += 1;
111 let n_s = S::from_usize(self.n).unwrap();
112 let mut delta = vec![S::zero(); self.dim];
113 for i in 0..self.dim {
114 delta[i] = x[i] - self.mean[i];
115 self.mean[i] = self.mean[i] + delta[i] / n_s;
116 let delta2 = x[i] - self.mean[i];
117 self.m2_diag[i] = self.m2_diag[i] + delta[i] * delta2;
118 }
119 if let Some(m2) = self.m2_dense.as_mut() {
120 let mut delta2 = vec![S::zero(); self.dim];
121 for i in 0..self.dim {
122 delta2[i] = x[i] - self.mean[i];
123 }
124 for i in 0..self.dim {
125 for j in i..self.dim {
126 let idx = i * self.dim + j;
127 m2[idx] = m2[idx] + delta[i] * delta2[j];
128 }
129 }
130 }
131 }
132}
133
134struct MassMatrixWarmup<S: Float> {
135 config: NUTSMassMatrixConfig,
136 next_window_end: usize,
137 window_len: usize,
138 running: RunningCov<S>,
139}
140
141impl<S: Float + FromPrimitive> MassMatrixWarmup<S> {
142 fn new(dim: usize, config: NUTSMassMatrixConfig, dense: bool) -> Self {
143 let start_buffer = config.start_buffer.max(1);
144 let window_len = config.initial_window.max(10);
145 Self {
146 config,
147 next_window_end: start_buffer + window_len,
148 window_len,
149 running: RunningCov::new(dim, dense),
150 }
151 }
152
153 fn should_collect(&self, m: usize, n_warmup: usize) -> bool {
154 if m == 0 || m > n_warmup {
155 return false;
156 }
157 if m <= self.config.start_buffer {
158 return false;
159 }
160 m < n_warmup.saturating_sub(self.config.end_buffer)
161 }
162
163 fn note_if_window_end(&mut self, m: usize, n_warmup: usize) -> bool {
164 if !self.should_collect(m, n_warmup) {
165 return false;
166 }
167 if m >= self.next_window_end || m + 1 >= n_warmup.saturating_sub(self.config.end_buffer) {
168 self.next_window_end = self.next_window_end.saturating_add(self.window_len);
169 self.window_len = (self.window_len.saturating_mul(2)).min(400);
170 return true;
171 }
172 false
173 }
174}
175
176#[derive(Clone)]
177enum MassMatrix<S: Float> {
178 Identity {
179 dim: usize,
180 },
181 Diagonal {
182 inv: Vec<S>,
183 sqrt: Vec<S>,
184 },
185 Dense {
186 dim: usize,
187 inv: Vec<S>,
188 chol: Vec<S>,
189 },
190}
191
192impl<S: Float + FromPrimitive> MassMatrix<S> {
193 fn identity(dim: usize) -> Self {
194 Self::Identity { dim }
195 }
196
197 fn diagonal_from_var(mut var: Vec<S>, jitter: S) -> Self {
198 let mut inv = vec![S::zero(); var.len()];
199 let mut sqrt = vec![S::zero(); var.len()];
200 for i in 0..var.len() {
201 let v = var[i].max(jitter);
202 var[i] = v;
203 inv[i] = S::one() / v;
204 sqrt[i] = v.sqrt();
205 }
206 Self::Diagonal { inv, sqrt }
207 }
208
209 fn dense_from_cov(cov: Vec<S>, dim: usize, jitter: S) -> Option<Self> {
210 let max_tries = 8usize;
211 let mut j = jitter.max(S::from_f64(1e-10).unwrap());
212 for _ in 0..max_tries {
213 let mut cov_try = cov.clone();
214 for d in 0..dim {
215 cov_try[d * dim + d] = cov_try[d * dim + d] + j;
216 }
217 if let Some(chol) = cholesky_spd(&cov_try, dim)
218 && let Some(inv) = invert_spd_from_cholesky(&chol, dim)
219 {
220 return Some(Self::Dense { dim, inv, chol });
221 }
222 j = j * S::from_f64(10.0).unwrap();
223 }
224 None
225 }
226
227 fn kinetic(&self, momentum: &[S]) -> S {
228 let half = S::from_f64(0.5).unwrap();
229 match self {
230 Self::Identity { .. } => {
231 let mut q = S::zero();
232 for v in momentum {
233 q = q + *v * *v;
234 }
235 half * q
236 }
237 Self::Diagonal { inv, .. } => {
238 let mut q = S::zero();
239 for i in 0..momentum.len() {
240 q = q + momentum[i] * momentum[i] * inv[i];
241 }
242 half * q
243 }
244 Self::Dense { inv, dim, .. } => {
245 let mut q = S::zero();
246 for i in 0..*dim {
247 let mut row_dot = S::zero();
248 for j in 0..*dim {
249 row_dot = row_dot + inv[i * *dim + j] * momentum[j];
250 }
251 q = q + momentum[i] * row_dot;
252 }
253 half * q
254 }
255 }
256 }
257
258 fn inv_mul(&self, input: &[S], out: &mut [S]) {
259 match self {
260 Self::Identity { .. } => out.copy_from_slice(input),
261 Self::Diagonal { inv, .. } => {
262 for i in 0..input.len() {
263 out[i] = inv[i] * input[i];
264 }
265 }
266 Self::Dense { inv, dim, .. } => {
267 for i in 0..*dim {
268 let mut acc = S::zero();
269 for j in 0..*dim {
270 acc = acc + inv[i * *dim + j] * input[j];
271 }
272 out[i] = acc;
273 }
274 }
275 }
276 }
277
278 fn sample_momentum(&self, rng: &mut SmallRng, out: &mut [S])
279 where
280 StandardNormal: RandDistribution<S>,
281 {
282 for v in out.iter_mut() {
283 *v = rng.sample(StandardNormal);
284 }
285 match self {
286 Self::Identity { .. } => {}
287 Self::Diagonal { sqrt, .. } => {
288 for i in 0..out.len() {
289 out[i] = out[i] * sqrt[i];
290 }
291 }
292 Self::Dense { chol, dim, .. } => {
293 let z = out.to_vec();
294 for i in 0..*dim {
295 let mut acc = S::zero();
296 for j in 0..=i {
297 acc = acc + chol[i * *dim + j] * z[j];
298 }
299 out[i] = acc;
300 }
301 }
302 }
303 }
304}
305
306fn cholesky_spd<S: Float + FromPrimitive>(a: &[S], dim: usize) -> Option<Vec<S>> {
307 let mut l = vec![S::zero(); dim * dim];
308 for i in 0..dim {
309 for j in 0..=i {
310 let mut sum = a[i * dim + j];
311 for k in 0..j {
312 sum = sum - l[i * dim + k] * l[j * dim + k];
313 }
314 if i == j {
315 if sum <= S::zero() || !sum.is_finite() {
316 return None;
317 }
318 l[i * dim + j] = sum.sqrt();
319 } else {
320 let d = l[j * dim + j];
321 if d <= S::zero() || !d.is_finite() {
322 return None;
323 }
324 l[i * dim + j] = sum / d;
325 }
326 }
327 }
328 Some(l)
329}
330
331fn invert_spd_from_cholesky<S: Float + FromPrimitive>(l: &[S], dim: usize) -> Option<Vec<S>> {
332 let mut inv_l = vec![S::zero(); dim * dim];
333 for i in 0..dim {
334 let d = l[i * dim + i];
335 if d <= S::zero() || !d.is_finite() {
336 return None;
337 }
338 inv_l[i * dim + i] = S::one() / d;
339 for j in (i + 1)..dim {
340 let mut sum = S::zero();
341 for k in i..j {
342 sum = sum + l[j * dim + k] * inv_l[k * dim + i];
343 }
344 inv_l[j * dim + i] = -sum / l[j * dim + j];
345 }
346 }
347 let mut inv = vec![S::zero(); dim * dim];
348 for i in 0..dim {
349 for j in 0..=i {
350 let mut sum = S::zero();
351 for k in i.max(j)..dim {
352 sum = sum + inv_l[k * dim + i] * inv_l[k * dim + j];
353 }
354 inv[i * dim + j] = sum;
355 inv[j * dim + i] = sum;
356 }
357 }
358 Some(inv)
359}
360
361impl<V, Target> GenericNUTS<V, Target>
362where
363 V: EuclideanVector + Send,
364 V::Scalar: Float + FromPrimitive + ToPrimitive + Send,
365 Target: HamiltonianTarget<V> + Sync + Send,
366 StandardNormal: RandDistribution<V::Scalar>,
367 StandardUniform: RandDistribution<V::Scalar>,
368 Exp1: RandDistribution<V::Scalar>,
369{
370 pub fn new(target: Target, initial_positions: Vec<V>, target_accept_p: V::Scalar) -> Self {
371 Self::new_with_mass_matrix(
372 target,
373 initial_positions,
374 target_accept_p,
375 NUTSMassMatrixConfig::disabled(),
376 )
377 }
378
379 pub fn new_with_mass_matrix(
380 target: Target,
381 initial_positions: Vec<V>,
382 target_accept_p: V::Scalar,
383 mass_config: NUTSMassMatrixConfig,
384 ) -> Self {
385 let target = Arc::new(target);
386 let chains = initial_positions
387 .into_iter()
388 .map(|pos| {
389 GenericNUTSChain::new_shared(
390 Arc::clone(&target),
391 pos,
392 target_accept_p,
393 mass_config.clone(),
394 )
395 })
396 .collect();
397 Self { chains }
398 }
399
400 pub(crate) fn chains_mut(&mut self) -> &mut [GenericNUTSChain<V, Target>] {
401 &mut self.chains
402 }
403
404 pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Array3<V::Scalar> {
405 let chain_samples: Vec<Array2<V::Scalar>> = self
406 .chains
407 .par_iter_mut()
408 .map(|chain| chain.run(n_collect, n_discard))
409 .collect();
410 let views: Vec<ArrayView2<V::Scalar>> = chain_samples.iter().map(|s| s.view()).collect();
411 ndarray::stack(Axis(0), &views).expect("expected stacking chain samples to succeed")
412 }
413
414 pub fn run_progress(&mut self, n_collect: usize, n_discard: usize) -> RunResult<V::Scalar> {
415 let chains = &mut self.chains;
416
417 let mut rxs: Vec<Receiver<ChainStats>> = vec![];
418 let mut txs: Vec<Sender<ChainStats>> = vec![];
419 (0..chains.len()).for_each(|_| {
420 let (tx, rx) = mpsc::channel();
421 rxs.push(rx);
422 txs.push(tx);
423 });
424
425 let progress_handle = thread::spawn(move || {
426 let sleep_ms = Duration::from_millis(250);
427 let timeout_ms = Duration::from_millis(0);
428 let multi = MultiProgress::new();
429
430 let pb_style = ProgressStyle::default_bar()
431 .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
432 .unwrap()
433 .progress_chars("=>-");
434 let total: u64 = (n_collect + n_discard).try_into().unwrap();
435
436 let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
437 global_pb.set_style(pb_style.clone());
438 global_pb.set_prefix("Global");
439
440 let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
441 .map(|chain_idx| {
442 let pb = multi.add(ProgressBar::new(total));
443 pb.set_style(pb_style.clone());
444 pb.set_prefix(format!("Chain {chain_idx}"));
445 (chain_idx, pb)
446 })
447 .collect();
448 let mut next_active = active.len();
449 let mut n_finished = 0;
450 let mut most_recent = vec![None; rxs.len()];
451
452 loop {
453 for (i, rx) in rxs.iter().enumerate() {
454 while let Ok(stats) = rx.recv_timeout(timeout_ms) {
455 most_recent[i] = Some(stats)
456 }
457 }
458
459 let mut to_replace = vec![false; active.len()];
460 let mut avg_p_accept = 0.0;
461 let mut n_available_stats = 0.0;
462 for (vec_idx, (i, pb)) in active.iter().enumerate() {
463 if let Some(stats) = &most_recent[*i] {
464 pb.set_position(stats.n);
465 pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
466 avg_p_accept += stats.p_accept;
467 n_available_stats += 1.0;
468
469 if stats.n == total {
470 to_replace[vec_idx] = true;
471 n_finished += 1;
472 }
473 }
474 }
475 if n_available_stats > 0.0 {
476 avg_p_accept /= n_available_stats;
477 }
478
479 let mut total_progress = 0;
480 for stats in most_recent.iter().flatten() {
481 total_progress += stats.n;
482 }
483 global_pb.set_position(total_progress);
484 let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
485 if valid.len() >= 2 {
486 let rhats = collect_rhat(valid.as_slice());
487 let max = max_skipnan(&rhats);
488 global_pb.set_message(format!(
489 "p(accept)≈{:.2} max(rhat)≈{:.2}",
490 avg_p_accept, max
491 ))
492 }
493
494 let mut to_remove = vec![];
495 for (i, replace) in to_replace.iter().enumerate() {
496 if *replace && next_active < most_recent.len() {
497 let pb = multi.add(ProgressBar::new(total));
498 pb.set_style(pb_style.clone());
499 pb.set_prefix(format!("Chain {next_active}"));
500 active[i] = (next_active, pb);
501 next_active += 1;
502 } else if *replace {
503 to_remove.push(i);
504 }
505 }
506
507 to_remove.sort();
508 for i in to_remove.iter().rev() {
509 active.remove(*i);
510 }
511
512 if n_finished >= most_recent.len() {
513 break;
514 }
515 std::thread::sleep(sleep_ms);
516 }
517 });
518
519 let chain_sample: Vec<Array2<V::Scalar>> = thread::scope(|s| {
520 let handles: Vec<thread::ScopedJoinHandle<Array2<V::Scalar>>> = chains
521 .iter_mut()
522 .zip(txs)
523 .map(|(chain, tx)| {
524 s.spawn(|| {
525 chain
526 .run_progress(n_collect, n_discard, tx)
527 .expect("expected running chain to succeed.")
528 })
529 })
530 .collect();
531 handles
532 .into_iter()
533 .map(|h| {
534 h.join()
535 .expect("expected thread to succeed in generating observation.")
536 })
537 .collect()
538 });
539 let views: Vec<ArrayView2<V::Scalar>> = chain_sample.iter().map(|s| s.view()).collect();
540 let sample = ndarray::stack(Axis(0), &views).expect("expected stacking sample to succeed");
541
542 if let Err(e) = progress_handle.join() {
543 eprintln!("Progress bar thread emitted error message: {:?}", e);
544 }
545
546 let run_stats = RunStats::from(sample.view());
547 Ok((sample, run_stats))
548 }
549
550 pub fn set_seed(mut self, seed: u64) -> Self {
551 for (i, chain) in self.chains.iter_mut().enumerate() {
552 let chain_seed = seed + i as u64 + 1;
553 chain.rng = SmallRng::seed_from_u64(chain_seed);
554 }
555 self
556 }
557}
558
559pub struct GenericNUTSChain<V, Target>
561where
562 V: EuclideanVector,
563 Target: HamiltonianTarget<V>,
564{
565 target: Arc<Target>,
566 position: V,
567 target_accept_p: V::Scalar,
568 epsilon: V::Scalar,
569 m: usize,
570 n_collect: usize,
571 n_discard: usize,
572 gamma: V::Scalar,
573 t_0: usize,
574 kappa: V::Scalar,
575 mu: V::Scalar,
576 epsilon_bar: V::Scalar,
577 h_bar: V::Scalar,
578 mass_matrix: MassMatrix<V::Scalar>,
579 mass_warmup: Option<MassMatrixWarmup<V::Scalar>>,
580 rng: SmallRng,
581}
582
583impl<V, Target> GenericNUTSChain<V, Target>
584where
585 V: EuclideanVector,
586 V::Scalar: Float + FromPrimitive + ToPrimitive,
587 Target: HamiltonianTarget<V> + Sync + Send,
588 StandardNormal: RandDistribution<V::Scalar>,
589 StandardUniform: RandDistribution<V::Scalar>,
590 Exp1: RandDistribution<V::Scalar>,
591{
592 pub fn new(target: Target, initial_position: V, target_accept_p: V::Scalar) -> Self {
593 let target = Arc::new(target);
594 Self::new_shared(
595 target,
596 initial_position,
597 target_accept_p,
598 NUTSMassMatrixConfig::disabled(),
599 )
600 }
601
602 pub(crate) fn new_shared(
603 target: Arc<Target>,
604 initial_position: V,
605 target_accept_p: V::Scalar,
606 mass_config: NUTSMassMatrixConfig,
607 ) -> Self {
608 let mut thread_rng = rand::rng();
609 let rng = SmallRng::from_rng(&mut thread_rng);
610 let epsilon = -V::Scalar::one();
611 let dim = initial_position.len();
612 let adaptation = if mass_config.adaptation == MassMatrixAdaptation::Dense
613 && dim > mass_config.dense_max_dim
614 {
615 MassMatrixAdaptation::Diagonal
616 } else {
617 mass_config.adaptation
618 };
619 let mass_matrix = MassMatrix::identity(dim);
620 let mass_warmup = match adaptation {
621 MassMatrixAdaptation::None => None,
622 MassMatrixAdaptation::Diagonal => {
623 Some(MassMatrixWarmup::new(dim, mass_config.clone(), false))
624 }
625 MassMatrixAdaptation::Dense => {
626 Some(MassMatrixWarmup::new(dim, mass_config.clone(), true))
627 }
628 };
629
630 Self {
631 target,
632 position: initial_position,
633 target_accept_p,
634 epsilon,
635 m: 0,
636 n_collect: 0,
637 n_discard: 0,
638 gamma: V::Scalar::from_f64(0.05).unwrap(),
639 t_0: 10,
640 kappa: V::Scalar::from_f64(0.75).unwrap(),
641 mu: (V::Scalar::from_f64(10.0).unwrap() * V::Scalar::one()).ln(),
642 epsilon_bar: V::Scalar::one(),
643 h_bar: V::Scalar::zero(),
644 mass_matrix,
645 mass_warmup,
646 rng,
647 }
648 }
649
650 pub fn set_seed(mut self, seed: u64) -> Self {
651 self.rng = SmallRng::seed_from_u64(seed);
652 self
653 }
654
655 pub fn position(&self) -> &V {
656 &self.position
657 }
658
659 pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Array2<V::Scalar> {
660 let (dim, mut sample) = self.init_chain(n_collect, n_discard);
661 let mut scratch = vec![V::Scalar::zero(); dim];
662
663 for m in 1..(n_collect + n_discard) {
664 self.step();
665
666 if m >= n_discard {
667 self.position.write_to_slice(&mut scratch);
668 let view = ArrayView1::from(&scratch);
669 sample.slice_mut(s![m - n_discard, ..]).assign(&view);
670 }
671 }
672 sample
673 }
674
675 fn run_progress(
676 &mut self,
677 n_collect: usize,
678 n_discard: usize,
679 tx: Sender<ChainStats>,
680 ) -> Result<Array2<V::Scalar>, Box<dyn Error>> {
681 let (dim, mut sample) = self.init_chain(n_collect, n_discard);
682 let mut scratch = vec![V::Scalar::zero(); dim];
683 self.position.write_to_slice(&mut scratch);
684
685 let mut tracker = ChainTracker::new(dim, &scratch);
686 let mut last = Instant::now();
687 let freq = Duration::from_secs(1);
688 let total = n_discard + n_collect;
689
690 for i in 0..total {
691 self.step();
692 self.position.write_to_slice(&mut scratch);
693 tracker.step(&scratch).map_err(|e| {
694 let msg = format!(
695 "Chain statistics tracker caused error: {}.\nAborting generation of further observations.",
696 e
697 );
698 println!("{}", msg);
699 msg
700 })?;
701
702 let now = Instant::now();
703 if (now >= last + freq) | (i == total - 1) {
704 if let Err(e) = tx.send(tracker.stats()) {
705 eprintln!("Sending chain statistics failed: {e}");
706 }
707 last = now;
708 }
709
710 if i >= n_discard {
711 let view = ArrayView1::from(&scratch);
712 sample.slice_mut(s![i - n_discard, ..]).assign(&view);
713 }
714 }
715
716 Ok(sample)
717 }
718
719 fn init_chain(&mut self, n_collect: usize, n_discard: usize) -> (usize, Array2<V::Scalar>) {
720 let dim = self.init_chain_state(n_collect, n_discard);
721
722 let mut sample = Array2::<V::Scalar>::zeros((n_collect, dim));
723 let mut scratch = vec![V::Scalar::zero(); dim];
724 self.position.write_to_slice(&mut scratch);
725 let view = ArrayView1::from(&scratch);
726 sample.slice_mut(s![0, ..]).assign(&view);
727
728 (dim, sample)
729 }
730
731 pub(crate) fn init_chain_state(&mut self, n_collect: usize, n_discard: usize) -> usize {
732 let dim = self.position.len();
733 self.n_collect = n_collect;
734 self.n_discard = n_discard;
735 self.m = 0;
736
737 let mut mom_0 = self.position.zeros_like();
738 let mut mom_buf = vec![V::Scalar::zero(); dim];
739 self.mass_matrix
740 .sample_momentum(&mut self.rng, &mut mom_buf);
741 mom_0.read_from_slice(&mom_buf);
742 if let Some(warmup) = self.mass_warmup.as_mut() {
743 warmup.running.reset();
744 }
745 if V::Scalar::abs(self.epsilon + V::Scalar::one()) <= V::Scalar::epsilon() {
746 self.epsilon = find_reasonable_epsilon(&self.position, &mom_0, self.target.as_ref());
747 }
748 self.mu = (V::Scalar::from_f64(10.0).unwrap() * self.epsilon).ln();
749 dim
750 }
751
752 pub fn step(&mut self) {
753 self.m += 1;
754
755 let dim = self.position.len();
756 let mut mom_0 = self.position.zeros_like();
757 let mut mom_buf = vec![V::Scalar::zero(); dim];
758 self.mass_matrix
759 .sample_momentum(&mut self.rng, &mut mom_buf);
760 mom_0.read_from_slice(&mom_buf);
761
762 let mut grad = self.position.zeros_like();
763 let logp = self.target.logp_and_grad(&self.position, &mut grad);
764 let joint = logp - kinetic_energy(&self.mass_matrix, &mom_0);
765 let exp1_obs: V::Scalar = self.rng.sample(Exp1);
766 let logu = joint - exp1_obs;
767
768 let mut position_minus = self.position.clone();
769 let mut position_plus = self.position.clone();
770 let mut mom_minus = mom_0.clone();
771 let mut mom_plus = mom_0.clone();
772 let mut grad_minus = grad.clone();
773 let mut grad_plus = grad.clone();
774 let mut j = 0;
775 let mut n = 1;
776 let mut s = true;
777 let mut alpha: V::Scalar = V::Scalar::zero();
778 let mut n_alpha: usize = 0;
779
780 while s {
781 let u_run_1: V::Scalar = self.rng.random();
782 let v = (2 * (u_run_1 < V::Scalar::from_f64(0.5).unwrap()) as i8) - 1;
783
784 let (position_prime, n_prime, s_prime) = if v == -1 {
785 let (
786 position_minus_2,
787 mom_minus_2,
788 grad_minus_2,
789 _,
790 _,
791 _,
792 position_prime_2,
793 _,
794 _,
795 n_prime_2,
796 s_prime_2,
797 alpha_2,
798 n_alpha_2,
799 ) = build_tree_with_mass(
800 position_minus.clone(),
801 mom_minus.clone(),
802 grad_minus.clone(),
803 logu,
804 v,
805 j,
806 self.epsilon,
807 self.target.as_ref(),
808 &self.mass_matrix,
809 joint,
810 &mut self.rng,
811 );
812
813 position_minus = position_minus_2;
814 mom_minus = mom_minus_2;
815 grad_minus = grad_minus_2;
816
817 alpha = alpha_2;
818 n_alpha = n_alpha_2;
819 (position_prime_2, n_prime_2, s_prime_2)
820 } else {
821 let (
822 _,
823 _,
824 _,
825 position_plus_2,
826 mom_plus_2,
827 grad_plus_2,
828 position_prime_2,
829 _,
830 _,
831 n_prime_2,
832 s_prime_2,
833 alpha_2,
834 n_alpha_2,
835 ) = build_tree_with_mass(
836 position_plus.clone(),
837 mom_plus.clone(),
838 grad_plus.clone(),
839 logu,
840 v,
841 j,
842 self.epsilon,
843 self.target.as_ref(),
844 &self.mass_matrix,
845 joint,
846 &mut self.rng,
847 );
848
849 position_plus = position_plus_2;
850 mom_plus = mom_plus_2;
851 grad_plus = grad_plus_2;
852
853 alpha = alpha_2;
854 n_alpha = n_alpha_2;
855 (position_prime_2, n_prime_2, s_prime_2)
856 };
857
858 let tmp = V::Scalar::one().min(
859 V::Scalar::from_usize(n_prime)
860 .expect("successful conversion of n_prime from usize")
861 / V::Scalar::from_usize(n).expect("successful conversion of n from usize"),
862 );
863 let u_run_2: V::Scalar = self.rng.random();
864 if s_prime && (u_run_2 < tmp) {
865 self.position = position_prime;
866 }
867 n += n_prime;
868
869 s = s_prime
870 && stop_criterion_with_mass(
871 position_minus.clone(),
872 position_plus.clone(),
873 mom_minus.clone(),
874 mom_plus.clone(),
875 &self.mass_matrix,
876 );
877 j += 1
878 }
879
880 let mut eta = V::Scalar::one()
881 / V::Scalar::from_usize(self.m + self.t_0).expect("successful conversion of m + t_0");
882 self.h_bar = (V::Scalar::one() - eta) * self.h_bar
883 + eta
884 * (self.target_accept_p
885 - alpha
886 / V::Scalar::from_usize(n_alpha)
887 .expect("successful conversion of n_alpha"));
888 if self.m <= self.n_discard {
889 let m = V::Scalar::from_usize(self.m).expect("successful conversion of m");
890 self.epsilon = (self.mu - m.sqrt() / self.gamma * self.h_bar).exp();
891 eta = m.powf(-self.kappa);
892 self.epsilon_bar =
893 ((V::Scalar::one() - eta) * self.epsilon_bar.ln() + eta * self.epsilon.ln()).exp();
894
895 if let Some(warmup) = self.mass_warmup.as_mut()
896 && warmup.should_collect(self.m, self.n_discard)
897 {
898 let mut q = vec![V::Scalar::zero(); dim];
899 self.position.write_to_slice(&mut q);
900 warmup.running.update(&q);
901 if warmup.note_if_window_end(self.m, self.n_discard)
902 && let Some(updated) = maybe_update_mass_matrix(&self.mass_matrix, warmup)
903 {
904 self.mass_matrix = updated;
905 let mut probe = self.position.zeros_like();
906 let mut probe_buf = vec![V::Scalar::zero(); dim];
907 self.mass_matrix
908 .sample_momentum(&mut self.rng, &mut probe_buf);
909 probe.read_from_slice(&probe_buf);
910 self.epsilon =
911 find_reasonable_epsilon(&self.position, &probe, self.target.as_ref());
912 self.mu = (V::Scalar::from_f64(10.0).unwrap() * self.epsilon).ln();
913 self.epsilon_bar = self.epsilon;
914 self.h_bar = V::Scalar::zero();
915 warmup.running.reset();
916 }
917 }
918 } else {
919 self.epsilon = self.epsilon_bar;
920 }
921 }
922}
923
924fn kinetic_energy<V: EuclideanVector>(mass: &MassMatrix<V::Scalar>, mom: &V) -> V::Scalar
925where
926 V::Scalar: Float + FromPrimitive,
927{
928 let mut p = vec![V::Scalar::zero(); mom.len()];
929 mom.write_to_slice(&mut p);
930 mass.kinetic(&p)
931}
932
933fn apply_inv_mass<V: EuclideanVector>(mass: &MassMatrix<V::Scalar>, input: &V, out: &mut V)
934where
935 V::Scalar: Float + FromPrimitive,
936{
937 let mut p = vec![V::Scalar::zero(); input.len()];
938 let mut v = vec![V::Scalar::zero(); input.len()];
939 input.write_to_slice(&mut p);
940 mass.inv_mul(&p, &mut v);
941 out.read_from_slice(&v);
942}
943
944fn maybe_update_mass_matrix<S: Float + FromPrimitive>(
945 current: &MassMatrix<S>,
946 warmup: &MassMatrixWarmup<S>,
947) -> Option<MassMatrix<S>> {
948 let n = warmup.running.n;
949 if n < 5 {
950 return None;
951 }
952 let n_denom = S::from_usize(n - 1).unwrap();
953 let reg = S::from_f64(warmup.config.regularize).unwrap();
954 let one_minus_reg = S::one() - reg;
955 let jitter = S::from_f64(warmup.config.jitter.max(1e-10)).unwrap();
956 match warmup.config.adaptation {
957 MassMatrixAdaptation::None => None,
958 MassMatrixAdaptation::Diagonal => {
959 let mut var = vec![S::zero(); warmup.running.dim];
960 for (i, vi) in var.iter_mut().enumerate().take(warmup.running.dim) {
961 let raw = warmup.running.m2_diag[i] / n_denom;
962 *vi = (one_minus_reg * raw + reg).max(jitter);
963 }
964 Some(MassMatrix::diagonal_from_var(var, jitter))
965 }
966 MassMatrixAdaptation::Dense => {
967 let dim = warmup.running.dim;
968 let Some(m2_dense) = warmup.running.m2_dense.as_ref() else {
969 return None;
970 };
971 let mut cov = vec![S::zero(); dim * dim];
972 for i in 0..dim {
973 for j in i..dim {
974 let idx = i * dim + j;
975 let raw = m2_dense[idx] / n_denom;
976 let v = if i == j {
977 (one_minus_reg * raw + reg).max(jitter)
978 } else {
979 one_minus_reg * raw
980 };
981 cov[idx] = v;
982 cov[j * dim + i] = v;
983 }
984 }
985 MassMatrix::dense_from_cov(cov, dim, jitter).or_else(|| match current {
986 MassMatrix::Diagonal { .. } | MassMatrix::Dense { .. } => None,
987 MassMatrix::Identity { dim } => {
988 Some(MassMatrix::diagonal_from_var(vec![S::one(); *dim], jitter))
989 }
990 })
991 }
992 }
993}
994
995fn all_real_vec<V: EuclideanVector>(v: &V) -> bool
996where
997 V::Scalar: Float,
998{
999 let mut scratch = vec![V::Scalar::zero(); v.len()];
1000 v.write_to_slice(&mut scratch);
1001 scratch.iter().all(|x: &V::Scalar| x.is_finite())
1002}
1003
1004#[allow(dead_code)]
1005pub(crate) fn find_reasonable_epsilon<V, Target>(
1006 position: &V,
1007 mom: &V,
1008 gradient_target: &Target,
1009) -> V::Scalar
1010where
1011 V: EuclideanVector,
1012 V::Scalar: Float + FromPrimitive,
1013 Target: HamiltonianTarget<V> + Sync,
1014 StandardNormal: RandDistribution<V::Scalar>,
1015 StandardUniform: RandDistribution<V::Scalar>,
1016{
1017 let mass_matrix = MassMatrix::identity(position.len());
1018 find_reasonable_epsilon_with_mass(position, mom, gradient_target, &mass_matrix)
1019}
1020
1021fn find_reasonable_epsilon_with_mass<V, Target>(
1022 position: &V,
1023 mom: &V,
1024 gradient_target: &Target,
1025 mass_matrix: &MassMatrix<V::Scalar>,
1026) -> V::Scalar
1027where
1028 V: EuclideanVector,
1029 V::Scalar: Float + FromPrimitive,
1030 Target: HamiltonianTarget<V> + Sync,
1031 StandardNormal: RandDistribution<V::Scalar>,
1032 StandardUniform: RandDistribution<V::Scalar>,
1033{
1034 let mut epsilon = V::Scalar::one();
1035 let half = V::Scalar::from_f64(0.5).unwrap();
1036
1037 let mut grad = position.zeros_like();
1038 let ulogp = gradient_target.logp_and_grad(position, &mut grad);
1039
1040 let mut position_prime = position.clone();
1041 let mut mom_prime = mom.clone();
1042 let mut grad_prime = grad.clone();
1043 let mut ulogp_prime = leapfrog_with_mass(
1044 &mut position_prime,
1045 &mut mom_prime,
1046 &mut grad_prime,
1047 epsilon,
1048 gradient_target,
1049 mass_matrix,
1050 );
1051 let mut k = V::Scalar::one();
1052
1053 while !ulogp_prime.is_finite() || !all_real_vec(&grad_prime) {
1054 k = k * half;
1055 position_prime.assign(position);
1056 mom_prime.assign(mom);
1057 grad_prime.assign(&grad);
1058 ulogp_prime = leapfrog_with_mass(
1059 &mut position_prime,
1060 &mut mom_prime,
1061 &mut grad_prime,
1062 epsilon * k,
1063 gradient_target,
1064 mass_matrix,
1065 );
1066 }
1067
1068 epsilon = half * k * epsilon;
1069 let log_accept_prob = ulogp_prime
1070 - ulogp
1071 - (kinetic_energy(mass_matrix, &mom_prime) - kinetic_energy(mass_matrix, mom));
1072 let mut log_accept_prob = log_accept_prob;
1073
1074 let a = if log_accept_prob > half.ln() {
1075 V::Scalar::one()
1076 } else {
1077 -V::Scalar::one()
1078 };
1079
1080 while a * log_accept_prob > -a * V::Scalar::from_f64(2.0).unwrap().ln() {
1081 epsilon = epsilon * V::Scalar::from_f64(2.0).unwrap().powf(a);
1082 position_prime.assign(position);
1083 mom_prime.assign(mom);
1084 grad_prime.assign(&grad);
1085 ulogp_prime = leapfrog_with_mass(
1086 &mut position_prime,
1087 &mut mom_prime,
1088 &mut grad_prime,
1089 epsilon,
1090 gradient_target,
1091 mass_matrix,
1092 );
1093 log_accept_prob = ulogp_prime
1094 - ulogp
1095 - (kinetic_energy(mass_matrix, &mom_prime) - kinetic_energy(mass_matrix, mom));
1096 }
1097
1098 epsilon
1099}
1100
1101#[allow(clippy::too_many_arguments, clippy::type_complexity)]
1102fn build_tree_with_mass<V, Target>(
1103 position: V,
1104 mom: V,
1105 grad: V,
1106 logu: V::Scalar,
1107 v: i8,
1108 j: usize,
1109 epsilon: V::Scalar,
1110 gradient_target: &Target,
1111 mass_matrix: &MassMatrix<V::Scalar>,
1112 joint_0: V::Scalar,
1113 rng: &mut SmallRng,
1114) -> (
1115 V,
1116 V,
1117 V,
1118 V,
1119 V,
1120 V,
1121 V,
1122 V,
1123 V::Scalar,
1124 usize,
1125 bool,
1126 V::Scalar,
1127 usize,
1128)
1129where
1130 V: EuclideanVector,
1131 V::Scalar: Float + FromPrimitive,
1132 Target: HamiltonianTarget<V> + Sync,
1133{
1134 if j == 0 {
1135 let mut position_prime = position.clone();
1136 let mut mom_prime = mom.clone();
1137 let mut grad_prime = grad.clone();
1138 let logp_prime = leapfrog_with_mass(
1139 &mut position_prime,
1140 &mut mom_prime,
1141 &mut grad_prime,
1142 V::Scalar::from_i64(v as i64).unwrap() * epsilon,
1143 gradient_target,
1144 mass_matrix,
1145 );
1146 let joint = logp_prime - kinetic_energy(mass_matrix, &mom_prime);
1147 let n_prime = (logu < joint) as usize;
1148 let s_prime = (logu - V::Scalar::from_f64(1000.0).unwrap()) < joint;
1149 let position_minus = position_prime.clone();
1150 let position_plus = position_prime.clone();
1151 let mom_minus = mom_prime.clone();
1152 let mom_plus = mom_prime.clone();
1153 let grad_minus = grad_prime.clone();
1154 let grad_plus = grad_prime.clone();
1155 let alpha_prime = V::Scalar::one().min((joint - joint_0).exp());
1156 let n_alpha_prime = 1_usize;
1157 (
1158 position_minus,
1159 mom_minus,
1160 grad_minus,
1161 position_plus,
1162 mom_plus,
1163 grad_plus,
1164 position_prime,
1165 grad_prime,
1166 logp_prime,
1167 n_prime,
1168 s_prime,
1169 alpha_prime,
1170 n_alpha_prime,
1171 )
1172 } else {
1173 let (
1174 mut position_minus,
1175 mut mom_minus,
1176 mut grad_minus,
1177 mut position_plus,
1178 mut mom_plus,
1179 mut grad_plus,
1180 mut position_prime,
1181 mut grad_prime,
1182 mut logp_prime,
1183 mut n_prime,
1184 mut s_prime,
1185 mut alpha_prime,
1186 mut n_alpha_prime,
1187 ) = build_tree_with_mass(
1188 position,
1189 mom,
1190 grad,
1191 logu,
1192 v,
1193 j - 1,
1194 epsilon,
1195 gradient_target,
1196 mass_matrix,
1197 joint_0,
1198 rng,
1199 );
1200 if s_prime {
1201 let (
1202 position_minus_2,
1203 mom_minus_2,
1204 grad_minus_2,
1205 position_plus_2,
1206 mom_plus_2,
1207 grad_plus_2,
1208 position_prime_2,
1209 grad_prime_2,
1210 logp_prime_2,
1211 n_prime_2,
1212 s_prime_2,
1213 alpha_prime_2,
1214 n_alpha_prime_2,
1215 ) = if v == -1 {
1216 build_tree_with_mass(
1217 position_minus.clone(),
1218 mom_minus.clone(),
1219 grad_minus.clone(),
1220 logu,
1221 v,
1222 j - 1,
1223 epsilon,
1224 gradient_target,
1225 mass_matrix,
1226 joint_0,
1227 rng,
1228 )
1229 } else {
1230 build_tree_with_mass(
1231 position_plus.clone(),
1232 mom_plus.clone(),
1233 grad_plus.clone(),
1234 logu,
1235 v,
1236 j - 1,
1237 epsilon,
1238 gradient_target,
1239 mass_matrix,
1240 joint_0,
1241 rng,
1242 )
1243 };
1244 if v == -1 {
1245 position_minus = position_minus_2;
1246 mom_minus = mom_minus_2;
1247 grad_minus = grad_minus_2;
1248 } else {
1249 position_plus = position_plus_2;
1250 mom_plus = mom_plus_2;
1251 grad_plus = grad_plus_2;
1252 }
1253
1254 let u_build_tree: f64 = rng.random();
1255 if u_build_tree < (n_prime_2 as f64 / (n_prime + n_prime_2).max(1) as f64) {
1256 position_prime = position_prime_2;
1257 grad_prime = grad_prime_2;
1258 logp_prime = logp_prime_2;
1259 }
1260
1261 n_prime += n_prime_2;
1262
1263 s_prime = s_prime
1264 && s_prime_2
1265 && stop_criterion(
1266 position_minus.clone(),
1267 position_plus.clone(),
1268 mom_minus.clone(),
1269 mom_plus.clone(),
1270 );
1271 alpha_prime = alpha_prime + alpha_prime_2;
1272 n_alpha_prime += n_alpha_prime_2;
1273 }
1274 (
1275 position_minus,
1276 mom_minus,
1277 grad_minus,
1278 position_plus,
1279 mom_plus,
1280 grad_plus,
1281 position_prime,
1282 grad_prime,
1283 logp_prime,
1284 n_prime,
1285 s_prime,
1286 alpha_prime,
1287 n_alpha_prime,
1288 )
1289 }
1290}
1291
1292pub(crate) fn stop_criterion<V>(
1293 position_minus: V,
1294 position_plus: V,
1295 mom_minus: V,
1296 mom_plus: V,
1297) -> bool
1298where
1299 V: EuclideanVector,
1300 V::Scalar: Float + FromPrimitive,
1301{
1302 let mass_matrix = MassMatrix::identity(position_minus.len());
1303 stop_criterion_with_mass(
1304 position_minus,
1305 position_plus,
1306 mom_minus,
1307 mom_plus,
1308 &mass_matrix,
1309 )
1310}
1311
1312fn stop_criterion_with_mass<V>(
1313 position_minus: V,
1314 position_plus: V,
1315 mom_minus: V,
1316 mom_plus: V,
1317 mass_matrix: &MassMatrix<V::Scalar>,
1318) -> bool
1319where
1320 V: EuclideanVector,
1321 V::Scalar: Float + FromPrimitive,
1322{
1323 let mut diff = position_plus.clone();
1325 diff.sub_assign(&position_minus);
1326 let mut vel_minus = mom_minus.zeros_like();
1327 let mut vel_plus = mom_plus.zeros_like();
1328 apply_inv_mass(mass_matrix, &mom_minus, &mut vel_minus);
1329 apply_inv_mass(mass_matrix, &mom_plus, &mut vel_plus);
1330 let dot_minus = diff.dot(&vel_minus);
1331 let dot_plus = diff.dot(&vel_plus);
1332 dot_minus >= V::Scalar::zero() && dot_plus >= V::Scalar::zero()
1333}
1334
1335fn leapfrog_with_mass<V, Target>(
1336 position: &mut V,
1337 momentum: &mut V,
1338 grad: &mut V,
1339 epsilon: V::Scalar,
1340 gradient_target: &Target,
1341 mass_matrix: &MassMatrix<V::Scalar>,
1342) -> V::Scalar
1343where
1344 V: EuclideanVector,
1345 V::Scalar: Float + FromPrimitive,
1346 Target: HamiltonianTarget<V>,
1347{
1348 let half = V::Scalar::from_f64(0.5).unwrap();
1350 momentum.add_scaled_assign(grad, epsilon * half);
1351 let mut velocity = momentum.zeros_like();
1352 apply_inv_mass(mass_matrix, momentum, &mut velocity);
1353 position.add_scaled_assign(&velocity, epsilon);
1354 let logp = gradient_target.logp_and_grad(position, grad);
1355 momentum.add_scaled_assign(grad, epsilon * half);
1356 logp
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361 use super::{
1362 MassMatrix, MassMatrixAdaptation, MassMatrixWarmup, NUTSMassMatrixConfig,
1363 maybe_update_mass_matrix,
1364 };
1365
1366 #[test]
1367 fn diagonal_mass_matrix_kinetic_and_inv_mul_are_consistent() {
1368 let mass = MassMatrix::diagonal_from_var(vec![4.0_f64, 9.0_f64], 1e-12);
1369 let p = [2.0_f64, 3.0_f64];
1370 let ke = mass.kinetic(&p);
1371 assert!((ke - 1.0).abs() < 1e-12);
1373
1374 let mut out = [0.0_f64; 2];
1375 mass.inv_mul(&p, &mut out);
1376 assert!((out[0] - 0.5).abs() < 1e-12);
1377 assert!((out[1] - (1.0 / 3.0)).abs() < 1e-12);
1378 }
1379
1380 #[test]
1381 fn dense_mass_matrix_inverse_matches_identity_action() {
1382 let cov = vec![
1383 2.0_f64, 0.3_f64, 0.3_f64, 1.0_f64,
1385 ];
1386 let mass = MassMatrix::dense_from_cov(cov, 2, 1e-12).expect("dense mass matrix");
1387 let p = [0.7_f64, -1.1_f64];
1388 let mut out = [0.0_f64; 2];
1389 mass.inv_mul(&p, &mut out);
1390
1391 let quad = p[0] * out[0] + p[1] * out[1];
1393 assert!(quad > 0.0);
1394 }
1395
1396 #[test]
1397 fn warmup_diagonal_update_produces_positive_variances() {
1398 let cfg = NUTSMassMatrixConfig {
1399 adaptation: MassMatrixAdaptation::Diagonal,
1400 start_buffer: 1,
1401 end_buffer: 1,
1402 initial_window: 4,
1403 regularize: 0.05,
1404 jitter: 1e-6,
1405 dense_max_dim: 75,
1406 };
1407 let mut warmup = MassMatrixWarmup::new(2, cfg, false);
1408 let current = MassMatrix::identity(2);
1409 for x in [
1410 [-2.0_f64, 1.0_f64],
1411 [-1.0, 0.0],
1412 [0.0, 1.0],
1413 [2.0, -1.0],
1414 [1.0, 0.5],
1415 ] {
1416 warmup.running.update(&x);
1417 }
1418 let updated = maybe_update_mass_matrix(¤t, &warmup).expect("updated mass");
1419 match updated {
1420 MassMatrix::Diagonal { inv, sqrt } => {
1421 for i in 0..2 {
1422 assert!(inv[i].is_finite() && inv[i] > 0.0);
1423 assert!(sqrt[i].is_finite() && sqrt[i] > 0.0);
1424 }
1425 }
1426 _ => panic!("expected diagonal mass matrix"),
1427 }
1428 }
1429}