Skip to main content

nuts_rs/dynamics/
transformed_hamiltonian.rs

1//! Concrete Hamiltonian that runs leapfrog in a whitened space to improve sampling geometry.
2//!
3//! Three trajectory kinds are supported via [`KineticEnergyKind`]:
4//! - [`KineticEnergyKind::Euclidean`]: standard leapfrog with Euclidean kinetic energy.
5//! - [`KineticEnergyKind::ExactNormal`]: geodesic leapfrog that is exact for a standard-normal
6//!   potential (position and velocity rotate together in each 2-D plane).
7//! - [`KineticEnergyKind::Microcanonical`]: isokinetic ESH-dynamics leapfrog (microcanonical
8//!   HMC). The momentum is constrained to the unit sphere and the ESH update keeps it there
9//!   while tracking the accumulated kinetic-energy change along the trajectory.
10
11use std::{fmt::Debug, marker::PhantomData, sync::Arc};
12
13use nuts_derive::Storable;
14use nuts_storable::HasDims;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    DivergenceInfo, LogpError, Math, NutsError,
19    dynamics::{Direction, Hamiltonian, LeapfrogResult, Point},
20    dynamics::{State, StatePool},
21    sampler_stats::{SamplerStats, StatsDims},
22    transform::{ExternalTransformation, Transformation},
23};
24
25/// Selects the kinetic-energy form (and thus the integrator) used by
26/// [`TransformedHamiltonian`].
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
28pub enum KineticEnergyKind {
29    /// Standard Euclidean kinetic energy `K = ½ ‖v‖²`.
30    /// Uses the ordinary leapfrog integrator (velocity Verlet).
31    #[default]
32    Euclidean,
33
34    /// Geodesic leapfrog that is *exact* for a standard-normal potential.
35    /// Position and velocity are rotated together in each 2-D plane `(q_i, v_i)`.
36    ExactNormal,
37
38    /// Microcanonical / isokinetic HMC using ESH dynamics
39    /// ([Steeg & Gallagher 2021](https://arxiv.org/abs/2111.02434),
40    /// ported from the [BlackJAX implementation](https://github.com/blackjax-devs/blackjax/blob/main/blackjax/mcmc/integrators.py#L314)).
41    ///
42    /// The momentum is constrained to the unit sphere (`‖v‖ = 1`).
43    /// The momentum update uses the ESH formula which preserves `‖v‖ = 1` exactly
44    /// while tracking the cumulative kinetic-energy change needed for the
45    /// Metropolis accept/reject decision.
46    ///
47    /// No partial momentum refreshment is performed — this variant only implements
48    /// the deterministic ESH trajectory.
49    Microcanonical,
50}
51
52// ---------------------------------------------------------------------------
53// ESH (Extended Stochastic Hamiltonian) momentum update
54// ---------------------------------------------------------------------------
55
56pub struct TransformedPoint<M: Math> {
57    pub(crate) untransformed_position: M::Vector,
58    pub(crate) untransformed_gradient: M::Vector,
59    pub(crate) transformed_position: M::Vector,
60    pub(crate) transformed_gradient: M::Vector,
61    pub(crate) velocity: M::Vector,
62    index_in_trajectory: i64,
63    logp: f64,
64    logdet: f64,
65    /// For Euclidean / ExactNormal: `½ ‖v‖²`.
66    /// For Microcanonical: the accumulated kinetic-energy change ΔKE along the
67    /// current leapfrog step (carried through `kinetic_energy` between the two
68    /// half-steps, then fixed after the second half-step).
69    kinetic_energy: f64,
70    initial_energy: f64,
71    transform_id: i64,
72    /// The step size factor used by the leapfrog step that produced this point.
73    /// For NUTS and static MCLMC this is always `1.0`; for MCLMC with dynamic
74    /// step size reduction it may be `< 1.0`.  Used to compute importance
75    /// weights: `log_weight = log(step_size_factor) - energy_error`.
76    pub(crate) step_size_factor: f64,
77}
78
79impl<M: Math> Debug for TransformedPoint<M> {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("TransformedPoint")
82            .field("untransformed_position", &self.untransformed_position)
83            .field("untransformed_gradient", &self.untransformed_gradient)
84            .field("transformed_position", &self.transformed_position)
85            .field("transformed_gradient", &self.transformed_gradient)
86            .field("velocity", &self.velocity)
87            .field("index_in_trajectory", &self.index_in_trajectory)
88            .field("logp", &self.logp)
89            .field("logdet", &self.logdet)
90            .field("kinetic_energy", &self.kinetic_energy)
91            .field("transform_id", &self.transform_id)
92            .finish()
93    }
94}
95
96#[derive(Debug, Storable)]
97pub struct PointStats {
98    pub index_in_trajectory: i64,
99    pub logp: f64,
100    pub energy: f64,
101    pub energy_error: f64,
102    #[storable(dims("unconstrained_parameter"))]
103    pub unconstrained_draw: Option<Vec<f64>>,
104    #[storable(dims("unconstrained_parameter"))]
105    pub gradient: Option<Vec<f64>>,
106    pub fisher_distance: f64,
107    #[storable(dims("unconstrained_parameter"))]
108    pub transformed_position: Option<Vec<f64>>,
109    #[storable(dims("unconstrained_parameter"))]
110    pub transformed_gradient: Option<Vec<f64>>,
111    pub transformation_index: i64,
112}
113
114#[derive(Debug, Clone, Copy)]
115pub struct TransformedPointStatsOptions {
116    pub store_gradient: bool,
117    pub store_unconstrained: bool,
118    pub store_transformed: bool,
119}
120
121impl<M: Math> SamplerStats<M> for TransformedPoint<M> {
122    type Stats = PointStats;
123    type StatsOptions = TransformedPointStatsOptions;
124
125    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
126        let unconstrained_draw = if opt.store_unconstrained {
127            Some(math.box_array(&self.untransformed_position).into_vec())
128        } else {
129            None
130        };
131        let gradient = if opt.store_gradient {
132            Some(math.box_array(&self.untransformed_gradient).into_vec())
133        } else {
134            None
135        };
136        let mut transformed_position = None;
137        let mut transformed_gradient = None;
138        if opt.store_transformed {
139            transformed_position = Some(math.box_array(&self.transformed_position));
140            transformed_gradient = Some(math.box_array(&self.transformed_gradient));
141        }
142        let fisher_distance =
143            math.sq_norm_sum(&self.transformed_position, &self.transformed_gradient);
144        PointStats {
145            index_in_trajectory: self.index_in_trajectory,
146            logp: self.logp,
147            energy: self.energy(),
148            energy_error: self.energy_error(),
149            unconstrained_draw,
150            gradient,
151            fisher_distance,
152            transformation_index: self.transform_id,
153            transformed_gradient: transformed_gradient.map(|x| x.into_vec()),
154            transformed_position: transformed_position.map(|x| x.into_vec()),
155        }
156    }
157}
158
159impl<M: Math> TransformedPoint<M> {
160    /// First velocity half-step.
161    fn first_velocity_halfstep(
162        &self,
163        math: &mut M,
164        out: &mut Self,
165        epsilon: f64,
166        kind: KineticEnergyKind,
167    ) {
168        match kind {
169            KineticEnergyKind::ExactNormal => {
170                math.std_norm_grad_flow(
171                    &self.transformed_position,
172                    &self.transformed_gradient,
173                    &self.velocity,
174                    &mut out.velocity,
175                    epsilon / 2.,
176                );
177            }
178            KineticEnergyKind::Euclidean => {
179                math.axpy_out(
180                    &self.transformed_gradient,
181                    &self.velocity,
182                    epsilon / 2.,
183                    &mut out.velocity,
184                );
185            }
186            KineticEnergyKind::Microcanonical => {
187                // TODO this is an extra copy we could get rid of
188                math.copy_into(&self.velocity, &mut out.velocity);
189                let ndim = math.dim();
190                out.kinetic_energy = self.kinetic_energy
191                    + math.esh_momentum_update(
192                        &self.transformed_gradient,
193                        &mut out.velocity,
194                        // Make the step sizes comparable
195                        (ndim as f64).sqrt() * epsilon / 2.,
196                    );
197            }
198        }
199    }
200
201    /// Position (and, for geodesic integrators, simultaneous velocity) step.
202    fn position_step(&self, math: &mut M, out: &mut Self, epsilon: f64, kind: KineticEnergyKind) {
203        match kind {
204            //   q' =  q cos ε + v sin ε
205            //   v' = −q sin ε + v cos ε
206            KineticEnergyKind::ExactNormal => {
207                math.std_norm_flow(
208                    &self.transformed_position,
209                    &mut out.transformed_position,
210                    &mut out.velocity,
211                    epsilon,
212                );
213            }
214            KineticEnergyKind::Euclidean | KineticEnergyKind::Microcanonical => {
215                let epsilon = if matches!(kind, KineticEnergyKind::Microcanonical) {
216                    epsilon * (math.dim() as f64).sqrt()
217                } else {
218                    epsilon
219                };
220                math.axpy_out(
221                    &out.velocity,
222                    &self.transformed_position,
223                    epsilon,
224                    &mut out.transformed_position,
225                );
226            }
227        }
228    }
229
230    /// Second velocity half-step.
231    ///
232    /// `accumulated_delta_ke` is the ΔKE from the first half-step (only used for
233    /// Microcanonical; ignored for other variants).  After this call `self.kinetic_energy`
234    /// holds the final value appropriate for `energy()`.
235    fn second_velocity_halfstep(&mut self, math: &mut M, epsilon: f64, kind: KineticEnergyKind) {
236        match kind {
237            KineticEnergyKind::ExactNormal => {
238                math.std_norm_grad_flow_inplace(
239                    &self.transformed_position,
240                    &self.transformed_gradient,
241                    &mut self.velocity,
242                    epsilon / 2.,
243                );
244            }
245            KineticEnergyKind::Euclidean => {
246                math.axpy(&self.transformed_gradient, &mut self.velocity, epsilon / 2.);
247            }
248            KineticEnergyKind::Microcanonical => {
249                let ndim = math.dim();
250                self.kinetic_energy = self.kinetic_energy
251                    + math.esh_momentum_update(
252                        &self.transformed_gradient,
253                        &mut self.velocity,
254                        (ndim as f64).sqrt() * epsilon / 2.,
255                    );
256            }
257        }
258    }
259
260    fn update_kinetic_energy(&mut self, math: &mut M) {
261        self.kinetic_energy = 0.5 * math.array_vector_dot(&self.velocity, &self.velocity);
262    }
263
264    fn init_from_untransformed_position<T: Transformation<M>>(
265        &mut self,
266        transformation: &T,
267        math: &mut M,
268    ) -> Result<(), M::LogpErr> {
269        let (logp, logdet) = transformation.init_from_untransformed_position(
270            math,
271            &self.untransformed_position,
272            &mut self.untransformed_gradient,
273            &mut self.transformed_position,
274            &mut self.transformed_gradient,
275        )?;
276        self.logp = logp;
277        self.logdet = logdet;
278        self.transform_id = transformation.transformation_id(math);
279        Ok(())
280    }
281
282    fn init_from_transformed_position<T: Transformation<M>>(
283        &mut self,
284        transformation: &T,
285        math: &mut M,
286    ) -> Result<(), M::LogpErr> {
287        let (logp, logdet) = transformation.init_from_transformed_position(
288            math,
289            &mut self.untransformed_position,
290            &mut self.untransformed_gradient,
291            &self.transformed_position,
292            &mut self.transformed_gradient,
293        )?;
294        self.logp = logp;
295        self.logdet = logdet;
296        self.transform_id = transformation.transformation_id(math);
297        Ok(())
298    }
299
300    fn check_untransformed(&self, math: &mut M) -> bool {
301        if !math.array_all_finite(&self.untransformed_gradient) {
302            return false;
303        }
304        if !math.array_all_finite(&self.untransformed_position) {
305            return false;
306        }
307        true
308    }
309
310    fn check_all(&self, math: &mut M) -> bool {
311        if !math.array_all_finite(&self.transformed_position) {
312            return false;
313        }
314        if !math.array_all_finite_and_nonzero(&self.transformed_gradient) {
315            return false;
316        }
317        if !math.array_all_finite(&self.untransformed_gradient) {
318            return false;
319        }
320        if !math.array_all_finite(&self.untransformed_position) {
321            return false;
322        }
323        true
324    }
325}
326
327impl<M: Math> Point<M> for TransformedPoint<M> {
328    fn position(&self) -> &<M as Math>::Vector {
329        &self.untransformed_position
330    }
331
332    fn gradient(&self) -> &<M as Math>::Vector {
333        &self.untransformed_gradient
334    }
335
336    fn index_in_trajectory(&self) -> i64 {
337        self.index_in_trajectory
338    }
339
340    /// The Hamiltonian energy at this point.
341    ///
342    /// For Euclidean / ExactNormal:  `E = ½‖v‖² − (logp + logdet)`
343    /// For Microcanonical:           `E = ΔKE_accum − (logp + logdet)`
344    ///
345    /// In both cases `energy_error = energy − initial_energy` is used for
346    /// divergence detection.  The constant offset `−(n−1) log 2` present in
347    /// the ESH kinetic energy cancels in the difference and is therefore
348    /// omitted.
349    fn energy(&self) -> f64 {
350        self.kinetic_energy - (self.logp + self.logdet)
351    }
352
353    fn initial_energy(&self) -> f64 {
354        self.initial_energy
355    }
356
357    fn logp(&self) -> f64 {
358        self.logp
359    }
360
361    fn new(math: &mut M) -> Self {
362        Self {
363            untransformed_position: math.new_array(),
364            untransformed_gradient: math.new_array(),
365            transformed_position: math.new_array(),
366            transformed_gradient: math.new_array(),
367            velocity: math.new_array(),
368            index_in_trajectory: 0,
369            logp: 0f64,
370            logdet: 0f64,
371            kinetic_energy: 0f64,
372            transform_id: -1,
373            initial_energy: 0f64,
374            step_size_factor: 1.0,
375        }
376    }
377
378    fn copy_into(&self, math: &mut M, other: &mut Self) {
379        let Self {
380            untransformed_position,
381            untransformed_gradient,
382            transformed_position,
383            transformed_gradient,
384            velocity,
385            index_in_trajectory,
386            logp,
387            logdet,
388            kinetic_energy,
389            transform_id,
390            initial_energy,
391            step_size_factor,
392        } = self;
393
394        other.index_in_trajectory = *index_in_trajectory;
395        other.logp = *logp;
396        other.logdet = *logdet;
397        other.kinetic_energy = *kinetic_energy;
398        other.transform_id = *transform_id;
399        other.initial_energy = *initial_energy;
400        other.step_size_factor = *step_size_factor;
401        math.copy_into(untransformed_position, &mut other.untransformed_position);
402        math.copy_into(untransformed_gradient, &mut other.untransformed_gradient);
403        math.copy_into(transformed_position, &mut other.transformed_position);
404        math.copy_into(transformed_gradient, &mut other.transformed_gradient);
405        math.copy_into(velocity, &mut other.velocity);
406    }
407}
408
409pub struct TransformedHamiltonian<M: Math, T: Transformation<M>> {
410    ones: M::Vector,
411    zeros: M::Vector,
412    step_size: f64,
413    /// Momentum decoherence length `L` for the isokinetic Langevin refresh.
414    /// `None` disables the refresh (used by NUTS); `Some(L)` enables it (MCLMC).
415    momentum_decoherence_length: Option<f64>,
416    transformation: T,
417    pub kinetic_energy_kind: KineticEnergyKind,
418    pool: StatePool<M, TransformedPoint<M>>,
419}
420
421impl<M: Math, T: Transformation<M>> TransformedHamiltonian<M, T> {
422    pub fn new(math: &mut M, transformation: T, kinetic_energy_kind: KineticEnergyKind) -> Self {
423        let mut ones = math.new_array();
424        math.fill_array(&mut ones, 1f64);
425        let mut zeros = math.new_array();
426        math.fill_array(&mut zeros, 0f64);
427        let pool = StatePool::new(math, 10);
428        Self {
429            step_size: 0f64,
430            momentum_decoherence_length: None,
431            ones,
432            zeros,
433            transformation,
434            kinetic_energy_kind,
435            pool,
436        }
437    }
438
439    pub fn transformation(&self) -> &T {
440        &self.transformation
441    }
442
443    pub fn transformation_mut(&mut self) -> &mut T {
444        &mut self.transformation
445    }
446
447    pub fn set_momentum_decoherence_length(&mut self, l: Option<f64>) {
448        self.momentum_decoherence_length = l;
449    }
450
451    /// Change the kinetic-energy kind (and thus the leapfrog integrator and
452    /// momentum distribution) used by this Hamiltonian.
453    ///
454    /// When switching from [`KineticEnergyKind::Euclidean`] to
455    /// [`KineticEnergyKind::Microcanonical`] the caller is responsible for
456    /// reinitializing the state.
457    pub fn set_kinetic_energy_kind(&mut self, kind: KineticEnergyKind) {
458        self.kinetic_energy_kind = kind;
459    }
460}
461
462impl<M: Math> TransformedHamiltonian<M, ExternalTransformation<M>> {
463    pub fn init_transformation<R: rand::Rng + ?Sized>(
464        &mut self,
465        rng: &mut R,
466        math: &mut M,
467        position: &[f64],
468        chain: u64,
469    ) -> Result<(), NutsError> {
470        let mut gradient_array = math.new_array();
471        let mut position_array = math.new_array();
472        math.read_from_slice(&mut position_array, position);
473        let _ = math
474            .logp_array(&position_array, &mut gradient_array)
475            .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?;
476        let mut params = math
477            .init_transformation(rng, &position_array, &gradient_array, chain)
478            .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?;
479        std::mem::swap(self.transformation_mut().params_mut(), &mut params);
480        Ok(())
481    }
482
483    pub fn update_params<'a, R: rand::Rng + ?Sized>(
484        &'a mut self,
485        math: &'a mut M,
486        rng: &mut R,
487        draws: impl ExactSizeIterator<Item = &'a M::Vector>,
488        grads: impl ExactSizeIterator<Item = &'a M::Vector>,
489        logps: impl ExactSizeIterator<Item = &'a f64>,
490    ) -> Result<(), NutsError> {
491        let t = self.transformation_mut();
492        math.update_transformation(rng, draws, grads, logps, t.params_mut())
493            .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?;
494        Ok(())
495    }
496}
497
498#[derive(Debug, Storable)]
499pub struct HamiltonianStats<P: HasDims, S: nuts_storable::Storable<P>> {
500    pub step_size: f64,
501    #[storable(flatten)]
502    pub transformation: S,
503    #[storable(ignore)]
504    _phantom: PhantomData<fn() -> P>,
505}
506
507impl<M: Math, T: Transformation<M>> SamplerStats<M> for TransformedHamiltonian<M, T> {
508    type Stats = HamiltonianStats<StatsDims, T::Stats>;
509    type StatsOptions = T::StatsOptions;
510
511    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
512        let transformation_stats = self.transformation.extract_stats(math, opt);
513        HamiltonianStats {
514            step_size: self.step_size,
515            transformation: transformation_stats,
516            _phantom: PhantomData,
517        }
518    }
519}
520
521impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M, T> {
522    type Point = TransformedPoint<M>;
523
524    fn leapfrog<C: crate::nuts::Collector<M, Self::Point>>(
525        &mut self,
526        math: &mut M,
527        start: &State<M, Self::Point>,
528        dir: Direction,
529        step_size_factor: f64,
530        energy_baseline: f64,
531        max_energy_error: f64,
532        collector: &mut C,
533    ) -> LeapfrogResult<M, Self::Point> {
534        let mut out = self.pool().new_state(math);
535        let out_point = out.try_point_mut().expect("New point has other references");
536
537        out_point.initial_energy = start.point().initial_energy();
538        out_point.transform_id = start.point().transform_id;
539
540        let sign = match dir {
541            Direction::Forward => 1,
542            Direction::Backward => -1,
543        };
544
545        let epsilon = (sign as f64) * self.step_size * step_size_factor;
546        out_point.step_size_factor = step_size_factor;
547        let kind = self.kinetic_energy_kind;
548
549        // --- First velocity half-step ---
550        // For Microcanonical: out_point.kinetic_energy receives the running ΔKE
551        // after this call; for other kinds it is left at whatever value it had
552        // (it will be overwritten by update_kinetic_energy below).
553        start
554            .point()
555            .first_velocity_halfstep(math, out_point, epsilon, kind);
556
557        // --- Position step ---
558        start.point().position_step(math, out_point, epsilon, kind);
559
560        // --- Evaluate log-density at new position ---
561        let transformation = self.transformation();
562        if let Err(logp_error) = out_point.init_from_transformed_position(transformation, math) {
563            if !logp_error.is_recoverable() {
564                return LeapfrogResult::Err(logp_error);
565            }
566            let div_info = DivergenceInfo {
567                logp_function_error: Some(Arc::new(Box::new(logp_error))),
568                start_location: Some(math.box_array(start.point().position())),
569                start_gradient: Some(math.box_array(start.point().gradient())),
570                start_momentum: None,
571                end_location: None,
572                start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
573                end_idx_in_trajectory: None,
574                energy_error: None,
575            };
576            collector.register_leapfrog(math, start, &out, Some(&div_info));
577            return LeapfrogResult::Divergence(div_info);
578        }
579
580        out_point.second_velocity_halfstep(math, epsilon, kind);
581
582        // For Microcanonical, kinetic_energy already holds the total accumulated ΔKE
583        // (set by second_velocity_halfstep). For other kinds we recompute from ½‖v‖².
584        if kind != KineticEnergyKind::Microcanonical {
585            out_point.update_kinetic_energy(math);
586        }
587
588        out_point.index_in_trajectory = start.index_in_trajectory() + sign;
589
590        let energy_error = out_point.energy() - energy_baseline;
591        let bad_energy = match self.kinetic_energy_kind {
592            KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
593                energy_error > max_energy_error
594            }
595            KineticEnergyKind::Microcanonical => energy_error.abs() >= max_energy_error,
596        };
597        if bad_energy | !energy_error.is_finite() {
598            let divergence_info = DivergenceInfo {
599                logp_function_error: None,
600                start_location: Some(math.box_array(start.point().position())),
601                start_gradient: Some(math.box_array(start.point().gradient())),
602                end_location: Some(math.box_array(out_point.position())),
603                start_momentum: None,
604                start_idx_in_trajectory: Some(start.index_in_trajectory()),
605                end_idx_in_trajectory: Some(out.index_in_trajectory()),
606                energy_error: Some(energy_error),
607            };
608            collector.register_leapfrog(math, start, &out, Some(&divergence_info));
609            return LeapfrogResult::Divergence(divergence_info);
610        }
611
612        collector.register_leapfrog(math, start, &out, None);
613
614        LeapfrogResult::Ok(out)
615    }
616
617    fn is_turning(
618        &self,
619        math: &mut M,
620        state1: &State<M, Self::Point>,
621        state2: &State<M, Self::Point>,
622    ) -> bool {
623        let (start, end) = if state1.index_in_trajectory() < state2.index_in_trajectory() {
624            (state1, state2)
625        } else {
626            (state2, state1)
627        };
628
629        let (turn1, turn2) = math.scalar_prods3(
630            &end.point().transformed_position,
631            &start.point().transformed_position,
632            &self.zeros,
633            &start.point().velocity,
634            &end.point().velocity,
635        );
636
637        (turn1 < 0f64) | (turn2 < 0f64)
638    }
639
640    fn init_state(
641        &mut self,
642        math: &mut M,
643        init: &[f64],
644    ) -> Result<State<M, Self::Point>, NutsError> {
645        let mut state = self.pool().new_state(math);
646        let point = state.try_point_mut().expect("State already in use");
647        math.read_from_slice(&mut point.untransformed_position, init);
648
649        let transformation = self.transformation();
650        point
651            .init_from_untransformed_position(transformation, math)
652            .map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
653
654        if !point.check_all(math) {
655            Err(NutsError::BadInitGrad(
656                anyhow::anyhow!("Invalid initial point").into(),
657            ))
658        } else {
659            Ok(state)
660        }
661    }
662
663    fn init_state_untransformed(
664        &mut self,
665        math: &mut M,
666        untransformed_position: &[f64],
667    ) -> Result<State<M, Self::Point>, NutsError> {
668        let mut state = self.pool().new_state(math);
669        let point = state.try_point_mut().expect("State already in use");
670        math.read_from_slice(&mut point.untransformed_position, untransformed_position);
671        math.logp_array(
672            &point.untransformed_position,
673            &mut point.untransformed_gradient,
674        )
675        .map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
676        // Force recomputation of transformed coordinates on first leapfrog step
677        point.transform_id = -1;
678        if !point.check_untransformed(math) {
679            Err(NutsError::BadInitGrad(
680                anyhow::anyhow!("Invalid initial point").into(),
681            ))
682        } else {
683            Ok(state)
684        }
685    }
686
687    fn initialize_trajectory<R: rand::Rng + ?Sized>(
688        &self,
689        math: &mut M,
690        state: &mut State<M, Self::Point>,
691        resample_velocity: bool,
692        rng: &mut R,
693    ) -> Result<(), NutsError> {
694        let point = state.try_point_mut().expect("State has other references");
695
696        if resample_velocity {
697            // Sample raw isotropic Gaussian momentum.
698            math.array_gaussian(rng, &mut point.velocity, &self.ones);
699
700            // For Microcanonical HMC the momentum must lie on the unit sphere.
701            if self.kinetic_energy_kind == KineticEnergyKind::Microcanonical {
702                math.array_normalize(&mut point.velocity);
703            }
704        }
705
706        let current_transform_id = self.transformation().transformation_id(math);
707        if current_transform_id != point.transform_id {
708            let logdet = self
709                .transformation()
710                .inv_transform_normalize(
711                    math,
712                    &point.untransformed_position,
713                    &point.untransformed_gradient,
714                    &mut point.transformed_position,
715                    &mut point.transformed_gradient,
716                )
717                .map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
718            point.logdet = logdet;
719            point.transform_id = current_transform_id;
720        }
721
722        match self.kinetic_energy_kind {
723            KineticEnergyKind::Microcanonical => {
724                // Initial accumulated ΔKE is 0 (no steps taken yet).
725                // energy() = 0 − (logp + logdet) = −(logp + logdet).
726                point.kinetic_energy = 0.0;
727            }
728            _ => {
729                point.update_kinetic_energy(math);
730            }
731        }
732
733        point.index_in_trajectory = 0;
734        point.initial_energy = point.energy();
735        Ok(())
736    }
737
738    fn pool(&mut self) -> &mut StatePool<M, Self::Point> {
739        &mut self.pool
740    }
741
742    fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point> {
743        let mut new_state = self.pool.new_state(math);
744        state.point().copy_into(
745            math,
746            new_state
747                .try_point_mut()
748                .expect("New point should not have other references"),
749        );
750        new_state
751    }
752
753    fn step_size(&self) -> f64 {
754        self.step_size
755    }
756
757    fn update_stats_options(
758        &mut self,
759        math: &mut M,
760        current: <Self as SamplerStats<M>>::StatsOptions,
761    ) -> <Self as SamplerStats<M>>::StatsOptions {
762        self.transformation.next_stats_options(math, current)
763    }
764
765    fn step_size_mut(&mut self) -> &mut f64 {
766        &mut self.step_size
767    }
768
769    fn momentum_decoherence_length(&self) -> Option<f64> {
770        self.momentum_decoherence_length
771    }
772
773    fn momentum_decoherence_length_mut(&mut self) -> Option<&mut f64> {
774        self.momentum_decoherence_length.as_mut()
775    }
776
777    fn partial_momentum_refresh<R: rand::Rng + ?Sized>(
778        &mut self,
779        math: &mut M,
780        state: &mut State<M, Self::Point>,
781        noise: &M::Vector,
782        _rng: &mut R,
783        factor: f64,
784    ) -> Result<(), NutsError> {
785        let Some(momentum_decoherence_length) = self.momentum_decoherence_length else {
786            return Ok(());
787        };
788
789        let half_step = self.step_size * factor / 2.0;
790
791        let point = state.try_point_mut().map_err(|_| {
792            NutsError::BadInitGrad(anyhow::anyhow!("State in use during momentum refresh").into())
793        })?;
794
795        match self.kinetic_energy_kind {
796            KineticEnergyKind::Microcanonical => {
797                // Isokinetic Langevin (OU on the unit sphere):
798                // ν = sqrt((exp(2·half_step/L) − 1) / n),  n = dim
799                // p ← (p + ν·z) / ‖p + ν·z‖,  z ~ N(0, I)
800                let n = math.dim() as f64;
801                let nu = ((2.0 * half_step / momentum_decoherence_length).exp_m1() / n).sqrt();
802                math.axpy(&noise, &mut point.velocity, nu);
803                math.array_normalize(&mut point.velocity);
804            }
805            KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
806                // Ornstein–Uhlenbeck for Gaussian momentum p ~ N(0, I):
807                //   α = exp(−half_step / L)
808                //   β = sqrt(1 − α²)
809                //   p_new = α · p + β · z,  z ~ N(0, I)
810                //
811                // `axpy_out(x, y, a, out)` computes `out = y + a·x`.
812                // So `axpy_out(&velocity, &zeros, alpha, &mut new_velocity)`
813                // gives `new_velocity = zeros + alpha·velocity = alpha·velocity`.
814                let alpha = (-half_step / momentum_decoherence_length).exp();
815                let beta = (1.0 - alpha * alpha).sqrt();
816                let mut new_velocity = math.new_array();
817                math.axpy_out(&point.velocity, &self.zeros, alpha, &mut new_velocity);
818                math.axpy(&noise, &mut new_velocity, beta);
819                math.copy_into(&new_velocity, &mut point.velocity);
820                // Keep kinetic_energy consistent with the updated velocity.
821                point.update_kinetic_energy(math);
822            }
823        }
824
825        Ok(())
826    }
827}