Skip to main content

nuts_rs/dynamics/
hamiltonian.rs

1//! Define the abstract interface for a Hamiltonian system (leapfrog, U-turn test, divergence detection).
2
3use std::{fmt::Debug, sync::Arc};
4
5use nuts_derive::Storable;
6use rand::{
7    Rng, RngExt,
8    distr::{Distribution, StandardUniform},
9};
10
11use crate::{
12    Math, NutsError,
13    dynamics::{State, StatePool},
14    nuts::Collector,
15    sampler_stats::SamplerStats,
16};
17
18/// Details about a divergence that might have occured during sampling
19///
20/// There are two reasons why we might observe a divergence:
21/// - The integration error of the Hamiltonian is larger than
22///   a cutoff value or nan.
23/// - The logp function caused a recoverable error (eg if an ODE solver
24///   failed)
25#[derive(Debug, Clone)]
26pub struct DivergenceInfo {
27    pub start_momentum: Option<Box<[f64]>>,
28    pub start_location: Option<Box<[f64]>>,
29    pub start_gradient: Option<Box<[f64]>>,
30    pub end_location: Option<Box<[f64]>>,
31    pub energy_error: Option<f64>,
32    pub end_idx_in_trajectory: Option<i64>,
33    pub start_idx_in_trajectory: Option<i64>,
34    pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
35}
36
37/// Per-draw divergence statistics, suitable for storage.
38#[derive(Debug, Storable)]
39pub struct DivergenceStats {
40    pub diverging: bool,
41    #[storable(event = "divergence")]
42    pub divergence_draw: Option<u64>,
43    #[storable(event = "divergence")]
44    pub divergence_message: Option<String>,
45    #[storable(event = "divergence", dims("unconstrained_parameter"))]
46    pub divergence_start: Option<Vec<f64>>,
47    #[storable(event = "divergence", dims("unconstrained_parameter"))]
48    pub divergence_start_gradient: Option<Vec<f64>>,
49    #[storable(event = "divergence", dims("unconstrained_parameter"))]
50    pub divergence_end: Option<Vec<f64>>,
51    #[storable(event = "divergence", dims("unconstrained_parameter"))]
52    pub divergence_momentum: Option<Vec<f64>>,
53    #[storable(event = "divergence")]
54    pub divergence_energy_error: Option<f64>,
55}
56
57#[derive(Debug, Clone, Copy)]
58pub struct DivergenceStatsOptions {
59    pub store_divergences: bool,
60}
61
62impl From<(Option<&DivergenceInfo>, DivergenceStatsOptions, u64)> for DivergenceStats {
63    fn from((info, options, draw): (Option<&DivergenceInfo>, DivergenceStatsOptions, u64)) -> Self {
64        DivergenceStats {
65            diverging: info.is_some(),
66            divergence_draw: info.map(|_| draw),
67            divergence_start: if options.store_divergences {
68                info.and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec()))
69            } else {
70                None
71            },
72            divergence_start_gradient: if options.store_divergences {
73                info.and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec()))
74            } else {
75                None
76            },
77            divergence_end: if options.store_divergences {
78                info.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec()))
79            } else {
80                None
81            },
82            divergence_momentum: if options.store_divergences {
83                info.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec()))
84            } else {
85                None
86            },
87            divergence_message: info.map(|d| {
88                if let Some(err) = &d.logp_function_error {
89                    err.to_string()
90                } else if let Some(energy_err) = d.energy_error {
91                    if energy_err.is_nan() {
92                        "Divergence due to NaN energy error".to_string()
93                    } else {
94                        format!("Divergence due to large energy error: {:.4}", energy_err)
95                    }
96                } else {
97                    "Divergence (unknown cause)".to_string()
98                }
99            }),
100            divergence_energy_error: info.and_then(|d| d.energy_error),
101        }
102    }
103}
104
105#[derive(Debug, Copy, Clone)]
106pub enum Direction {
107    Forward,
108    Backward,
109}
110
111impl Distribution<Direction> for StandardUniform {
112    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
113        if rng.random::<bool>() {
114            Direction::Forward
115        } else {
116            Direction::Backward
117        }
118    }
119}
120
121pub enum LeapfrogResult<M: Math, P: Point<M>> {
122    Ok(State<M, P>),
123    Divergence(DivergenceInfo),
124    Err(M::LogpErr),
125}
126
127pub trait Point<M: Math>: Sized + SamplerStats<M> + Debug {
128    fn position(&self) -> &M::Vector;
129    fn gradient(&self) -> &M::Vector;
130    fn index_in_trajectory(&self) -> i64;
131    fn energy(&self) -> f64;
132    fn logp(&self) -> f64;
133
134    fn energy_error(&self) -> f64 {
135        self.energy() - self.initial_energy()
136    }
137
138    fn initial_energy(&self) -> f64;
139
140    fn new(math: &mut M) -> Self;
141    fn copy_into(&self, math: &mut M, other: &mut Self);
142}
143
144/// The hamiltonian defined by the potential energy and the kinetic energy
145pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
146    /// The type that stores a point in phase space, together
147    /// with some information about the location inside the
148    /// integration trajectory.
149    type Point: Point<M>;
150
151    /// Perform one leapfrog step.
152    ///
153    /// `step_size_factor` scales the hamiltonian's base step size for this
154    /// step only.
155    /// `energy_baseline` is the energy value against which the divergence
156    /// check (`|energy_error| >= max_energy_error`) is evaluated.
157    ///
158    /// Return either an unrecoverable error, a new state or a divergence.
159    fn leapfrog<C: Collector<M, Self::Point>>(
160        &mut self,
161        math: &mut M,
162        start: &State<M, Self::Point>,
163        dir: Direction,
164        step_size_factor: f64,
165        energy_baseline: f64,
166        max_energy_error: f64,
167        collector: &mut C,
168    ) -> LeapfrogResult<M, Self::Point>;
169
170    fn is_turning(
171        &self,
172        math: &mut M,
173        state1: &State<M, Self::Point>,
174        state2: &State<M, Self::Point>,
175    ) -> bool;
176
177    /// Initialize a state at a new location.
178    ///
179    /// The momentum should be initialized to some arbitrary invalid number,
180    /// it will later be set using Self::randomize_momentum.
181    fn init_state(
182        &mut self,
183        math: &mut M,
184        init: &[f64],
185    ) -> Result<State<M, Self::Point>, NutsError>;
186
187    /// Initialize a state at a new location, without applying a transformation.
188    fn init_state_untransformed(
189        &mut self,
190        math: &mut M,
191        init: &[f64],
192    ) -> Result<State<M, Self::Point>, NutsError>;
193
194    /// Randomize the momentum part of a state
195    fn initialize_trajectory<R: rand::Rng + ?Sized>(
196        &self,
197        math: &mut M,
198        state: &mut State<M, Self::Point>,
199        resaple_velocity: bool,
200        rng: &mut R,
201    ) -> Result<(), NutsError>;
202
203    fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
204
205    fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
206
207    fn step_size(&self) -> f64;
208    fn step_size_mut(&mut self) -> &mut f64;
209
210    /// Return updated hamiltonian stats options to use on the next draw.
211    ///
212    /// Called in `expanded_draw` after stats extraction.  For hamiltonians
213    /// with a trackable transformation, this records the current transformation
214    /// id into the options so the following `extract_stats` call can detect
215    /// whether the mass matrix changed and emit a `transformation_update` event.
216    /// The default passes the current options through unchanged, meaning no
217    /// transformation-update events are ever emitted.
218    fn update_stats_options(
219        &mut self,
220        _math: &mut M,
221        current: <Self as SamplerStats<M>>::StatsOptions,
222    ) -> <Self as SamplerStats<M>>::StatsOptions {
223        current
224    }
225
226    /// The momentum decoherence length `L` used for the isokinetic Langevin
227    /// (partial momentum refresh) step.
228    ///
229    /// - `None` means no refresh is performed (default, used by NUTS).
230    /// - `Some(L)` enables a half-step Ornstein–Uhlenbeck refresh with
231    ///   `ν = sqrt((exp(2·ε/L) − 1) / n)` around each trajectory.
232    fn momentum_decoherence_length(&self) -> Option<f64> {
233        None
234    }
235
236    fn momentum_decoherence_length_mut(&mut self) -> Option<&mut f64> {
237        None
238    }
239
240    /// Apply one isokinetic Langevin partial momentum refresh to `state`.
241    ///
242    /// `factor` scales the base step size: the half-step used internally is
243    /// `hamiltonian.step_size() * factor / 2`.  When
244    /// [`Self::momentum_decoherence_length`] returns `None` this must be a
245    /// no-op.  Implementations that support the refresh should override this
246    /// method.
247    fn partial_momentum_refresh<R: rand::Rng + ?Sized>(
248        &mut self,
249        math: &mut M,
250        state: &mut State<M, Self::Point>,
251        noise: &M::Vector,
252        rng: &mut R,
253        factor: f64,
254    ) -> Result<(), NutsError> {
255        let _ = (math, state, noise, rng, factor);
256        Ok(())
257    }
258}