1use std::sync::Arc;
2
3use rand_distr::{Distribution, StandardUniform};
4
5use crate::{
6 Math, NutsError,
7 nuts::Collector,
8 sampler_stats::SamplerStats,
9 state::{State, StatePool},
10};
11
12#[derive(Debug, Clone)]
20pub struct DivergenceInfo {
21 pub start_momentum: Option<Box<[f64]>>,
22 pub start_location: Option<Box<[f64]>>,
23 pub start_gradient: Option<Box<[f64]>>,
24 pub end_location: Option<Box<[f64]>>,
25 pub energy_error: Option<f64>,
26 pub end_idx_in_trajectory: Option<i64>,
27 pub start_idx_in_trajectory: Option<i64>,
28 pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
29}
30
31#[derive(Debug, Copy, Clone)]
32pub enum Direction {
33 Forward,
34 Backward,
35}
36
37impl Distribution<Direction> for StandardUniform {
38 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Direction {
39 if rng.random::<bool>() {
40 Direction::Forward
41 } else {
42 Direction::Backward
43 }
44 }
45}
46
47pub enum LeapfrogResult<M: Math, P: Point<M>> {
48 Ok(State<M, P>),
49 Divergence(DivergenceInfo),
50 Err(M::LogpErr),
51}
52
53pub trait Point<M: Math>: Sized + SamplerStats<M> {
54 fn position(&self) -> &M::Vector;
55 fn gradient(&self) -> &M::Vector;
56 fn index_in_trajectory(&self) -> i64;
57 fn energy(&self) -> f64;
58 fn logp(&self) -> f64;
59
60 fn energy_error(&self) -> f64 {
61 self.energy() - self.initial_energy()
62 }
63
64 fn initial_energy(&self) -> f64;
65
66 fn new(math: &mut M) -> Self;
67 fn copy_into(&self, math: &mut M, other: &mut Self);
68}
69
70pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
72 type Point: Point<M>;
76
77 fn leapfrog<C: Collector<M, Self::Point>>(
81 &mut self,
82 math: &mut M,
83 start: &State<M, Self::Point>,
84 dir: Direction,
85 collector: &mut C,
86 ) -> LeapfrogResult<M, Self::Point>;
87
88 fn is_turning(
89 &self,
90 math: &mut M,
91 state1: &State<M, Self::Point>,
92 state2: &State<M, Self::Point>,
93 ) -> bool;
94
95 fn init_state(
100 &mut self,
101 math: &mut M,
102 init: &[f64],
103 ) -> Result<State<M, Self::Point>, NutsError>;
104
105 fn initialize_trajectory<R: rand::Rng + ?Sized>(
107 &self,
108 math: &mut M,
109 state: &mut State<M, Self::Point>,
110 rng: &mut R,
111 ) -> Result<(), NutsError>;
112
113 fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
114
115 fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
116
117 fn step_size(&self) -> f64;
118 fn step_size_mut(&mut self) -> &mut f64;
119}