Skip to main content

nuts_rs/
hamiltonian.rs

1use std::sync::Arc;
2
3use rand::{
4    Rng, RngExt,
5    distr::{Distribution, StandardUniform},
6};
7
8use crate::{
9    Math, NutsError,
10    nuts::Collector,
11    sampler_stats::SamplerStats,
12    state::{State, StatePool},
13};
14
15/// Details about a divergence that might have occured during sampling
16///
17/// There are two reasons why we might observe a divergence:
18/// - The integration error of the Hamiltonian is larger than
19///   a cutoff value or nan.
20/// - The logp function caused a recoverable error (eg if an ODE solver
21///   failed)
22#[derive(Debug, Clone)]
23pub struct DivergenceInfo {
24    pub start_momentum: Option<Box<[f64]>>,
25    pub start_location: Option<Box<[f64]>>,
26    pub start_gradient: Option<Box<[f64]>>,
27    pub end_location: Option<Box<[f64]>>,
28    pub energy_error: Option<f64>,
29    pub end_idx_in_trajectory: Option<i64>,
30    pub start_idx_in_trajectory: Option<i64>,
31    pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
32}
33
34#[derive(Debug, Copy, Clone)]
35pub enum Direction {
36    Forward,
37    Backward,
38}
39
40impl Distribution<Direction> for StandardUniform {
41    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
42        if rng.random::<bool>() {
43            Direction::Forward
44        } else {
45            Direction::Backward
46        }
47    }
48}
49
50pub enum LeapfrogResult<M: Math, P: Point<M>> {
51    Ok(State<M, P>),
52    Divergence(DivergenceInfo),
53    Err(M::LogpErr),
54}
55
56pub trait Point<M: Math>: Sized + SamplerStats<M> {
57    fn position(&self) -> &M::Vector;
58    fn gradient(&self) -> &M::Vector;
59    fn index_in_trajectory(&self) -> i64;
60    fn energy(&self) -> f64;
61    fn logp(&self) -> f64;
62
63    fn energy_error(&self) -> f64 {
64        self.energy() - self.initial_energy()
65    }
66
67    fn initial_energy(&self) -> f64;
68
69    fn new(math: &mut M) -> Self;
70    fn copy_into(&self, math: &mut M, other: &mut Self);
71}
72
73/// The hamiltonian defined by the potential energy and the kinetic energy
74pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
75    /// The type that stores a point in phase space, together
76    /// with some information about the location inside the
77    /// integration trajectory.
78    type Point: Point<M>;
79
80    /// Perform one leapfrog step.
81    ///
82    /// Return either an unrecoverable error, a new state or a divergence.
83    fn leapfrog<C: Collector<M, Self::Point>>(
84        &mut self,
85        math: &mut M,
86        start: &State<M, Self::Point>,
87        dir: Direction,
88        collector: &mut C,
89    ) -> LeapfrogResult<M, Self::Point>;
90
91    fn is_turning(
92        &self,
93        math: &mut M,
94        state1: &State<M, Self::Point>,
95        state2: &State<M, Self::Point>,
96    ) -> bool;
97
98    /// Initialize a state at a new location.
99    ///
100    /// The momentum should be initialized to some arbitrary invalid number,
101    /// it will later be set using Self::randomize_momentum.
102    fn init_state(
103        &mut self,
104        math: &mut M,
105        init: &[f64],
106    ) -> Result<State<M, Self::Point>, NutsError>;
107
108    /// Randomize the momentum part of a state
109    fn initialize_trajectory<R: rand::Rng + ?Sized>(
110        &self,
111        math: &mut M,
112        state: &mut State<M, Self::Point>,
113        rng: &mut R,
114    ) -> Result<(), NutsError>;
115
116    fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
117
118    fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
119
120    fn step_size(&self) -> f64;
121    fn step_size_mut(&mut self) -> &mut f64;
122}