Skip to main content

general_mcmc/
generic_nuts.rs

1//! Backend-agnostic No-U-Turn Sampler (NUTS) core.
2
3use crate::euclidean::EuclideanVector;
4use crate::generic_hmc::HamiltonianTarget;
5use crate::stats::{ChainStats, ChainTracker, RunStats, collect_rhat, max_skipnan};
6use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
7use ndarray::{Array2, Array3, ArrayView1, ArrayView2, Axis, s};
8use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
9use rand::distr::Distribution as RandDistribution;
10// rand_distr provides the distributions, but we rely on rand's Distribution trait for compatibility.
11use rand::rngs::SmallRng;
12use rand::{Rng, SeedableRng};
13use rand_distr::{Exp1, StandardNormal, StandardUniform};
14use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
15use std::error::Error;
16use std::sync::Arc;
17use std::sync::mpsc;
18use std::sync::mpsc::{Receiver, Sender};
19use std::thread;
20use std::time::{Duration, Instant};
21
22/// Backend-agnostic No-U-Turn Sampler (NUTS) spanning multiple chains.
23pub struct GenericNUTS<V, Target>
24where
25    V: EuclideanVector,
26    Target: HamiltonianTarget<V>,
27{
28    chains: Vec<GenericNUTSChain<V, Target>>,
29}
30
31type RunResult<T> = Result<(Array3<T>, RunStats), Box<dyn Error>>;
32
33/// Mass-matrix adaptation strategy for warmup.
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum MassMatrixAdaptation {
36    None,
37    Diagonal,
38    Dense,
39}
40
41/// Controls warmup-time mass-matrix adaptation for NUTS.
42#[derive(Clone, Debug)]
43pub struct NUTSMassMatrixConfig {
44    pub adaptation: MassMatrixAdaptation,
45    pub start_buffer: usize,
46    pub end_buffer: usize,
47    pub initial_window: usize,
48    pub regularize: f64,
49    pub jitter: f64,
50    pub dense_max_dim: usize,
51}
52
53impl NUTSMassMatrixConfig {
54    pub fn disabled() -> Self {
55        Self {
56            adaptation: MassMatrixAdaptation::None,
57            start_buffer: 0,
58            end_buffer: 0,
59            initial_window: 0,
60            regularize: 0.0,
61            jitter: 0.0,
62            dense_max_dim: 0,
63        }
64    }
65}
66
67impl Default for NUTSMassMatrixConfig {
68    fn default() -> Self {
69        Self {
70            adaptation: MassMatrixAdaptation::Diagonal,
71            start_buffer: 75,
72            end_buffer: 50,
73            initial_window: 25,
74            regularize: 0.05,
75            jitter: 1e-6,
76            dense_max_dim: 75,
77        }
78    }
79}
80
81struct RunningCov<S: Float> {
82    dim: usize,
83    n: usize,
84    mean: Vec<S>,
85    m2_diag: Vec<S>,
86    m2_dense: Option<Vec<S>>,
87}
88
89impl<S: Float + FromPrimitive> RunningCov<S> {
90    fn new(dim: usize, dense: bool) -> Self {
91        Self {
92            dim,
93            n: 0,
94            mean: vec![S::zero(); dim],
95            m2_diag: vec![S::zero(); dim],
96            m2_dense: dense.then(|| vec![S::zero(); dim * dim]),
97        }
98    }
99
100    fn reset(&mut self) {
101        self.n = 0;
102        self.mean.fill(S::zero());
103        self.m2_diag.fill(S::zero());
104        if let Some(m2) = self.m2_dense.as_mut() {
105            m2.fill(S::zero());
106        }
107    }
108
109    fn update(&mut self, x: &[S]) {
110        self.n += 1;
111        let n_s = S::from_usize(self.n).unwrap();
112        let mut delta = vec![S::zero(); self.dim];
113        for i in 0..self.dim {
114            delta[i] = x[i] - self.mean[i];
115            self.mean[i] = self.mean[i] + delta[i] / n_s;
116            let delta2 = x[i] - self.mean[i];
117            self.m2_diag[i] = self.m2_diag[i] + delta[i] * delta2;
118        }
119        if let Some(m2) = self.m2_dense.as_mut() {
120            let mut delta2 = vec![S::zero(); self.dim];
121            for i in 0..self.dim {
122                delta2[i] = x[i] - self.mean[i];
123            }
124            for i in 0..self.dim {
125                for j in i..self.dim {
126                    let idx = i * self.dim + j;
127                    m2[idx] = m2[idx] + delta[i] * delta2[j];
128                }
129            }
130        }
131    }
132}
133
134struct MassMatrixWarmup<S: Float> {
135    config: NUTSMassMatrixConfig,
136    next_window_end: usize,
137    window_len: usize,
138    running: RunningCov<S>,
139}
140
141impl<S: Float + FromPrimitive> MassMatrixWarmup<S> {
142    fn new(dim: usize, config: NUTSMassMatrixConfig, dense: bool) -> Self {
143        let start_buffer = config.start_buffer.max(1);
144        let window_len = config.initial_window.max(10);
145        Self {
146            config,
147            next_window_end: start_buffer + window_len,
148            window_len,
149            running: RunningCov::new(dim, dense),
150        }
151    }
152
153    fn should_collect(&self, m: usize, n_warmup: usize) -> bool {
154        if m == 0 || m > n_warmup {
155            return false;
156        }
157        if m <= self.config.start_buffer {
158            return false;
159        }
160        m < n_warmup.saturating_sub(self.config.end_buffer)
161    }
162
163    fn note_if_window_end(&mut self, m: usize, n_warmup: usize) -> bool {
164        if !self.should_collect(m, n_warmup) {
165            return false;
166        }
167        if m >= self.next_window_end || m + 1 >= n_warmup.saturating_sub(self.config.end_buffer) {
168            self.next_window_end = self.next_window_end.saturating_add(self.window_len);
169            self.window_len = (self.window_len.saturating_mul(2)).min(400);
170            return true;
171        }
172        false
173    }
174}
175
176#[derive(Clone)]
177enum MassMatrix<S: Float> {
178    Identity {
179        dim: usize,
180    },
181    Diagonal {
182        inv: Vec<S>,
183        sqrt: Vec<S>,
184    },
185    Dense {
186        dim: usize,
187        inv: Vec<S>,
188        chol: Vec<S>,
189    },
190}
191
192impl<S: Float + FromPrimitive> MassMatrix<S> {
193    fn identity(dim: usize) -> Self {
194        Self::Identity { dim }
195    }
196
197    fn diagonal_from_var(mut var: Vec<S>, jitter: S) -> Self {
198        let mut inv = vec![S::zero(); var.len()];
199        let mut sqrt = vec![S::zero(); var.len()];
200        for i in 0..var.len() {
201            let v = var[i].max(jitter);
202            var[i] = v;
203            inv[i] = S::one() / v;
204            sqrt[i] = v.sqrt();
205        }
206        Self::Diagonal { inv, sqrt }
207    }
208
209    fn dense_from_cov(cov: Vec<S>, dim: usize, jitter: S) -> Option<Self> {
210        let max_tries = 8usize;
211        let mut j = jitter.max(S::from_f64(1e-10).unwrap());
212        for _ in 0..max_tries {
213            let mut cov_try = cov.clone();
214            for d in 0..dim {
215                cov_try[d * dim + d] = cov_try[d * dim + d] + j;
216            }
217            if let Some(chol) = cholesky_spd(&cov_try, dim)
218                && let Some(inv) = invert_spd_from_cholesky(&chol, dim)
219            {
220                return Some(Self::Dense { dim, inv, chol });
221            }
222            j = j * S::from_f64(10.0).unwrap();
223        }
224        None
225    }
226
227    fn kinetic(&self, momentum: &[S]) -> S {
228        let half = S::from_f64(0.5).unwrap();
229        match self {
230            Self::Identity { .. } => {
231                let mut q = S::zero();
232                for v in momentum {
233                    q = q + *v * *v;
234                }
235                half * q
236            }
237            Self::Diagonal { inv, .. } => {
238                let mut q = S::zero();
239                for i in 0..momentum.len() {
240                    q = q + momentum[i] * momentum[i] * inv[i];
241                }
242                half * q
243            }
244            Self::Dense { inv, dim, .. } => {
245                let mut q = S::zero();
246                for i in 0..*dim {
247                    let mut row_dot = S::zero();
248                    for j in 0..*dim {
249                        row_dot = row_dot + inv[i * *dim + j] * momentum[j];
250                    }
251                    q = q + momentum[i] * row_dot;
252                }
253                half * q
254            }
255        }
256    }
257
258    fn inv_mul(&self, input: &[S], out: &mut [S]) {
259        match self {
260            Self::Identity { .. } => out.copy_from_slice(input),
261            Self::Diagonal { inv, .. } => {
262                for i in 0..input.len() {
263                    out[i] = inv[i] * input[i];
264                }
265            }
266            Self::Dense { inv, dim, .. } => {
267                for i in 0..*dim {
268                    let mut acc = S::zero();
269                    for j in 0..*dim {
270                        acc = acc + inv[i * *dim + j] * input[j];
271                    }
272                    out[i] = acc;
273                }
274            }
275        }
276    }
277
278    fn sample_momentum(&self, rng: &mut SmallRng, out: &mut [S])
279    where
280        StandardNormal: RandDistribution<S>,
281    {
282        for v in out.iter_mut() {
283            *v = rng.sample(StandardNormal);
284        }
285        match self {
286            Self::Identity { .. } => {}
287            Self::Diagonal { sqrt, .. } => {
288                for i in 0..out.len() {
289                    out[i] = out[i] * sqrt[i];
290                }
291            }
292            Self::Dense { chol, dim, .. } => {
293                let z = out.to_vec();
294                for i in 0..*dim {
295                    let mut acc = S::zero();
296                    for j in 0..=i {
297                        acc = acc + chol[i * *dim + j] * z[j];
298                    }
299                    out[i] = acc;
300                }
301            }
302        }
303    }
304}
305
306fn cholesky_spd<S: Float + FromPrimitive>(a: &[S], dim: usize) -> Option<Vec<S>> {
307    let mut l = vec![S::zero(); dim * dim];
308    for i in 0..dim {
309        for j in 0..=i {
310            let mut sum = a[i * dim + j];
311            for k in 0..j {
312                sum = sum - l[i * dim + k] * l[j * dim + k];
313            }
314            if i == j {
315                if sum <= S::zero() || !sum.is_finite() {
316                    return None;
317                }
318                l[i * dim + j] = sum.sqrt();
319            } else {
320                let d = l[j * dim + j];
321                if d <= S::zero() || !d.is_finite() {
322                    return None;
323                }
324                l[i * dim + j] = sum / d;
325            }
326        }
327    }
328    Some(l)
329}
330
331fn invert_spd_from_cholesky<S: Float + FromPrimitive>(l: &[S], dim: usize) -> Option<Vec<S>> {
332    let mut inv_l = vec![S::zero(); dim * dim];
333    for i in 0..dim {
334        let d = l[i * dim + i];
335        if d <= S::zero() || !d.is_finite() {
336            return None;
337        }
338        inv_l[i * dim + i] = S::one() / d;
339        for j in (i + 1)..dim {
340            let mut sum = S::zero();
341            for k in i..j {
342                sum = sum + l[j * dim + k] * inv_l[k * dim + i];
343            }
344            inv_l[j * dim + i] = -sum / l[j * dim + j];
345        }
346    }
347    let mut inv = vec![S::zero(); dim * dim];
348    for i in 0..dim {
349        for j in 0..=i {
350            let mut sum = S::zero();
351            for k in i.max(j)..dim {
352                sum = sum + inv_l[k * dim + i] * inv_l[k * dim + j];
353            }
354            inv[i * dim + j] = sum;
355            inv[j * dim + i] = sum;
356        }
357    }
358    Some(inv)
359}
360
361impl<V, Target> GenericNUTS<V, Target>
362where
363    V: EuclideanVector + Send,
364    V::Scalar: Float + FromPrimitive + ToPrimitive + Send,
365    Target: HamiltonianTarget<V> + Sync + Send,
366    StandardNormal: RandDistribution<V::Scalar>,
367    StandardUniform: RandDistribution<V::Scalar>,
368    Exp1: RandDistribution<V::Scalar>,
369{
370    pub fn new(target: Target, initial_positions: Vec<V>, target_accept_p: V::Scalar) -> Self {
371        Self::new_with_mass_matrix(
372            target,
373            initial_positions,
374            target_accept_p,
375            NUTSMassMatrixConfig::disabled(),
376        )
377    }
378
379    pub fn new_with_mass_matrix(
380        target: Target,
381        initial_positions: Vec<V>,
382        target_accept_p: V::Scalar,
383        mass_config: NUTSMassMatrixConfig,
384    ) -> Self {
385        let target = Arc::new(target);
386        let chains = initial_positions
387            .into_iter()
388            .map(|pos| {
389                GenericNUTSChain::new_shared(
390                    Arc::clone(&target),
391                    pos,
392                    target_accept_p,
393                    mass_config.clone(),
394                )
395            })
396            .collect();
397        Self { chains }
398    }
399
400    pub(crate) fn chains_mut(&mut self) -> &mut [GenericNUTSChain<V, Target>] {
401        &mut self.chains
402    }
403
404    pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Array3<V::Scalar> {
405        let chain_samples: Vec<Array2<V::Scalar>> = self
406            .chains
407            .par_iter_mut()
408            .map(|chain| chain.run(n_collect, n_discard))
409            .collect();
410        let views: Vec<ArrayView2<V::Scalar>> = chain_samples.iter().map(|s| s.view()).collect();
411        ndarray::stack(Axis(0), &views).expect("expected stacking chain samples to succeed")
412    }
413
414    pub fn run_progress(&mut self, n_collect: usize, n_discard: usize) -> RunResult<V::Scalar> {
415        let chains = &mut self.chains;
416
417        let mut rxs: Vec<Receiver<ChainStats>> = vec![];
418        let mut txs: Vec<Sender<ChainStats>> = vec![];
419        (0..chains.len()).for_each(|_| {
420            let (tx, rx) = mpsc::channel();
421            rxs.push(rx);
422            txs.push(tx);
423        });
424
425        let progress_handle = thread::spawn(move || {
426            let sleep_ms = Duration::from_millis(250);
427            let timeout_ms = Duration::from_millis(0);
428            let multi = MultiProgress::new();
429
430            let pb_style = ProgressStyle::default_bar()
431                .template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
432                .unwrap()
433                .progress_chars("=>-");
434            let total: u64 = (n_collect + n_discard).try_into().unwrap();
435
436            let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
437            global_pb.set_style(pb_style.clone());
438            global_pb.set_prefix("Global");
439
440            let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
441                .map(|chain_idx| {
442                    let pb = multi.add(ProgressBar::new(total));
443                    pb.set_style(pb_style.clone());
444                    pb.set_prefix(format!("Chain {chain_idx}"));
445                    (chain_idx, pb)
446                })
447                .collect();
448            let mut next_active = active.len();
449            let mut n_finished = 0;
450            let mut most_recent = vec![None; rxs.len()];
451
452            loop {
453                for (i, rx) in rxs.iter().enumerate() {
454                    while let Ok(stats) = rx.recv_timeout(timeout_ms) {
455                        most_recent[i] = Some(stats)
456                    }
457                }
458
459                let mut to_replace = vec![false; active.len()];
460                let mut avg_p_accept = 0.0;
461                let mut n_available_stats = 0.0;
462                for (vec_idx, (i, pb)) in active.iter().enumerate() {
463                    if let Some(stats) = &most_recent[*i] {
464                        pb.set_position(stats.n);
465                        pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
466                        avg_p_accept += stats.p_accept;
467                        n_available_stats += 1.0;
468
469                        if stats.n == total {
470                            to_replace[vec_idx] = true;
471                            n_finished += 1;
472                        }
473                    }
474                }
475                if n_available_stats > 0.0 {
476                    avg_p_accept /= n_available_stats;
477                }
478
479                let mut total_progress = 0;
480                for stats in most_recent.iter().flatten() {
481                    total_progress += stats.n;
482                }
483                global_pb.set_position(total_progress);
484                let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
485                if valid.len() >= 2 {
486                    let rhats = collect_rhat(valid.as_slice());
487                    let max = max_skipnan(&rhats);
488                    global_pb.set_message(format!(
489                        "p(accept)≈{:.2} max(rhat)≈{:.2}",
490                        avg_p_accept, max
491                    ))
492                }
493
494                let mut to_remove = vec![];
495                for (i, replace) in to_replace.iter().enumerate() {
496                    if *replace && next_active < most_recent.len() {
497                        let pb = multi.add(ProgressBar::new(total));
498                        pb.set_style(pb_style.clone());
499                        pb.set_prefix(format!("Chain {next_active}"));
500                        active[i] = (next_active, pb);
501                        next_active += 1;
502                    } else if *replace {
503                        to_remove.push(i);
504                    }
505                }
506
507                to_remove.sort();
508                for i in to_remove.iter().rev() {
509                    active.remove(*i);
510                }
511
512                if n_finished >= most_recent.len() {
513                    break;
514                }
515                std::thread::sleep(sleep_ms);
516            }
517        });
518
519        let chain_sample: Vec<Array2<V::Scalar>> = thread::scope(|s| {
520            let handles: Vec<thread::ScopedJoinHandle<Array2<V::Scalar>>> = chains
521                .iter_mut()
522                .zip(txs)
523                .map(|(chain, tx)| {
524                    s.spawn(|| {
525                        chain
526                            .run_progress(n_collect, n_discard, tx)
527                            .expect("expected running chain to succeed.")
528                    })
529                })
530                .collect();
531            handles
532                .into_iter()
533                .map(|h| {
534                    h.join()
535                        .expect("expected thread to succeed in generating observation.")
536                })
537                .collect()
538        });
539        let views: Vec<ArrayView2<V::Scalar>> = chain_sample.iter().map(|s| s.view()).collect();
540        let sample = ndarray::stack(Axis(0), &views).expect("expected stacking sample to succeed");
541
542        if let Err(e) = progress_handle.join() {
543            eprintln!("Progress bar thread emitted error message: {:?}", e);
544        }
545
546        let run_stats = RunStats::from(sample.view());
547        Ok((sample, run_stats))
548    }
549
550    pub fn set_seed(mut self, seed: u64) -> Self {
551        for (i, chain) in self.chains.iter_mut().enumerate() {
552            let chain_seed = seed + i as u64 + 1;
553            chain.rng = SmallRng::seed_from_u64(chain_seed);
554        }
555        self
556    }
557}
558
559/// Single-chain state and adaptation for NUTS.
560pub struct GenericNUTSChain<V, Target>
561where
562    V: EuclideanVector,
563    Target: HamiltonianTarget<V>,
564{
565    target: Arc<Target>,
566    position: V,
567    target_accept_p: V::Scalar,
568    epsilon: V::Scalar,
569    m: usize,
570    n_collect: usize,
571    n_discard: usize,
572    gamma: V::Scalar,
573    t_0: usize,
574    kappa: V::Scalar,
575    mu: V::Scalar,
576    epsilon_bar: V::Scalar,
577    h_bar: V::Scalar,
578    mass_matrix: MassMatrix<V::Scalar>,
579    mass_warmup: Option<MassMatrixWarmup<V::Scalar>>,
580    rng: SmallRng,
581}
582
583impl<V, Target> GenericNUTSChain<V, Target>
584where
585    V: EuclideanVector,
586    V::Scalar: Float + FromPrimitive + ToPrimitive,
587    Target: HamiltonianTarget<V> + Sync + Send,
588    StandardNormal: RandDistribution<V::Scalar>,
589    StandardUniform: RandDistribution<V::Scalar>,
590    Exp1: RandDistribution<V::Scalar>,
591{
592    pub fn new(target: Target, initial_position: V, target_accept_p: V::Scalar) -> Self {
593        let target = Arc::new(target);
594        Self::new_shared(
595            target,
596            initial_position,
597            target_accept_p,
598            NUTSMassMatrixConfig::disabled(),
599        )
600    }
601
602    pub(crate) fn new_shared(
603        target: Arc<Target>,
604        initial_position: V,
605        target_accept_p: V::Scalar,
606        mass_config: NUTSMassMatrixConfig,
607    ) -> Self {
608        let mut thread_rng = rand::rng();
609        let rng = SmallRng::from_rng(&mut thread_rng);
610        let epsilon = -V::Scalar::one();
611        let dim = initial_position.len();
612        let adaptation = if mass_config.adaptation == MassMatrixAdaptation::Dense
613            && dim > mass_config.dense_max_dim
614        {
615            MassMatrixAdaptation::Diagonal
616        } else {
617            mass_config.adaptation
618        };
619        let mass_matrix = MassMatrix::identity(dim);
620        let mass_warmup = match adaptation {
621            MassMatrixAdaptation::None => None,
622            MassMatrixAdaptation::Diagonal => {
623                Some(MassMatrixWarmup::new(dim, mass_config.clone(), false))
624            }
625            MassMatrixAdaptation::Dense => {
626                Some(MassMatrixWarmup::new(dim, mass_config.clone(), true))
627            }
628        };
629
630        Self {
631            target,
632            position: initial_position,
633            target_accept_p,
634            epsilon,
635            m: 0,
636            n_collect: 0,
637            n_discard: 0,
638            gamma: V::Scalar::from_f64(0.05).unwrap(),
639            t_0: 10,
640            kappa: V::Scalar::from_f64(0.75).unwrap(),
641            mu: (V::Scalar::from_f64(10.0).unwrap() * V::Scalar::one()).ln(),
642            epsilon_bar: V::Scalar::one(),
643            h_bar: V::Scalar::zero(),
644            mass_matrix,
645            mass_warmup,
646            rng,
647        }
648    }
649
650    pub fn set_seed(mut self, seed: u64) -> Self {
651        self.rng = SmallRng::seed_from_u64(seed);
652        self
653    }
654
655    pub fn position(&self) -> &V {
656        &self.position
657    }
658
659    pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Array2<V::Scalar> {
660        let (dim, mut sample) = self.init_chain(n_collect, n_discard);
661        let mut scratch = vec![V::Scalar::zero(); dim];
662
663        for m in 1..(n_collect + n_discard) {
664            self.step();
665
666            if m >= n_discard {
667                self.position.write_to_slice(&mut scratch);
668                let view = ArrayView1::from(&scratch);
669                sample.slice_mut(s![m - n_discard, ..]).assign(&view);
670            }
671        }
672        sample
673    }
674
675    fn run_progress(
676        &mut self,
677        n_collect: usize,
678        n_discard: usize,
679        tx: Sender<ChainStats>,
680    ) -> Result<Array2<V::Scalar>, Box<dyn Error>> {
681        let (dim, mut sample) = self.init_chain(n_collect, n_discard);
682        let mut scratch = vec![V::Scalar::zero(); dim];
683        self.position.write_to_slice(&mut scratch);
684
685        let mut tracker = ChainTracker::new(dim, &scratch);
686        let mut last = Instant::now();
687        let freq = Duration::from_secs(1);
688        let total = n_discard + n_collect;
689
690        for i in 0..total {
691            self.step();
692            self.position.write_to_slice(&mut scratch);
693            tracker.step(&scratch).map_err(|e| {
694                let msg = format!(
695                    "Chain statistics tracker caused error: {}.\nAborting generation of further observations.",
696                    e
697                );
698                println!("{}", msg);
699                msg
700            })?;
701
702            let now = Instant::now();
703            if (now >= last + freq) | (i == total - 1) {
704                if let Err(e) = tx.send(tracker.stats()) {
705                    eprintln!("Sending chain statistics failed: {e}");
706                }
707                last = now;
708            }
709
710            if i >= n_discard {
711                let view = ArrayView1::from(&scratch);
712                sample.slice_mut(s![i - n_discard, ..]).assign(&view);
713            }
714        }
715
716        Ok(sample)
717    }
718
719    fn init_chain(&mut self, n_collect: usize, n_discard: usize) -> (usize, Array2<V::Scalar>) {
720        let dim = self.init_chain_state(n_collect, n_discard);
721
722        let mut sample = Array2::<V::Scalar>::zeros((n_collect, dim));
723        let mut scratch = vec![V::Scalar::zero(); dim];
724        self.position.write_to_slice(&mut scratch);
725        let view = ArrayView1::from(&scratch);
726        sample.slice_mut(s![0, ..]).assign(&view);
727
728        (dim, sample)
729    }
730
731    pub(crate) fn init_chain_state(&mut self, n_collect: usize, n_discard: usize) -> usize {
732        let dim = self.position.len();
733        self.n_collect = n_collect;
734        self.n_discard = n_discard;
735        self.m = 0;
736
737        let mut mom_0 = self.position.zeros_like();
738        let mut mom_buf = vec![V::Scalar::zero(); dim];
739        self.mass_matrix
740            .sample_momentum(&mut self.rng, &mut mom_buf);
741        mom_0.read_from_slice(&mom_buf);
742        if let Some(warmup) = self.mass_warmup.as_mut() {
743            warmup.running.reset();
744        }
745        if V::Scalar::abs(self.epsilon + V::Scalar::one()) <= V::Scalar::epsilon() {
746            self.epsilon = find_reasonable_epsilon(&self.position, &mom_0, self.target.as_ref());
747        }
748        self.mu = (V::Scalar::from_f64(10.0).unwrap() * self.epsilon).ln();
749        dim
750    }
751
752    pub fn step(&mut self) {
753        self.m += 1;
754
755        let dim = self.position.len();
756        let mut mom_0 = self.position.zeros_like();
757        let mut mom_buf = vec![V::Scalar::zero(); dim];
758        self.mass_matrix
759            .sample_momentum(&mut self.rng, &mut mom_buf);
760        mom_0.read_from_slice(&mom_buf);
761
762        let mut grad = self.position.zeros_like();
763        let logp = self.target.logp_and_grad(&self.position, &mut grad);
764        let joint = logp - kinetic_energy(&self.mass_matrix, &mom_0);
765        let exp1_obs: V::Scalar = self.rng.sample(Exp1);
766        let logu = joint - exp1_obs;
767
768        let mut position_minus = self.position.clone();
769        let mut position_plus = self.position.clone();
770        let mut mom_minus = mom_0.clone();
771        let mut mom_plus = mom_0.clone();
772        let mut grad_minus = grad.clone();
773        let mut grad_plus = grad.clone();
774        let mut j = 0;
775        let mut n = 1;
776        let mut s = true;
777        let mut alpha: V::Scalar = V::Scalar::zero();
778        let mut n_alpha: usize = 0;
779
780        while s {
781            let u_run_1: V::Scalar = self.rng.random();
782            let v = (2 * (u_run_1 < V::Scalar::from_f64(0.5).unwrap()) as i8) - 1;
783
784            let (position_prime, n_prime, s_prime) = if v == -1 {
785                let (
786                    position_minus_2,
787                    mom_minus_2,
788                    grad_minus_2,
789                    _,
790                    _,
791                    _,
792                    position_prime_2,
793                    _,
794                    _,
795                    n_prime_2,
796                    s_prime_2,
797                    alpha_2,
798                    n_alpha_2,
799                ) = build_tree_with_mass(
800                    position_minus.clone(),
801                    mom_minus.clone(),
802                    grad_minus.clone(),
803                    logu,
804                    v,
805                    j,
806                    self.epsilon,
807                    self.target.as_ref(),
808                    &self.mass_matrix,
809                    joint,
810                    &mut self.rng,
811                );
812
813                position_minus = position_minus_2;
814                mom_minus = mom_minus_2;
815                grad_minus = grad_minus_2;
816
817                alpha = alpha_2;
818                n_alpha = n_alpha_2;
819                (position_prime_2, n_prime_2, s_prime_2)
820            } else {
821                let (
822                    _,
823                    _,
824                    _,
825                    position_plus_2,
826                    mom_plus_2,
827                    grad_plus_2,
828                    position_prime_2,
829                    _,
830                    _,
831                    n_prime_2,
832                    s_prime_2,
833                    alpha_2,
834                    n_alpha_2,
835                ) = build_tree_with_mass(
836                    position_plus.clone(),
837                    mom_plus.clone(),
838                    grad_plus.clone(),
839                    logu,
840                    v,
841                    j,
842                    self.epsilon,
843                    self.target.as_ref(),
844                    &self.mass_matrix,
845                    joint,
846                    &mut self.rng,
847                );
848
849                position_plus = position_plus_2;
850                mom_plus = mom_plus_2;
851                grad_plus = grad_plus_2;
852
853                alpha = alpha_2;
854                n_alpha = n_alpha_2;
855                (position_prime_2, n_prime_2, s_prime_2)
856            };
857
858            let tmp = V::Scalar::one().min(
859                V::Scalar::from_usize(n_prime)
860                    .expect("successful conversion of n_prime from usize")
861                    / V::Scalar::from_usize(n).expect("successful conversion of n from usize"),
862            );
863            let u_run_2: V::Scalar = self.rng.random();
864            if s_prime && (u_run_2 < tmp) {
865                self.position = position_prime;
866            }
867            n += n_prime;
868
869            s = s_prime
870                && stop_criterion_with_mass(
871                    position_minus.clone(),
872                    position_plus.clone(),
873                    mom_minus.clone(),
874                    mom_plus.clone(),
875                    &self.mass_matrix,
876                );
877            j += 1
878        }
879
880        let mut eta = V::Scalar::one()
881            / V::Scalar::from_usize(self.m + self.t_0).expect("successful conversion of m + t_0");
882        self.h_bar = (V::Scalar::one() - eta) * self.h_bar
883            + eta
884                * (self.target_accept_p
885                    - alpha
886                        / V::Scalar::from_usize(n_alpha)
887                            .expect("successful conversion of n_alpha"));
888        if self.m <= self.n_discard {
889            let m = V::Scalar::from_usize(self.m).expect("successful conversion of m");
890            self.epsilon = (self.mu - m.sqrt() / self.gamma * self.h_bar).exp();
891            eta = m.powf(-self.kappa);
892            self.epsilon_bar =
893                ((V::Scalar::one() - eta) * self.epsilon_bar.ln() + eta * self.epsilon.ln()).exp();
894
895            if let Some(warmup) = self.mass_warmup.as_mut()
896                && warmup.should_collect(self.m, self.n_discard)
897            {
898                let mut q = vec![V::Scalar::zero(); dim];
899                self.position.write_to_slice(&mut q);
900                warmup.running.update(&q);
901                if warmup.note_if_window_end(self.m, self.n_discard)
902                    && let Some(updated) = maybe_update_mass_matrix(&self.mass_matrix, warmup)
903                {
904                    self.mass_matrix = updated;
905                    let mut probe = self.position.zeros_like();
906                    let mut probe_buf = vec![V::Scalar::zero(); dim];
907                    self.mass_matrix
908                        .sample_momentum(&mut self.rng, &mut probe_buf);
909                    probe.read_from_slice(&probe_buf);
910                    self.epsilon =
911                        find_reasonable_epsilon(&self.position, &probe, self.target.as_ref());
912                    self.mu = (V::Scalar::from_f64(10.0).unwrap() * self.epsilon).ln();
913                    self.epsilon_bar = self.epsilon;
914                    self.h_bar = V::Scalar::zero();
915                    warmup.running.reset();
916                }
917            }
918        } else {
919            self.epsilon = self.epsilon_bar;
920        }
921    }
922}
923
924fn kinetic_energy<V: EuclideanVector>(mass: &MassMatrix<V::Scalar>, mom: &V) -> V::Scalar
925where
926    V::Scalar: Float + FromPrimitive,
927{
928    let mut p = vec![V::Scalar::zero(); mom.len()];
929    mom.write_to_slice(&mut p);
930    mass.kinetic(&p)
931}
932
933fn apply_inv_mass<V: EuclideanVector>(mass: &MassMatrix<V::Scalar>, input: &V, out: &mut V)
934where
935    V::Scalar: Float + FromPrimitive,
936{
937    let mut p = vec![V::Scalar::zero(); input.len()];
938    let mut v = vec![V::Scalar::zero(); input.len()];
939    input.write_to_slice(&mut p);
940    mass.inv_mul(&p, &mut v);
941    out.read_from_slice(&v);
942}
943
944fn maybe_update_mass_matrix<S: Float + FromPrimitive>(
945    current: &MassMatrix<S>,
946    warmup: &MassMatrixWarmup<S>,
947) -> Option<MassMatrix<S>> {
948    let n = warmup.running.n;
949    if n < 5 {
950        return None;
951    }
952    let n_denom = S::from_usize(n - 1).unwrap();
953    let reg = S::from_f64(warmup.config.regularize).unwrap();
954    let one_minus_reg = S::one() - reg;
955    let jitter = S::from_f64(warmup.config.jitter.max(1e-10)).unwrap();
956    match warmup.config.adaptation {
957        MassMatrixAdaptation::None => None,
958        MassMatrixAdaptation::Diagonal => {
959            let mut var = vec![S::zero(); warmup.running.dim];
960            for (i, vi) in var.iter_mut().enumerate().take(warmup.running.dim) {
961                let raw = warmup.running.m2_diag[i] / n_denom;
962                *vi = (one_minus_reg * raw + reg).max(jitter);
963            }
964            Some(MassMatrix::diagonal_from_var(var, jitter))
965        }
966        MassMatrixAdaptation::Dense => {
967            let dim = warmup.running.dim;
968            let Some(m2_dense) = warmup.running.m2_dense.as_ref() else {
969                return None;
970            };
971            let mut cov = vec![S::zero(); dim * dim];
972            for i in 0..dim {
973                for j in i..dim {
974                    let idx = i * dim + j;
975                    let raw = m2_dense[idx] / n_denom;
976                    let v = if i == j {
977                        (one_minus_reg * raw + reg).max(jitter)
978                    } else {
979                        one_minus_reg * raw
980                    };
981                    cov[idx] = v;
982                    cov[j * dim + i] = v;
983                }
984            }
985            MassMatrix::dense_from_cov(cov, dim, jitter).or_else(|| match current {
986                MassMatrix::Diagonal { .. } | MassMatrix::Dense { .. } => None,
987                MassMatrix::Identity { dim } => {
988                    Some(MassMatrix::diagonal_from_var(vec![S::one(); *dim], jitter))
989                }
990            })
991        }
992    }
993}
994
995fn all_real_vec<V: EuclideanVector>(v: &V) -> bool
996where
997    V::Scalar: Float,
998{
999    let mut scratch = vec![V::Scalar::zero(); v.len()];
1000    v.write_to_slice(&mut scratch);
1001    scratch.iter().all(|x: &V::Scalar| x.is_finite())
1002}
1003
1004#[allow(dead_code)]
1005pub(crate) fn find_reasonable_epsilon<V, Target>(
1006    position: &V,
1007    mom: &V,
1008    gradient_target: &Target,
1009) -> V::Scalar
1010where
1011    V: EuclideanVector,
1012    V::Scalar: Float + FromPrimitive,
1013    Target: HamiltonianTarget<V> + Sync,
1014    StandardNormal: RandDistribution<V::Scalar>,
1015    StandardUniform: RandDistribution<V::Scalar>,
1016{
1017    let mass_matrix = MassMatrix::identity(position.len());
1018    find_reasonable_epsilon_with_mass(position, mom, gradient_target, &mass_matrix)
1019}
1020
1021fn find_reasonable_epsilon_with_mass<V, Target>(
1022    position: &V,
1023    mom: &V,
1024    gradient_target: &Target,
1025    mass_matrix: &MassMatrix<V::Scalar>,
1026) -> V::Scalar
1027where
1028    V: EuclideanVector,
1029    V::Scalar: Float + FromPrimitive,
1030    Target: HamiltonianTarget<V> + Sync,
1031    StandardNormal: RandDistribution<V::Scalar>,
1032    StandardUniform: RandDistribution<V::Scalar>,
1033{
1034    let mut epsilon = V::Scalar::one();
1035    let half = V::Scalar::from_f64(0.5).unwrap();
1036
1037    let mut grad = position.zeros_like();
1038    let ulogp = gradient_target.logp_and_grad(position, &mut grad);
1039
1040    let mut position_prime = position.clone();
1041    let mut mom_prime = mom.clone();
1042    let mut grad_prime = grad.clone();
1043    let mut ulogp_prime = leapfrog_with_mass(
1044        &mut position_prime,
1045        &mut mom_prime,
1046        &mut grad_prime,
1047        epsilon,
1048        gradient_target,
1049        mass_matrix,
1050    );
1051    let mut k = V::Scalar::one();
1052
1053    while !ulogp_prime.is_finite() || !all_real_vec(&grad_prime) {
1054        k = k * half;
1055        position_prime.assign(position);
1056        mom_prime.assign(mom);
1057        grad_prime.assign(&grad);
1058        ulogp_prime = leapfrog_with_mass(
1059            &mut position_prime,
1060            &mut mom_prime,
1061            &mut grad_prime,
1062            epsilon * k,
1063            gradient_target,
1064            mass_matrix,
1065        );
1066    }
1067
1068    epsilon = half * k * epsilon;
1069    let log_accept_prob = ulogp_prime
1070        - ulogp
1071        - (kinetic_energy(mass_matrix, &mom_prime) - kinetic_energy(mass_matrix, mom));
1072    let mut log_accept_prob = log_accept_prob;
1073
1074    let a = if log_accept_prob > half.ln() {
1075        V::Scalar::one()
1076    } else {
1077        -V::Scalar::one()
1078    };
1079
1080    while a * log_accept_prob > -a * V::Scalar::from_f64(2.0).unwrap().ln() {
1081        epsilon = epsilon * V::Scalar::from_f64(2.0).unwrap().powf(a);
1082        position_prime.assign(position);
1083        mom_prime.assign(mom);
1084        grad_prime.assign(&grad);
1085        ulogp_prime = leapfrog_with_mass(
1086            &mut position_prime,
1087            &mut mom_prime,
1088            &mut grad_prime,
1089            epsilon,
1090            gradient_target,
1091            mass_matrix,
1092        );
1093        log_accept_prob = ulogp_prime
1094            - ulogp
1095            - (kinetic_energy(mass_matrix, &mom_prime) - kinetic_energy(mass_matrix, mom));
1096    }
1097
1098    epsilon
1099}
1100
1101#[allow(clippy::too_many_arguments, clippy::type_complexity)]
1102fn build_tree_with_mass<V, Target>(
1103    position: V,
1104    mom: V,
1105    grad: V,
1106    logu: V::Scalar,
1107    v: i8,
1108    j: usize,
1109    epsilon: V::Scalar,
1110    gradient_target: &Target,
1111    mass_matrix: &MassMatrix<V::Scalar>,
1112    joint_0: V::Scalar,
1113    rng: &mut SmallRng,
1114) -> (
1115    V,
1116    V,
1117    V,
1118    V,
1119    V,
1120    V,
1121    V,
1122    V,
1123    V::Scalar,
1124    usize,
1125    bool,
1126    V::Scalar,
1127    usize,
1128)
1129where
1130    V: EuclideanVector,
1131    V::Scalar: Float + FromPrimitive,
1132    Target: HamiltonianTarget<V> + Sync,
1133{
1134    if j == 0 {
1135        let mut position_prime = position.clone();
1136        let mut mom_prime = mom.clone();
1137        let mut grad_prime = grad.clone();
1138        let logp_prime = leapfrog_with_mass(
1139            &mut position_prime,
1140            &mut mom_prime,
1141            &mut grad_prime,
1142            V::Scalar::from_i64(v as i64).unwrap() * epsilon,
1143            gradient_target,
1144            mass_matrix,
1145        );
1146        let joint = logp_prime - kinetic_energy(mass_matrix, &mom_prime);
1147        let n_prime = (logu < joint) as usize;
1148        let s_prime = (logu - V::Scalar::from_f64(1000.0).unwrap()) < joint;
1149        let position_minus = position_prime.clone();
1150        let position_plus = position_prime.clone();
1151        let mom_minus = mom_prime.clone();
1152        let mom_plus = mom_prime.clone();
1153        let grad_minus = grad_prime.clone();
1154        let grad_plus = grad_prime.clone();
1155        let alpha_prime = V::Scalar::one().min((joint - joint_0).exp());
1156        let n_alpha_prime = 1_usize;
1157        (
1158            position_minus,
1159            mom_minus,
1160            grad_minus,
1161            position_plus,
1162            mom_plus,
1163            grad_plus,
1164            position_prime,
1165            grad_prime,
1166            logp_prime,
1167            n_prime,
1168            s_prime,
1169            alpha_prime,
1170            n_alpha_prime,
1171        )
1172    } else {
1173        let (
1174            mut position_minus,
1175            mut mom_minus,
1176            mut grad_minus,
1177            mut position_plus,
1178            mut mom_plus,
1179            mut grad_plus,
1180            mut position_prime,
1181            mut grad_prime,
1182            mut logp_prime,
1183            mut n_prime,
1184            mut s_prime,
1185            mut alpha_prime,
1186            mut n_alpha_prime,
1187        ) = build_tree_with_mass(
1188            position,
1189            mom,
1190            grad,
1191            logu,
1192            v,
1193            j - 1,
1194            epsilon,
1195            gradient_target,
1196            mass_matrix,
1197            joint_0,
1198            rng,
1199        );
1200        if s_prime {
1201            let (
1202                position_minus_2,
1203                mom_minus_2,
1204                grad_minus_2,
1205                position_plus_2,
1206                mom_plus_2,
1207                grad_plus_2,
1208                position_prime_2,
1209                grad_prime_2,
1210                logp_prime_2,
1211                n_prime_2,
1212                s_prime_2,
1213                alpha_prime_2,
1214                n_alpha_prime_2,
1215            ) = if v == -1 {
1216                build_tree_with_mass(
1217                    position_minus.clone(),
1218                    mom_minus.clone(),
1219                    grad_minus.clone(),
1220                    logu,
1221                    v,
1222                    j - 1,
1223                    epsilon,
1224                    gradient_target,
1225                    mass_matrix,
1226                    joint_0,
1227                    rng,
1228                )
1229            } else {
1230                build_tree_with_mass(
1231                    position_plus.clone(),
1232                    mom_plus.clone(),
1233                    grad_plus.clone(),
1234                    logu,
1235                    v,
1236                    j - 1,
1237                    epsilon,
1238                    gradient_target,
1239                    mass_matrix,
1240                    joint_0,
1241                    rng,
1242                )
1243            };
1244            if v == -1 {
1245                position_minus = position_minus_2;
1246                mom_minus = mom_minus_2;
1247                grad_minus = grad_minus_2;
1248            } else {
1249                position_plus = position_plus_2;
1250                mom_plus = mom_plus_2;
1251                grad_plus = grad_plus_2;
1252            }
1253
1254            let u_build_tree: f64 = rng.random();
1255            if u_build_tree < (n_prime_2 as f64 / (n_prime + n_prime_2).max(1) as f64) {
1256                position_prime = position_prime_2;
1257                grad_prime = grad_prime_2;
1258                logp_prime = logp_prime_2;
1259            }
1260
1261            n_prime += n_prime_2;
1262
1263            s_prime = s_prime
1264                && s_prime_2
1265                && stop_criterion(
1266                    position_minus.clone(),
1267                    position_plus.clone(),
1268                    mom_minus.clone(),
1269                    mom_plus.clone(),
1270                );
1271            alpha_prime = alpha_prime + alpha_prime_2;
1272            n_alpha_prime += n_alpha_prime_2;
1273        }
1274        (
1275            position_minus,
1276            mom_minus,
1277            grad_minus,
1278            position_plus,
1279            mom_plus,
1280            grad_plus,
1281            position_prime,
1282            grad_prime,
1283            logp_prime,
1284            n_prime,
1285            s_prime,
1286            alpha_prime,
1287            n_alpha_prime,
1288        )
1289    }
1290}
1291
1292pub(crate) fn stop_criterion<V>(
1293    position_minus: V,
1294    position_plus: V,
1295    mom_minus: V,
1296    mom_plus: V,
1297) -> bool
1298where
1299    V: EuclideanVector,
1300    V::Scalar: Float + FromPrimitive,
1301{
1302    let mass_matrix = MassMatrix::identity(position_minus.len());
1303    stop_criterion_with_mass(
1304        position_minus,
1305        position_plus,
1306        mom_minus,
1307        mom_plus,
1308        &mass_matrix,
1309    )
1310}
1311
1312fn stop_criterion_with_mass<V>(
1313    position_minus: V,
1314    position_plus: V,
1315    mom_minus: V,
1316    mom_plus: V,
1317    mass_matrix: &MassMatrix<V::Scalar>,
1318) -> bool
1319where
1320    V: EuclideanVector,
1321    V::Scalar: Float + FromPrimitive,
1322{
1323    // Use proper subtraction to match original Tensor semantics
1324    let mut diff = position_plus.clone();
1325    diff.sub_assign(&position_minus);
1326    let mut vel_minus = mom_minus.zeros_like();
1327    let mut vel_plus = mom_plus.zeros_like();
1328    apply_inv_mass(mass_matrix, &mom_minus, &mut vel_minus);
1329    apply_inv_mass(mass_matrix, &mom_plus, &mut vel_plus);
1330    let dot_minus = diff.dot(&vel_minus);
1331    let dot_plus = diff.dot(&vel_plus);
1332    dot_minus >= V::Scalar::zero() && dot_plus >= V::Scalar::zero()
1333}
1334
1335fn leapfrog_with_mass<V, Target>(
1336    position: &mut V,
1337    momentum: &mut V,
1338    grad: &mut V,
1339    epsilon: V::Scalar,
1340    gradient_target: &Target,
1341    mass_matrix: &MassMatrix<V::Scalar>,
1342) -> V::Scalar
1343where
1344    V: EuclideanVector,
1345    V::Scalar: Float + FromPrimitive,
1346    Target: HamiltonianTarget<V>,
1347{
1348    // Match original operation order: grad * epsilon * 0.5 (not grad * (0.5 * epsilon))
1349    let half = V::Scalar::from_f64(0.5).unwrap();
1350    momentum.add_scaled_assign(grad, epsilon * half);
1351    let mut velocity = momentum.zeros_like();
1352    apply_inv_mass(mass_matrix, momentum, &mut velocity);
1353    position.add_scaled_assign(&velocity, epsilon);
1354    let logp = gradient_target.logp_and_grad(position, grad);
1355    momentum.add_scaled_assign(grad, epsilon * half);
1356    logp
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361    use super::{
1362        MassMatrix, MassMatrixAdaptation, MassMatrixWarmup, NUTSMassMatrixConfig,
1363        maybe_update_mass_matrix,
1364    };
1365
1366    #[test]
1367    fn diagonal_mass_matrix_kinetic_and_inv_mul_are_consistent() {
1368        let mass = MassMatrix::diagonal_from_var(vec![4.0_f64, 9.0_f64], 1e-12);
1369        let p = [2.0_f64, 3.0_f64];
1370        let ke = mass.kinetic(&p);
1371        // 0.5 * (2^2/4 + 3^2/9) = 1.0
1372        assert!((ke - 1.0).abs() < 1e-12);
1373
1374        let mut out = [0.0_f64; 2];
1375        mass.inv_mul(&p, &mut out);
1376        assert!((out[0] - 0.5).abs() < 1e-12);
1377        assert!((out[1] - (1.0 / 3.0)).abs() < 1e-12);
1378    }
1379
1380    #[test]
1381    fn dense_mass_matrix_inverse_matches_identity_action() {
1382        let cov = vec![
1383            2.0_f64, 0.3_f64, //
1384            0.3_f64, 1.0_f64,
1385        ];
1386        let mass = MassMatrix::dense_from_cov(cov, 2, 1e-12).expect("dense mass matrix");
1387        let p = [0.7_f64, -1.1_f64];
1388        let mut out = [0.0_f64; 2];
1389        mass.inv_mul(&p, &mut out);
1390
1391        // For SPD matrix, p' M^{-1} p must be positive.
1392        let quad = p[0] * out[0] + p[1] * out[1];
1393        assert!(quad > 0.0);
1394    }
1395
1396    #[test]
1397    fn warmup_diagonal_update_produces_positive_variances() {
1398        let cfg = NUTSMassMatrixConfig {
1399            adaptation: MassMatrixAdaptation::Diagonal,
1400            start_buffer: 1,
1401            end_buffer: 1,
1402            initial_window: 4,
1403            regularize: 0.05,
1404            jitter: 1e-6,
1405            dense_max_dim: 75,
1406        };
1407        let mut warmup = MassMatrixWarmup::new(2, cfg, false);
1408        let current = MassMatrix::identity(2);
1409        for x in [
1410            [-2.0_f64, 1.0_f64],
1411            [-1.0, 0.0],
1412            [0.0, 1.0],
1413            [2.0, -1.0],
1414            [1.0, 0.5],
1415        ] {
1416            warmup.running.update(&x);
1417        }
1418        let updated = maybe_update_mass_matrix(&current, &warmup).expect("updated mass");
1419        match updated {
1420            MassMatrix::Diagonal { inv, sqrt } => {
1421                for i in 0..2 {
1422                    assert!(inv[i].is_finite() && inv[i] > 0.0);
1423                    assert!(sqrt[i].is_finite() && sqrt[i] > 0.0);
1424                }
1425            }
1426            _ => panic!("expected diagonal mass matrix"),
1427        }
1428    }
1429}