nuts_rs/
hamiltonian.rs

1use std::sync::Arc;
2
3use rand_distr::{Distribution, StandardUniform};
4
5use crate::{
6    Math, NutsError,
7    nuts::Collector,
8    sampler_stats::SamplerStats,
9    state::{State, StatePool},
10};
11
12/// Details about a divergence that might have occured during sampling
13///
14/// There are two reasons why we might observe a divergence:
15/// - The integration error of the Hamiltonian is larger than
16///   a cutoff value or nan.
17/// - The logp function caused a recoverable error (eg if an ODE solver
18///   failed)
19#[derive(Debug, Clone)]
20pub struct DivergenceInfo {
21    pub start_momentum: Option<Box<[f64]>>,
22    pub start_location: Option<Box<[f64]>>,
23    pub start_gradient: Option<Box<[f64]>>,
24    pub end_location: Option<Box<[f64]>>,
25    pub energy_error: Option<f64>,
26    pub end_idx_in_trajectory: Option<i64>,
27    pub start_idx_in_trajectory: Option<i64>,
28    pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
29}
30
31#[derive(Debug, Copy, Clone)]
32pub enum Direction {
33    Forward,
34    Backward,
35}
36
37impl Distribution<Direction> for StandardUniform {
38    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Direction {
39        if rng.random::<bool>() {
40            Direction::Forward
41        } else {
42            Direction::Backward
43        }
44    }
45}
46
47pub enum LeapfrogResult<M: Math, P: Point<M>> {
48    Ok(State<M, P>),
49    Divergence(DivergenceInfo),
50    Err(M::LogpErr),
51}
52
53pub trait Point<M: Math>: Sized + SamplerStats<M> {
54    fn position(&self) -> &M::Vector;
55    fn gradient(&self) -> &M::Vector;
56    fn index_in_trajectory(&self) -> i64;
57    fn energy(&self) -> f64;
58    fn logp(&self) -> f64;
59
60    fn energy_error(&self) -> f64 {
61        self.energy() - self.initial_energy()
62    }
63
64    fn initial_energy(&self) -> f64;
65
66    fn new(math: &mut M) -> Self;
67    fn copy_into(&self, math: &mut M, other: &mut Self);
68}
69
70/// The hamiltonian defined by the potential energy and the kinetic energy
71pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
72    /// The type that stores a point in phase space, together
73    /// with some information about the location inside the
74    /// integration trajectory.
75    type Point: Point<M>;
76
77    /// Perform one leapfrog step.
78    ///
79    /// Return either an unrecoverable error, a new state or a divergence.
80    fn leapfrog<C: Collector<M, Self::Point>>(
81        &mut self,
82        math: &mut M,
83        start: &State<M, Self::Point>,
84        dir: Direction,
85        collector: &mut C,
86    ) -> LeapfrogResult<M, Self::Point>;
87
88    fn is_turning(
89        &self,
90        math: &mut M,
91        state1: &State<M, Self::Point>,
92        state2: &State<M, Self::Point>,
93    ) -> bool;
94
95    /// Initialize a state at a new location.
96    ///
97    /// The momentum should be initialized to some arbitrary invalid number,
98    /// it will later be set using Self::randomize_momentum.
99    fn init_state(
100        &mut self,
101        math: &mut M,
102        init: &[f64],
103    ) -> Result<State<M, Self::Point>, NutsError>;
104
105    /// Randomize the momentum part of a state
106    fn initialize_trajectory<R: rand::Rng + ?Sized>(
107        &self,
108        math: &mut M,
109        state: &mut State<M, Self::Point>,
110        rng: &mut R,
111    ) -> Result<(), NutsError>;
112
113    fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
114
115    fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
116
117    fn step_size(&self) -> f64;
118    fn step_size_mut(&mut self) -> &mut f64;
119}