1use std::{fmt::Debug, sync::Arc};
4
5use nuts_derive::Storable;
6use rand::{
7 Rng, RngExt,
8 distr::{Distribution, StandardUniform},
9};
10
11use crate::{
12 Math, NutsError,
13 dynamics::{State, StatePool},
14 nuts::Collector,
15 sampler_stats::SamplerStats,
16};
17
18#[derive(Debug, Clone)]
26pub struct DivergenceInfo {
27 pub start_momentum: Option<Box<[f64]>>,
28 pub start_location: Option<Box<[f64]>>,
29 pub start_gradient: Option<Box<[f64]>>,
30 pub end_location: Option<Box<[f64]>>,
31 pub energy_error: Option<f64>,
32 pub end_idx_in_trajectory: Option<i64>,
33 pub start_idx_in_trajectory: Option<i64>,
34 pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
35}
36
37#[derive(Debug, Storable)]
39pub struct DivergenceStats {
40 pub diverging: bool,
41 #[storable(event = "divergence")]
42 pub divergence_draw: Option<u64>,
43 #[storable(event = "divergence")]
44 pub divergence_message: Option<String>,
45 #[storable(event = "divergence", dims("unconstrained_parameter"))]
46 pub divergence_start: Option<Vec<f64>>,
47 #[storable(event = "divergence", dims("unconstrained_parameter"))]
48 pub divergence_start_gradient: Option<Vec<f64>>,
49 #[storable(event = "divergence", dims("unconstrained_parameter"))]
50 pub divergence_end: Option<Vec<f64>>,
51 #[storable(event = "divergence", dims("unconstrained_parameter"))]
52 pub divergence_momentum: Option<Vec<f64>>,
53 #[storable(event = "divergence")]
54 pub divergence_energy_error: Option<f64>,
55}
56
57#[derive(Debug, Clone, Copy)]
58pub struct DivergenceStatsOptions {
59 pub store_divergences: bool,
60}
61
62impl From<(Option<&DivergenceInfo>, DivergenceStatsOptions, u64)> for DivergenceStats {
63 fn from((info, options, draw): (Option<&DivergenceInfo>, DivergenceStatsOptions, u64)) -> Self {
64 DivergenceStats {
65 diverging: info.is_some(),
66 divergence_draw: info.map(|_| draw),
67 divergence_start: if options.store_divergences {
68 info.and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec()))
69 } else {
70 None
71 },
72 divergence_start_gradient: if options.store_divergences {
73 info.and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec()))
74 } else {
75 None
76 },
77 divergence_end: if options.store_divergences {
78 info.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec()))
79 } else {
80 None
81 },
82 divergence_momentum: if options.store_divergences {
83 info.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec()))
84 } else {
85 None
86 },
87 divergence_message: info.map(|d| {
88 if let Some(err) = &d.logp_function_error {
89 err.to_string()
90 } else if let Some(energy_err) = d.energy_error {
91 if energy_err.is_nan() {
92 "Divergence due to NaN energy error".to_string()
93 } else {
94 format!("Divergence due to large energy error: {:.4}", energy_err)
95 }
96 } else {
97 "Divergence (unknown cause)".to_string()
98 }
99 }),
100 divergence_energy_error: info.and_then(|d| d.energy_error),
101 }
102 }
103}
104
105#[derive(Debug, Copy, Clone)]
106pub enum Direction {
107 Forward,
108 Backward,
109}
110
111impl Distribution<Direction> for StandardUniform {
112 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
113 if rng.random::<bool>() {
114 Direction::Forward
115 } else {
116 Direction::Backward
117 }
118 }
119}
120
121pub enum LeapfrogResult<M: Math, P: Point<M>> {
122 Ok(State<M, P>),
123 Divergence(DivergenceInfo),
124 Err(M::LogpErr),
125}
126
127pub trait Point<M: Math>: Sized + SamplerStats<M> + Debug {
128 fn position(&self) -> &M::Vector;
129 fn gradient(&self) -> &M::Vector;
130 fn index_in_trajectory(&self) -> i64;
131 fn energy(&self) -> f64;
132 fn logp(&self) -> f64;
133
134 fn energy_error(&self) -> f64 {
135 self.energy() - self.initial_energy()
136 }
137
138 fn initial_energy(&self) -> f64;
139
140 fn new(math: &mut M) -> Self;
141 fn copy_into(&self, math: &mut M, other: &mut Self);
142}
143
144pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
146 type Point: Point<M>;
150
151 fn leapfrog<C: Collector<M, Self::Point>>(
160 &mut self,
161 math: &mut M,
162 start: &State<M, Self::Point>,
163 dir: Direction,
164 step_size_factor: f64,
165 energy_baseline: f64,
166 max_energy_error: f64,
167 collector: &mut C,
168 ) -> LeapfrogResult<M, Self::Point>;
169
170 fn is_turning(
171 &self,
172 math: &mut M,
173 state1: &State<M, Self::Point>,
174 state2: &State<M, Self::Point>,
175 ) -> bool;
176
177 fn init_state(
182 &mut self,
183 math: &mut M,
184 init: &[f64],
185 ) -> Result<State<M, Self::Point>, NutsError>;
186
187 fn init_state_untransformed(
189 &mut self,
190 math: &mut M,
191 init: &[f64],
192 ) -> Result<State<M, Self::Point>, NutsError>;
193
194 fn initialize_trajectory<R: rand::Rng + ?Sized>(
196 &self,
197 math: &mut M,
198 state: &mut State<M, Self::Point>,
199 resaple_velocity: bool,
200 rng: &mut R,
201 ) -> Result<(), NutsError>;
202
203 fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
204
205 fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
206
207 fn step_size(&self) -> f64;
208 fn step_size_mut(&mut self) -> &mut f64;
209
210 fn update_stats_options(
219 &mut self,
220 _math: &mut M,
221 current: <Self as SamplerStats<M>>::StatsOptions,
222 ) -> <Self as SamplerStats<M>>::StatsOptions {
223 current
224 }
225
226 fn momentum_decoherence_length(&self) -> Option<f64> {
233 None
234 }
235
236 fn momentum_decoherence_length_mut(&mut self) -> Option<&mut f64> {
237 None
238 }
239
240 fn partial_momentum_refresh<R: rand::Rng + ?Sized>(
248 &mut self,
249 math: &mut M,
250 state: &mut State<M, Self::Point>,
251 noise: &M::Vector,
252 rng: &mut R,
253 factor: f64,
254 ) -> Result<(), NutsError> {
255 let _ = (math, state, noise, rng, factor);
256 Ok(())
257 }
258}