1use rand::RngExt;
4use rand_distr::num_traits::ToPrimitive;
5use thiserror::Error;
6
7use std::{fmt::Debug, marker::PhantomData};
8
9use crate::dynamics::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point, State};
10use crate::math::{Math, logaddexp};
11
12#[non_exhaustive]
13#[derive(Error, Debug)]
14pub enum NutsError {
15 #[error("Logp function returned error: {0:?}")]
16 LogpFailure(Box<dyn std::error::Error + Send + Sync>),
17
18 #[error("Could not serialize sample stats")]
19 SerializeFailure(),
20
21 #[error("Could not initialize state because of bad initial gradient: {0:?}")]
22 BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
23}
24
25pub type Result<T> = std::result::Result<T, NutsError>;
26
27pub trait Collector<M: Math, P: Point<M>> {
32 fn register_leapfrog(
33 &mut self,
34 _math: &mut M,
35 _start: &State<M, P>,
36 _end: &State<M, P>,
37 _divergence_info: Option<&DivergenceInfo>,
38 ) {
39 }
40 fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
41 fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
42}
43
44#[derive(Debug)]
46pub struct SampleInfo {
47 pub depth: u64,
49
50 pub divergence_info: Option<DivergenceInfo>,
53
54 pub reached_maxdepth: bool,
57}
58
59struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
61 left: State<M, H::Point>,
66 right: State<M, H::Point>,
67
68 draw: State<M, H::Point>,
71 log_size: f64,
72 depth: u64,
73
74 is_main: bool,
77 _phantom2: PhantomData<C>,
78}
79
80enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
81 Ok(NutsTree<M, H, C>),
84 Err(NutsError),
86 Turning(NutsTree<M, H, C>),
89 Diverging(NutsTree<M, H, C>, DivergenceInfo),
91}
92
93impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
94 fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
95 NutsTree {
96 right: state.clone(),
97 left: state.clone(),
98 draw: state,
99 depth: 0,
100 log_size: 0.,
101 is_main: true,
102 _phantom2: PhantomData,
103 }
104 }
105
106 #[allow(clippy::too_many_arguments)]
107 #[inline]
108 fn extend<R>(
109 mut self,
110 math: &mut M,
111 rng: &mut R,
112 hamiltonian: &mut H,
113 direction: Direction,
114 collector: &mut C,
115 options: &NutsOptions,
116 ) -> ExtendResult<M, H, C>
117 where
118 H: Hamiltonian<M>,
119 R: rand::Rng + ?Sized,
120 {
121 let mut other = match self.single_step(math, hamiltonian, direction, options, collector) {
122 Ok(Ok(tree)) => tree,
123 Ok(Err(info)) => return ExtendResult::Diverging(self, info),
124 Err(err) => return ExtendResult::Err(err),
125 };
126
127 while other.depth < self.depth {
128 use ExtendResult::*;
129 other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
130 Ok(tree) => tree,
131 Turning(_) => {
132 return Turning(self);
133 }
134 Diverging(_, info) => {
135 return Diverging(self, info);
136 }
137 Err(error) => {
138 return Err(error);
139 }
140 };
141 }
142
143 let (first, last) = match direction {
144 Direction::Forward => (&self.left, &other.right),
145 Direction::Backward => (&other.left, &self.right),
146 };
147
148 let turning = if options.check_turning {
149 let mut turning = hamiltonian.is_turning(math, first, last);
150 if self.depth > 0 {
151 if !turning {
152 turning = hamiltonian.is_turning(math, &self.right, &other.right);
153 }
154 if !turning {
155 turning = hamiltonian.is_turning(math, &self.left, &other.left);
156 }
157 }
158 turning
159 } else {
160 false
161 };
162
163 self.merge_into(math, other, rng, direction);
164
165 if turning {
166 ExtendResult::Turning(self)
167 } else {
168 ExtendResult::Ok(self)
169 }
170 }
171
172 fn merge_into<R: rand::Rng + ?Sized>(
173 &mut self,
174 _math: &mut M,
175 other: NutsTree<M, H, C>,
176 rng: &mut R,
177 direction: Direction,
178 ) {
179 assert!(self.depth == other.depth);
180 assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
181 match direction {
182 Direction::Forward => {
183 self.right = other.right;
184 }
185 Direction::Backward => {
186 self.left = other.left;
187 }
188 }
189 let log_size = logaddexp(self.log_size, other.log_size);
190
191 let self_log_size = if self.is_main {
192 assert!(self.left.index_in_trajectory() <= 0);
193 assert!(self.right.index_in_trajectory() >= 0);
194 self.log_size
195 } else {
196 log_size
197 };
198
199 if (other.log_size >= self_log_size)
200 || (rng.random_bool((other.log_size - self_log_size).exp()))
201 {
202 self.draw = other.draw;
203 }
204
205 self.depth += 1;
206 self.log_size = log_size;
207 }
208
209 fn single_step(
210 &self,
211 math: &mut M,
212 hamiltonian: &mut H,
213 direction: Direction,
214 options: &NutsOptions,
215 collector: &mut C,
216 ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
217 let start = match direction {
218 Direction::Forward => &self.right,
219 Direction::Backward => &self.left,
220 };
221 let end = match hamiltonian.leapfrog(
222 math,
223 start,
224 direction,
225 1.0,
226 start.point().initial_energy(),
227 options.max_energy_error,
228 collector,
229 ) {
230 LeapfrogResult::Divergence(info) => return Ok(Err(info)),
231 LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
232 LeapfrogResult::Ok(end) => end,
233 };
234
235 let log_size = -end.point().energy_error();
236 Ok(Ok(NutsTree {
237 right: end.clone(),
238 left: end.clone(),
239 draw: end,
240 depth: 0,
241 log_size,
242 is_main: false,
243 _phantom2: PhantomData,
244 }))
245 }
246
247 fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
248 SampleInfo {
249 depth: self.depth,
250 divergence_info,
251 reached_maxdepth: maxdepth,
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
257pub struct NutsOptions {
258 pub maxdepth: u64,
259 pub mindepth: u64,
260 pub check_turning: bool,
261 pub store_divergences: bool,
262 pub target_integration_time: Option<f64>,
263 pub extra_doublings: u64,
264 pub max_energy_error: f64,
265}
266
267impl Default for NutsOptions {
268 fn default() -> Self {
269 NutsOptions {
270 maxdepth: 10,
271 mindepth: 0,
272 check_turning: true,
273 store_divergences: false,
274 target_integration_time: None,
275 extra_doublings: 0,
276 max_energy_error: 1000.0,
277 }
278 }
279}
280
281pub(crate) fn draw<M, H, R, C>(
282 math: &mut M,
283 init: &mut State<M, H::Point>,
284 rng: &mut R,
285 hamiltonian: &mut H,
286 options: &NutsOptions,
287 collector: &mut C,
288) -> Result<(State<M, H::Point>, SampleInfo)>
289where
290 M: Math,
291 H: Hamiltonian<M>,
292 R: rand::Rng + ?Sized,
293 C: Collector<M, H::Point>,
294{
295 hamiltonian.initialize_trajectory(math, init, true, rng)?;
296 collector.register_init(math, init, options);
297
298 let mut tree = NutsTree::new(init.clone());
299
300 let (mindepth, maxdepth) = if let Some(target_time) = options.target_integration_time {
301 let step_size = hamiltonian.step_size();
302 let max_steps = (target_time / step_size).ceil() as u64;
303 let mindepth = (max_steps as f64)
304 .log2()
305 .floor()
306 .to_u64()
307 .unwrap()
308 .max(options.mindepth);
309 let maxdepth = (max_steps as f64)
310 .log2()
311 .ceil()
312 .to_u64()
313 .unwrap()
314 .max(mindepth)
315 .min(options.maxdepth);
316
317 (mindepth, maxdepth)
318 } else {
319 (options.mindepth, options.maxdepth)
320 };
321
322 if math.dim() == 0 {
323 let info = tree.info(false, None);
324 collector.register_draw(math, init, &info);
325 return Ok((init.clone(), info));
326 }
327
328 let options_no_check = NutsOptions {
329 check_turning: false,
330 ..*options
331 };
332
333 while tree.depth < maxdepth {
334 let direction: Direction = rng.random();
335 let current_options = if tree.depth < mindepth {
336 &options_no_check
337 } else {
338 options
339 };
340 tree = match tree.extend(
341 math,
342 rng,
343 hamiltonian,
344 direction,
345 collector,
346 current_options,
347 ) {
348 ExtendResult::Ok(tree) => tree,
349 ExtendResult::Turning(mut tree) => {
350 for _ in 0..options.extra_doublings {
351 tree = match tree.extend(
352 math,
353 rng,
354 hamiltonian,
355 direction,
356 collector,
357 &options_no_check,
358 ) {
359 ExtendResult::Ok(tree) => tree,
360 ExtendResult::Turning(tree) => tree,
361 ExtendResult::Diverging(tree, info) => {
362 let info = tree.info(false, Some(info));
363 collector.register_draw(math, &tree.draw, &info);
364 return Ok((tree.draw, info));
365 }
366 ExtendResult::Err(error) => {
367 return Err(error);
368 }
369 }
370 }
371 let info = tree.info(false, None);
372 collector.register_draw(math, &tree.draw, &info);
373 return Ok((tree.draw, info));
374 }
375 ExtendResult::Diverging(tree, info) => {
376 let info = tree.info(false, Some(info));
377 collector.register_draw(math, &tree.draw, &info);
378 return Ok((tree.draw, info));
379 }
380 ExtendResult::Err(error) => {
381 return Err(error);
382 }
383 };
384 }
385 let info = tree.info(true, None);
386 collector.register_draw(math, &tree.draw, &info);
387 Ok((tree.draw, info))
388}
389
390#[cfg(test)]
391mod tests {
392 use rand::rng;
393
394 use crate::{
395 Chain, Settings, math::test_logps::NormalLogp, math::CpuMath,
396 sampler::DiagNutsSettings,
397 };
398
399 #[test]
400 fn to_arrow() {
401 let ndim = 10;
402 let func = NormalLogp::new(ndim, 3.);
403 let math = CpuMath::new(func);
404
405 let settings = DiagNutsSettings::default();
406 let mut rng = rng();
407
408 let mut chain = settings.new_chain(0, math, &mut rng);
409
410 chain.set_position(&vec![0.0; ndim]).unwrap();
411
412 let (_, mut progress) = chain.draw().unwrap();
413 for _ in 0..10 {
414 let (_, prog) = chain.draw().unwrap();
415 progress = prog;
416 }
417
418 assert!(!progress.diverging);
419 }
420}