1use rand::RngExt;
2use thiserror::Error;
3
4use std::{fmt::Debug, marker::PhantomData};
5
6use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point};
7use crate::math::logaddexp;
8use crate::state::State;
9
10use crate::math_base::Math;
11
12#[non_exhaustive]
13#[derive(Error, Debug)]
14pub enum NutsError {
15 #[error("Logp function returned error: {0:?}")]
16 LogpFailure(Box<dyn std::error::Error + Send + Sync>),
17
18 #[error("Could not serialize sample stats")]
19 SerializeFailure(),
20
21 #[error("Could not initialize state because of bad initial gradient: {0:?}")]
22 BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
23}
24
25pub type Result<T> = std::result::Result<T, NutsError>;
26
27pub trait Collector<M: Math, P: Point<M>> {
32 fn register_leapfrog(
33 &mut self,
34 _math: &mut M,
35 _start: &State<M, P>,
36 _end: &State<M, P>,
37 _divergence_info: Option<&DivergenceInfo>,
38 ) {
39 }
40 fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
41 fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
42}
43
44#[derive(Debug)]
46pub struct SampleInfo {
47 pub depth: u64,
49
50 pub divergence_info: Option<DivergenceInfo>,
53
54 pub reached_maxdepth: bool,
57
58 pub initial_energy: f64,
59 pub draw_energy: f64,
60}
61
62struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
64 left: State<M, H::Point>,
69 right: State<M, H::Point>,
70
71 draw: State<M, H::Point>,
74 log_size: f64,
75 depth: u64,
76
77 is_main: bool,
80 _phantom2: PhantomData<C>,
81}
82
83enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
84 Ok(NutsTree<M, H, C>),
87 Err(NutsError),
89 Turning(NutsTree<M, H, C>),
92 Diverging(NutsTree<M, H, C>, DivergenceInfo),
94}
95
96impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
97 fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
98 NutsTree {
99 right: state.clone(),
100 left: state.clone(),
101 draw: state,
102 depth: 0,
103 log_size: 0.,
104 is_main: true,
105 _phantom2: PhantomData,
106 }
107 }
108
109 #[allow(clippy::too_many_arguments)]
110 #[inline]
111 fn extend<R>(
112 mut self,
113 math: &mut M,
114 rng: &mut R,
115 hamiltonian: &mut H,
116 direction: Direction,
117 collector: &mut C,
118 options: &NutsOptions,
119 ) -> ExtendResult<M, H, C>
120 where
121 H: Hamiltonian<M>,
122 R: rand::Rng + ?Sized,
123 {
124 let mut other = match self.single_step(math, hamiltonian, direction, collector) {
125 Ok(Ok(tree)) => tree,
126 Ok(Err(info)) => return ExtendResult::Diverging(self, info),
127 Err(err) => return ExtendResult::Err(err),
128 };
129
130 while other.depth < self.depth {
131 use ExtendResult::*;
132 other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
133 Ok(tree) => tree,
134 Turning(_) => {
135 return Turning(self);
136 }
137 Diverging(_, info) => {
138 return Diverging(self, info);
139 }
140 Err(error) => {
141 return Err(error);
142 }
143 };
144 }
145
146 let (first, last) = match direction {
147 Direction::Forward => (&self.left, &other.right),
148 Direction::Backward => (&other.left, &self.right),
149 };
150
151 let turning = if options.check_turning {
152 let mut turning = hamiltonian.is_turning(math, first, last);
153 if self.depth > 0 {
154 if !turning {
155 turning = hamiltonian.is_turning(math, &self.right, &other.right);
156 }
157 if !turning {
158 turning = hamiltonian.is_turning(math, &self.left, &other.left);
159 }
160 }
161 turning
162 } else {
163 false
164 };
165
166 self.merge_into(math, other, rng, direction);
167
168 if turning {
169 ExtendResult::Turning(self)
170 } else {
171 ExtendResult::Ok(self)
172 }
173 }
174
175 fn merge_into<R: rand::Rng + ?Sized>(
176 &mut self,
177 _math: &mut M,
178 other: NutsTree<M, H, C>,
179 rng: &mut R,
180 direction: Direction,
181 ) {
182 assert!(self.depth == other.depth);
183 assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
184 match direction {
185 Direction::Forward => {
186 self.right = other.right;
187 }
188 Direction::Backward => {
189 self.left = other.left;
190 }
191 }
192 let log_size = logaddexp(self.log_size, other.log_size);
193
194 let self_log_size = if self.is_main {
195 assert!(self.left.index_in_trajectory() <= 0);
196 assert!(self.right.index_in_trajectory() >= 0);
197 self.log_size
198 } else {
199 log_size
200 };
201
202 if (other.log_size >= self_log_size)
203 || (rng.random_bool((other.log_size - self_log_size).exp()))
204 {
205 self.draw = other.draw;
206 }
207
208 self.depth += 1;
209 self.log_size = log_size;
210 }
211
212 fn single_step(
213 &self,
214 math: &mut M,
215 hamiltonian: &mut H,
216 direction: Direction,
217 collector: &mut C,
218 ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
219 let start = match direction {
220 Direction::Forward => &self.right,
221 Direction::Backward => &self.left,
222 };
223 let end = match hamiltonian.leapfrog(math, start, direction, collector) {
224 LeapfrogResult::Divergence(info) => return Ok(Err(info)),
225 LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
226 LeapfrogResult::Ok(end) => end,
227 };
228
229 let log_size = -end.point().energy_error();
230 Ok(Ok(NutsTree {
231 right: end.clone(),
232 left: end.clone(),
233 draw: end,
234 depth: 0,
235 log_size,
236 is_main: false,
237 _phantom2: PhantomData,
238 }))
239 }
240
241 fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
242 SampleInfo {
243 depth: self.depth,
244 divergence_info,
245 reached_maxdepth: maxdepth,
246 initial_energy: self.draw.point().initial_energy(),
247 draw_energy: self.draw.energy(),
248 }
249 }
250}
251
252pub struct NutsOptions {
253 pub maxdepth: u64,
254 pub mindepth: u64,
255 pub store_gradient: bool,
256 pub store_unconstrained: bool,
257 pub check_turning: bool,
258 pub store_divergences: bool,
259}
260
261impl Default for NutsOptions {
262 fn default() -> Self {
263 NutsOptions {
264 maxdepth: 10,
265 mindepth: 0,
266 store_gradient: false,
267 store_unconstrained: false,
268 check_turning: true,
269 store_divergences: false,
270 }
271 }
272}
273
274pub(crate) fn draw<M, H, R, C>(
275 math: &mut M,
276 init: &mut State<M, H::Point>,
277 rng: &mut R,
278 hamiltonian: &mut H,
279 options: &NutsOptions,
280 collector: &mut C,
281) -> Result<(State<M, H::Point>, SampleInfo)>
282where
283 M: Math,
284 H: Hamiltonian<M>,
285 R: rand::Rng + ?Sized,
286 C: Collector<M, H::Point>,
287{
288 hamiltonian.initialize_trajectory(math, init, rng)?;
289 collector.register_init(math, init, options);
290
291 let mut tree = NutsTree::new(init.clone());
292
293 if math.dim() == 0 {
294 let info = tree.info(false, None);
295 collector.register_draw(math, init, &info);
296 return Ok((init.clone(), info));
297 }
298
299 let options_no_check = NutsOptions {
300 check_turning: false,
301 ..*options
302 };
303
304 while tree.depth < options.maxdepth {
305 let direction: Direction = rng.random();
306 let current_options = if tree.depth < options.mindepth {
307 &options_no_check
308 } else {
309 options
310 };
311 tree = match tree.extend(
312 math,
313 rng,
314 hamiltonian,
315 direction,
316 collector,
317 current_options,
318 ) {
319 ExtendResult::Ok(tree) => tree,
320 ExtendResult::Turning(tree) => {
321 let info = tree.info(false, None);
322 collector.register_draw(math, &tree.draw, &info);
323 return Ok((tree.draw, info));
324 }
325 ExtendResult::Diverging(tree, info) => {
326 let info = tree.info(false, Some(info));
327 collector.register_draw(math, &tree.draw, &info);
328 return Ok((tree.draw, info));
329 }
330 ExtendResult::Err(error) => {
331 return Err(error);
332 }
333 };
334 }
335 let info = tree.info(true, None);
336 collector.register_draw(math, &tree.draw, &info);
337 Ok((tree.draw, info))
338}
339
340#[cfg(test)]
341mod tests {
342 use rand::rng;
343
344 use crate::{
345 Chain, Settings, adapt_strategy::test_logps::NormalLogp, cpu_math::CpuMath,
346 sampler::DiagGradNutsSettings,
347 };
348
349 #[test]
350 fn to_arrow() {
351 let ndim = 10;
352 let func = NormalLogp::new(ndim, 3.);
353 let math = CpuMath::new(func);
354
355 let settings = DiagGradNutsSettings::default();
356 let mut rng = rng();
357
358 let mut chain = settings.new_chain(0, math, &mut rng);
359
360 let (_, mut progress) = chain.draw().unwrap();
361 for _ in 0..10 {
362 let (_, prog) = chain.draw().unwrap();
363 progress = prog;
364 }
365
366 assert!(!progress.diverging);
367 }
368}