Skip to main content

nuts_rs/
nuts.rs

1//! Implement the recursive doubling tree expansion that is the heart of the NUTS algorithm.
2
3use rand::RngExt;
4use rand_distr::num_traits::ToPrimitive;
5use thiserror::Error;
6
7use std::{fmt::Debug, marker::PhantomData};
8
9use crate::dynamics::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point, State};
10use crate::math::{Math, logaddexp};
11
12#[non_exhaustive]
13#[derive(Error, Debug)]
14pub enum NutsError {
15    #[error("Logp function returned error: {0:?}")]
16    LogpFailure(Box<dyn std::error::Error + Send + Sync>),
17
18    #[error("Could not serialize sample stats")]
19    SerializeFailure(),
20
21    #[error("Could not initialize state because of bad initial gradient: {0:?}")]
22    BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
23}
24
25pub type Result<T> = std::result::Result<T, NutsError>;
26
27/// Callbacks for various events during a Nuts sampling step.
28///
29/// Collectors can compute statistics like the mean acceptance rate
30/// or collect data for mass matrix adaptation.
31pub trait Collector<M: Math, P: Point<M>> {
32    fn register_leapfrog(
33        &mut self,
34        _math: &mut M,
35        _start: &State<M, P>,
36        _end: &State<M, P>,
37        _divergence_info: Option<&DivergenceInfo>,
38    ) {
39    }
40    fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
41    fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
42}
43
44/// Information about a draw, exported as part of the sampler stats
45#[derive(Debug)]
46pub struct SampleInfo {
47    /// The depth of the trajectory that this point was sampled from
48    pub depth: u64,
49
50    /// More detailed information about a divergence that might have
51    /// occured in the trajectory.
52    pub divergence_info: Option<DivergenceInfo>,
53
54    /// Whether the trajectory was terminated because it reached
55    /// the maximum tree depth.
56    pub reached_maxdepth: bool,
57}
58
59/// A part of the trajectory tree during NUTS sampling.
60struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
61    /// The left position of the tree.
62    ///
63    /// The left side always has the smaller index_in_trajectory.
64    /// Leapfrogs in backward direction will replace the left.
65    left: State<M, H::Point>,
66    right: State<M, H::Point>,
67
68    /// A draw from the trajectory between left and right using
69    /// multinomial sampling.
70    draw: State<M, H::Point>,
71    log_size: f64,
72    depth: u64,
73
74    /// A tree is the main tree if it contains the initial point
75    /// of the trajectory.
76    is_main: bool,
77    _phantom2: PhantomData<C>,
78}
79
80enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
81    /// The tree extension succeeded properly, and the termination
82    /// criterion was not reached.
83    Ok(NutsTree<M, H, C>),
84    /// An unrecoverable error happend during a leapfrog step
85    Err(NutsError),
86    /// Tree extension succeeded and the termination criterion
87    /// was reached.
88    Turning(NutsTree<M, H, C>),
89    /// A divergence happend during tree extension.
90    Diverging(NutsTree<M, H, C>, DivergenceInfo),
91}
92
93impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
94    fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
95        NutsTree {
96            right: state.clone(),
97            left: state.clone(),
98            draw: state,
99            depth: 0,
100            log_size: 0.,
101            is_main: true,
102            _phantom2: PhantomData,
103        }
104    }
105
106    #[allow(clippy::too_many_arguments)]
107    #[inline]
108    fn extend<R>(
109        mut self,
110        math: &mut M,
111        rng: &mut R,
112        hamiltonian: &mut H,
113        direction: Direction,
114        collector: &mut C,
115        options: &NutsOptions,
116    ) -> ExtendResult<M, H, C>
117    where
118        H: Hamiltonian<M>,
119        R: rand::Rng + ?Sized,
120    {
121        let mut other = match self.single_step(math, hamiltonian, direction, options, collector) {
122            Ok(Ok(tree)) => tree,
123            Ok(Err(info)) => return ExtendResult::Diverging(self, info),
124            Err(err) => return ExtendResult::Err(err),
125        };
126
127        while other.depth < self.depth {
128            use ExtendResult::*;
129            other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
130                Ok(tree) => tree,
131                Turning(_) => {
132                    return Turning(self);
133                }
134                Diverging(_, info) => {
135                    return Diverging(self, info);
136                }
137                Err(error) => {
138                    return Err(error);
139                }
140            };
141        }
142
143        let (first, last) = match direction {
144            Direction::Forward => (&self.left, &other.right),
145            Direction::Backward => (&other.left, &self.right),
146        };
147
148        let turning = if options.check_turning {
149            let mut turning = hamiltonian.is_turning(math, first, last);
150            if self.depth > 0 {
151                if !turning {
152                    turning = hamiltonian.is_turning(math, &self.right, &other.right);
153                }
154                if !turning {
155                    turning = hamiltonian.is_turning(math, &self.left, &other.left);
156                }
157            }
158            turning
159        } else {
160            false
161        };
162
163        self.merge_into(math, other, rng, direction);
164
165        if turning {
166            ExtendResult::Turning(self)
167        } else {
168            ExtendResult::Ok(self)
169        }
170    }
171
172    fn merge_into<R: rand::Rng + ?Sized>(
173        &mut self,
174        _math: &mut M,
175        other: NutsTree<M, H, C>,
176        rng: &mut R,
177        direction: Direction,
178    ) {
179        assert!(self.depth == other.depth);
180        assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
181        match direction {
182            Direction::Forward => {
183                self.right = other.right;
184            }
185            Direction::Backward => {
186                self.left = other.left;
187            }
188        }
189        let log_size = logaddexp(self.log_size, other.log_size);
190
191        let self_log_size = if self.is_main {
192            assert!(self.left.index_in_trajectory() <= 0);
193            assert!(self.right.index_in_trajectory() >= 0);
194            self.log_size
195        } else {
196            log_size
197        };
198
199        if (other.log_size >= self_log_size)
200            || (rng.random_bool((other.log_size - self_log_size).exp()))
201        {
202            self.draw = other.draw;
203        }
204
205        self.depth += 1;
206        self.log_size = log_size;
207    }
208
209    fn single_step(
210        &self,
211        math: &mut M,
212        hamiltonian: &mut H,
213        direction: Direction,
214        options: &NutsOptions,
215        collector: &mut C,
216    ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
217        let start = match direction {
218            Direction::Forward => &self.right,
219            Direction::Backward => &self.left,
220        };
221        let end = match hamiltonian.leapfrog(
222            math,
223            start,
224            direction,
225            1.0,
226            start.point().initial_energy(),
227            options.max_energy_error,
228            collector,
229        ) {
230            LeapfrogResult::Divergence(info) => return Ok(Err(info)),
231            LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
232            LeapfrogResult::Ok(end) => end,
233        };
234
235        let log_size = -end.point().energy_error();
236        Ok(Ok(NutsTree {
237            right: end.clone(),
238            left: end.clone(),
239            draw: end,
240            depth: 0,
241            log_size,
242            is_main: false,
243            _phantom2: PhantomData,
244        }))
245    }
246
247    fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
248        SampleInfo {
249            depth: self.depth,
250            divergence_info,
251            reached_maxdepth: maxdepth,
252        }
253    }
254}
255
256#[derive(Debug, Clone)]
257pub struct NutsOptions {
258    pub maxdepth: u64,
259    pub mindepth: u64,
260    pub check_turning: bool,
261    pub store_divergences: bool,
262    pub target_integration_time: Option<f64>,
263    pub extra_doublings: u64,
264    pub max_energy_error: f64,
265}
266
267impl Default for NutsOptions {
268    fn default() -> Self {
269        NutsOptions {
270            maxdepth: 10,
271            mindepth: 0,
272            check_turning: true,
273            store_divergences: false,
274            target_integration_time: None,
275            extra_doublings: 0,
276            max_energy_error: 1000.0,
277        }
278    }
279}
280
281pub(crate) fn draw<M, H, R, C>(
282    math: &mut M,
283    init: &mut State<M, H::Point>,
284    rng: &mut R,
285    hamiltonian: &mut H,
286    options: &NutsOptions,
287    collector: &mut C,
288) -> Result<(State<M, H::Point>, SampleInfo)>
289where
290    M: Math,
291    H: Hamiltonian<M>,
292    R: rand::Rng + ?Sized,
293    C: Collector<M, H::Point>,
294{
295    hamiltonian.initialize_trajectory(math, init, true, rng)?;
296    collector.register_init(math, init, options);
297
298    let mut tree = NutsTree::new(init.clone());
299
300    let (mindepth, maxdepth) = if let Some(target_time) = options.target_integration_time {
301        let step_size = hamiltonian.step_size();
302        let max_steps = (target_time / step_size).ceil() as u64;
303        let mindepth = (max_steps as f64)
304            .log2()
305            .floor()
306            .to_u64()
307            .unwrap()
308            .max(options.mindepth);
309        let maxdepth = (max_steps as f64)
310            .log2()
311            .ceil()
312            .to_u64()
313            .unwrap()
314            .max(mindepth)
315            .min(options.maxdepth);
316
317        (mindepth, maxdepth)
318    } else {
319        (options.mindepth, options.maxdepth)
320    };
321
322    if math.dim() == 0 {
323        let info = tree.info(false, None);
324        collector.register_draw(math, init, &info);
325        return Ok((init.clone(), info));
326    }
327
328    let options_no_check = NutsOptions {
329        check_turning: false,
330        ..*options
331    };
332
333    while tree.depth < maxdepth {
334        let direction: Direction = rng.random();
335        let current_options = if tree.depth < mindepth {
336            &options_no_check
337        } else {
338            options
339        };
340        tree = match tree.extend(
341            math,
342            rng,
343            hamiltonian,
344            direction,
345            collector,
346            current_options,
347        ) {
348            ExtendResult::Ok(tree) => tree,
349            ExtendResult::Turning(mut tree) => {
350                for _ in 0..options.extra_doublings {
351                    tree = match tree.extend(
352                        math,
353                        rng,
354                        hamiltonian,
355                        direction,
356                        collector,
357                        &options_no_check,
358                    ) {
359                        ExtendResult::Ok(tree) => tree,
360                        ExtendResult::Turning(tree) => tree,
361                        ExtendResult::Diverging(tree, info) => {
362                            let info = tree.info(false, Some(info));
363                            collector.register_draw(math, &tree.draw, &info);
364                            return Ok((tree.draw, info));
365                        }
366                        ExtendResult::Err(error) => {
367                            return Err(error);
368                        }
369                    }
370                }
371                let info = tree.info(false, None);
372                collector.register_draw(math, &tree.draw, &info);
373                return Ok((tree.draw, info));
374            }
375            ExtendResult::Diverging(tree, info) => {
376                let info = tree.info(false, Some(info));
377                collector.register_draw(math, &tree.draw, &info);
378                return Ok((tree.draw, info));
379            }
380            ExtendResult::Err(error) => {
381                return Err(error);
382            }
383        };
384    }
385    let info = tree.info(true, None);
386    collector.register_draw(math, &tree.draw, &info);
387    Ok((tree.draw, info))
388}
389
390#[cfg(test)]
391mod tests {
392    use rand::rng;
393
394    use crate::{
395        Chain, Settings, math::test_logps::NormalLogp, math::CpuMath,
396        sampler::DiagNutsSettings,
397    };
398
399    #[test]
400    fn to_arrow() {
401        let ndim = 10;
402        let func = NormalLogp::new(ndim, 3.);
403        let math = CpuMath::new(func);
404
405        let settings = DiagNutsSettings::default();
406        let mut rng = rng();
407
408        let mut chain = settings.new_chain(0, math, &mut rng);
409
410        chain.set_position(&vec![0.0; ndim]).unwrap();
411
412        let (_, mut progress) = chain.draw().unwrap();
413        for _ in 0..10 {
414            let (_, prog) = chain.draw().unwrap();
415            progress = prog;
416        }
417
418        assert!(!progress.diverging);
419    }
420}