1use std::sync::Arc;
2
3use rand::{
4 Rng, RngExt,
5 distr::{Distribution, StandardUniform},
6};
7
8use crate::{
9 Math, NutsError,
10 nuts::Collector,
11 sampler_stats::SamplerStats,
12 state::{State, StatePool},
13};
14
15#[derive(Debug, Clone)]
23pub struct DivergenceInfo {
24 pub start_momentum: Option<Box<[f64]>>,
25 pub start_location: Option<Box<[f64]>>,
26 pub start_gradient: Option<Box<[f64]>>,
27 pub end_location: Option<Box<[f64]>>,
28 pub energy_error: Option<f64>,
29 pub end_idx_in_trajectory: Option<i64>,
30 pub start_idx_in_trajectory: Option<i64>,
31 pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
32}
33
34#[derive(Debug, Copy, Clone)]
35pub enum Direction {
36 Forward,
37 Backward,
38}
39
40impl Distribution<Direction> for StandardUniform {
41 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
42 if rng.random::<bool>() {
43 Direction::Forward
44 } else {
45 Direction::Backward
46 }
47 }
48}
49
50pub enum LeapfrogResult<M: Math, P: Point<M>> {
51 Ok(State<M, P>),
52 Divergence(DivergenceInfo),
53 Err(M::LogpErr),
54}
55
56pub trait Point<M: Math>: Sized + SamplerStats<M> {
57 fn position(&self) -> &M::Vector;
58 fn gradient(&self) -> &M::Vector;
59 fn index_in_trajectory(&self) -> i64;
60 fn energy(&self) -> f64;
61 fn logp(&self) -> f64;
62
63 fn energy_error(&self) -> f64 {
64 self.energy() - self.initial_energy()
65 }
66
67 fn initial_energy(&self) -> f64;
68
69 fn new(math: &mut M) -> Self;
70 fn copy_into(&self, math: &mut M, other: &mut Self);
71}
72
73pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
75 type Point: Point<M>;
79
80 fn leapfrog<C: Collector<M, Self::Point>>(
84 &mut self,
85 math: &mut M,
86 start: &State<M, Self::Point>,
87 dir: Direction,
88 collector: &mut C,
89 ) -> LeapfrogResult<M, Self::Point>;
90
91 fn is_turning(
92 &self,
93 math: &mut M,
94 state1: &State<M, Self::Point>,
95 state2: &State<M, Self::Point>,
96 ) -> bool;
97
98 fn init_state(
103 &mut self,
104 math: &mut M,
105 init: &[f64],
106 ) -> Result<State<M, Self::Point>, NutsError>;
107
108 fn initialize_trajectory<R: rand::Rng + ?Sized>(
110 &self,
111 math: &mut M,
112 state: &mut State<M, Self::Point>,
113 rng: &mut R,
114 ) -> Result<(), NutsError>;
115
116 fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
117
118 fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
119
120 fn step_size(&self) -> f64;
121 fn step_size_mut(&mut self) -> &mut f64;
122}