1use thiserror::Error;
2
3use std::{fmt::Debug, marker::PhantomData};
4
5use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point};
6use crate::math::logaddexp;
7use crate::state::State;
8
9use crate::math_base::Math;
10
11#[non_exhaustive]
12#[derive(Error, Debug)]
13pub enum NutsError {
14 #[error("Logp function returned error: {0:?}")]
15 LogpFailure(Box<dyn std::error::Error + Send + Sync>),
16
17 #[error("Could not serialize sample stats")]
18 SerializeFailure(),
19
20 #[error("Could not initialize state because of bad initial gradient: {0:?}")]
21 BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
22}
23
24pub type Result<T> = std::result::Result<T, NutsError>;
25
26pub trait Collector<M: Math, P: Point<M>> {
31 fn register_leapfrog(
32 &mut self,
33 _math: &mut M,
34 _start: &State<M, P>,
35 _end: &State<M, P>,
36 _divergence_info: Option<&DivergenceInfo>,
37 ) {
38 }
39 fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
40 fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
41}
42
43#[derive(Debug)]
45pub struct SampleInfo {
46 pub depth: u64,
48
49 pub divergence_info: Option<DivergenceInfo>,
52
53 pub reached_maxdepth: bool,
56
57 pub initial_energy: f64,
58 pub draw_energy: f64,
59}
60
61struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
63 left: State<M, H::Point>,
68 right: State<M, H::Point>,
69
70 draw: State<M, H::Point>,
73 log_size: f64,
74 depth: u64,
75
76 is_main: bool,
79 _phantom2: PhantomData<C>,
80}
81
82enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
83 Ok(NutsTree<M, H, C>),
86 Err(NutsError),
88 Turning(NutsTree<M, H, C>),
91 Diverging(NutsTree<M, H, C>, DivergenceInfo),
93}
94
95impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
96 fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
97 NutsTree {
98 right: state.clone(),
99 left: state.clone(),
100 draw: state,
101 depth: 0,
102 log_size: 0.,
103 is_main: true,
104 _phantom2: PhantomData,
105 }
106 }
107
108 #[allow(clippy::too_many_arguments)]
109 #[inline]
110 fn extend<R>(
111 mut self,
112 math: &mut M,
113 rng: &mut R,
114 hamiltonian: &mut H,
115 direction: Direction,
116 collector: &mut C,
117 options: &NutsOptions,
118 ) -> ExtendResult<M, H, C>
119 where
120 H: Hamiltonian<M>,
121 R: rand::Rng + ?Sized,
122 {
123 let mut other = match self.single_step(math, hamiltonian, direction, collector) {
124 Ok(Ok(tree)) => tree,
125 Ok(Err(info)) => return ExtendResult::Diverging(self, info),
126 Err(err) => return ExtendResult::Err(err),
127 };
128
129 while other.depth < self.depth {
130 use ExtendResult::*;
131 other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
132 Ok(tree) => tree,
133 Turning(_) => {
134 return Turning(self);
135 }
136 Diverging(_, info) => {
137 return Diverging(self, info);
138 }
139 Err(error) => {
140 return Err(error);
141 }
142 };
143 }
144
145 let (first, last) = match direction {
146 Direction::Forward => (&self.left, &other.right),
147 Direction::Backward => (&other.left, &self.right),
148 };
149
150 let turning = if options.check_turning {
151 let mut turning = hamiltonian.is_turning(math, first, last);
152 if self.depth > 0 {
153 if !turning {
154 turning = hamiltonian.is_turning(math, &self.right, &other.right);
155 }
156 if !turning {
157 turning = hamiltonian.is_turning(math, &self.left, &other.left);
158 }
159 }
160 turning
161 } else {
162 false
163 };
164
165 self.merge_into(math, other, rng, direction);
166
167 if turning {
168 ExtendResult::Turning(self)
169 } else {
170 ExtendResult::Ok(self)
171 }
172 }
173
174 fn merge_into<R: rand::Rng + ?Sized>(
175 &mut self,
176 _math: &mut M,
177 other: NutsTree<M, H, C>,
178 rng: &mut R,
179 direction: Direction,
180 ) {
181 assert!(self.depth == other.depth);
182 assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
183 match direction {
184 Direction::Forward => {
185 self.right = other.right;
186 }
187 Direction::Backward => {
188 self.left = other.left;
189 }
190 }
191 let log_size = logaddexp(self.log_size, other.log_size);
192
193 let self_log_size = if self.is_main {
194 assert!(self.left.index_in_trajectory() <= 0);
195 assert!(self.right.index_in_trajectory() >= 0);
196 self.log_size
197 } else {
198 log_size
199 };
200
201 if (other.log_size >= self_log_size)
202 || (rng.random_bool((other.log_size - self_log_size).exp()))
203 {
204 self.draw = other.draw;
205 }
206
207 self.depth += 1;
208 self.log_size = log_size;
209 }
210
211 fn single_step(
212 &self,
213 math: &mut M,
214 hamiltonian: &mut H,
215 direction: Direction,
216 collector: &mut C,
217 ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
218 let start = match direction {
219 Direction::Forward => &self.right,
220 Direction::Backward => &self.left,
221 };
222 let end = match hamiltonian.leapfrog(math, start, direction, collector) {
223 LeapfrogResult::Divergence(info) => return Ok(Err(info)),
224 LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
225 LeapfrogResult::Ok(end) => end,
226 };
227
228 let log_size = -end.point().energy_error();
229 Ok(Ok(NutsTree {
230 right: end.clone(),
231 left: end.clone(),
232 draw: end,
233 depth: 0,
234 log_size,
235 is_main: false,
236 _phantom2: PhantomData,
237 }))
238 }
239
240 fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
241 SampleInfo {
242 depth: self.depth,
243 divergence_info,
244 reached_maxdepth: maxdepth,
245 initial_energy: self.draw.point().initial_energy(),
246 draw_energy: self.draw.energy(),
247 }
248 }
249}
250
251pub struct NutsOptions {
252 pub maxdepth: u64,
253 pub store_gradient: bool,
254 pub store_unconstrained: bool,
255 pub check_turning: bool,
256 pub store_divergences: bool,
257}
258
259pub(crate) fn draw<M, H, R, C>(
260 math: &mut M,
261 init: &mut State<M, H::Point>,
262 rng: &mut R,
263 hamiltonian: &mut H,
264 options: &NutsOptions,
265 collector: &mut C,
266) -> Result<(State<M, H::Point>, SampleInfo)>
267where
268 M: Math,
269 H: Hamiltonian<M>,
270 R: rand::Rng + ?Sized,
271 C: Collector<M, H::Point>,
272{
273 hamiltonian.initialize_trajectory(math, init, rng)?;
274 collector.register_init(math, init, options);
275
276 let mut tree = NutsTree::new(init.clone());
277
278 if math.dim() == 0 {
279 let info = tree.info(false, None);
280 collector.register_draw(math, init, &info);
281 return Ok((init.clone(), info));
282 }
283
284 while tree.depth < options.maxdepth {
285 let direction: Direction = rng.random();
286 tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) {
287 ExtendResult::Ok(tree) => tree,
288 ExtendResult::Turning(tree) => {
289 let info = tree.info(false, None);
290 collector.register_draw(math, &tree.draw, &info);
291 return Ok((tree.draw, info));
292 }
293 ExtendResult::Diverging(tree, info) => {
294 let info = tree.info(false, Some(info));
295 collector.register_draw(math, &tree.draw, &info);
296 return Ok((tree.draw, info));
297 }
298 ExtendResult::Err(error) => {
299 return Err(error);
300 }
301 };
302 }
303 let info = tree.info(true, None);
304 collector.register_draw(math, &tree.draw, &info);
305 Ok((tree.draw, info))
306}
307
308#[cfg(test)]
309mod tests {
310 use rand::{rng, rngs::ThreadRng};
311
312 use crate::{
313 adapt_strategy::test_logps::NormalLogp,
314 chain::NutsChain,
315 cpu_math::CpuMath,
316 sampler::DiagGradNutsSettings,
317 sampler_stats::{SamplerStats, StatTraceBuilder},
318 Chain, Settings,
319 };
320
321 #[test]
322 fn to_arrow() {
323 let ndim = 10;
324 let func = NormalLogp::new(ndim, 3.);
325 let math = CpuMath::new(func);
326
327 let settings = DiagGradNutsSettings::default();
328 let mut rng = rng();
329
330 let mut chain = settings.new_chain(0, math, &mut rng);
331
332 let opt_settings = settings.stats_options(&chain);
333 let mut builder = chain.new_builder(opt_settings, &settings, ndim);
334
335 let (_, mut progress) = chain.draw().unwrap();
336 for _ in 0..10 {
337 let (_, prog) = chain.draw().unwrap();
338 progress = prog;
339 builder.append_value(None, &chain);
340 }
341
342 assert!(!progress.diverging);
343 StatTraceBuilder::<_, NutsChain<_, ThreadRng, _>>::finalize(builder);
344 }
345}