Skip to main content

nuts_rs/
nuts.rs

1use rand::RngExt;
2use thiserror::Error;
3
4use std::{fmt::Debug, marker::PhantomData};
5
6use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point};
7use crate::math::logaddexp;
8use crate::state::State;
9
10use crate::math_base::Math;
11
12#[non_exhaustive]
13#[derive(Error, Debug)]
14pub enum NutsError {
15    #[error("Logp function returned error: {0:?}")]
16    LogpFailure(Box<dyn std::error::Error + Send + Sync>),
17
18    #[error("Could not serialize sample stats")]
19    SerializeFailure(),
20
21    #[error("Could not initialize state because of bad initial gradient: {0:?}")]
22    BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
23}
24
25pub type Result<T> = std::result::Result<T, NutsError>;
26
27/// Callbacks for various events during a Nuts sampling step.
28///
29/// Collectors can compute statistics like the mean acceptance rate
30/// or collect data for mass matrix adaptation.
31pub trait Collector<M: Math, P: Point<M>> {
32    fn register_leapfrog(
33        &mut self,
34        _math: &mut M,
35        _start: &State<M, P>,
36        _end: &State<M, P>,
37        _divergence_info: Option<&DivergenceInfo>,
38    ) {
39    }
40    fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
41    fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
42}
43
44/// Information about a draw, exported as part of the sampler stats
45#[derive(Debug)]
46pub struct SampleInfo {
47    /// The depth of the trajectory that this point was sampled from
48    pub depth: u64,
49
50    /// More detailed information about a divergence that might have
51    /// occured in the trajectory.
52    pub divergence_info: Option<DivergenceInfo>,
53
54    /// Whether the trajectory was terminated because it reached
55    /// the maximum tree depth.
56    pub reached_maxdepth: bool,
57
58    pub initial_energy: f64,
59    pub draw_energy: f64,
60}
61
62/// A part of the trajectory tree during NUTS sampling.
63struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
64    /// The left position of the tree.
65    ///
66    /// The left side always has the smaller index_in_trajectory.
67    /// Leapfrogs in backward direction will replace the left.
68    left: State<M, H::Point>,
69    right: State<M, H::Point>,
70
71    /// A draw from the trajectory between left and right using
72    /// multinomial sampling.
73    draw: State<M, H::Point>,
74    log_size: f64,
75    depth: u64,
76
77    /// A tree is the main tree if it contains the initial point
78    /// of the trajectory.
79    is_main: bool,
80    _phantom2: PhantomData<C>,
81}
82
83enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
84    /// The tree extension succeeded properly, and the termination
85    /// criterion was not reached.
86    Ok(NutsTree<M, H, C>),
87    /// An unrecoverable error happend during a leapfrog step
88    Err(NutsError),
89    /// Tree extension succeeded and the termination criterion
90    /// was reached.
91    Turning(NutsTree<M, H, C>),
92    /// A divergence happend during tree extension.
93    Diverging(NutsTree<M, H, C>, DivergenceInfo),
94}
95
96impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
97    fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
98        NutsTree {
99            right: state.clone(),
100            left: state.clone(),
101            draw: state,
102            depth: 0,
103            log_size: 0.,
104            is_main: true,
105            _phantom2: PhantomData,
106        }
107    }
108
109    #[allow(clippy::too_many_arguments)]
110    #[inline]
111    fn extend<R>(
112        mut self,
113        math: &mut M,
114        rng: &mut R,
115        hamiltonian: &mut H,
116        direction: Direction,
117        collector: &mut C,
118        options: &NutsOptions,
119    ) -> ExtendResult<M, H, C>
120    where
121        H: Hamiltonian<M>,
122        R: rand::Rng + ?Sized,
123    {
124        let mut other = match self.single_step(math, hamiltonian, direction, collector) {
125            Ok(Ok(tree)) => tree,
126            Ok(Err(info)) => return ExtendResult::Diverging(self, info),
127            Err(err) => return ExtendResult::Err(err),
128        };
129
130        while other.depth < self.depth {
131            use ExtendResult::*;
132            other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
133                Ok(tree) => tree,
134                Turning(_) => {
135                    return Turning(self);
136                }
137                Diverging(_, info) => {
138                    return Diverging(self, info);
139                }
140                Err(error) => {
141                    return Err(error);
142                }
143            };
144        }
145
146        let (first, last) = match direction {
147            Direction::Forward => (&self.left, &other.right),
148            Direction::Backward => (&other.left, &self.right),
149        };
150
151        let turning = if options.check_turning {
152            let mut turning = hamiltonian.is_turning(math, first, last);
153            if self.depth > 0 {
154                if !turning {
155                    turning = hamiltonian.is_turning(math, &self.right, &other.right);
156                }
157                if !turning {
158                    turning = hamiltonian.is_turning(math, &self.left, &other.left);
159                }
160            }
161            turning
162        } else {
163            false
164        };
165
166        self.merge_into(math, other, rng, direction);
167
168        if turning {
169            ExtendResult::Turning(self)
170        } else {
171            ExtendResult::Ok(self)
172        }
173    }
174
175    fn merge_into<R: rand::Rng + ?Sized>(
176        &mut self,
177        _math: &mut M,
178        other: NutsTree<M, H, C>,
179        rng: &mut R,
180        direction: Direction,
181    ) {
182        assert!(self.depth == other.depth);
183        assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
184        match direction {
185            Direction::Forward => {
186                self.right = other.right;
187            }
188            Direction::Backward => {
189                self.left = other.left;
190            }
191        }
192        let log_size = logaddexp(self.log_size, other.log_size);
193
194        let self_log_size = if self.is_main {
195            assert!(self.left.index_in_trajectory() <= 0);
196            assert!(self.right.index_in_trajectory() >= 0);
197            self.log_size
198        } else {
199            log_size
200        };
201
202        if (other.log_size >= self_log_size)
203            || (rng.random_bool((other.log_size - self_log_size).exp()))
204        {
205            self.draw = other.draw;
206        }
207
208        self.depth += 1;
209        self.log_size = log_size;
210    }
211
212    fn single_step(
213        &self,
214        math: &mut M,
215        hamiltonian: &mut H,
216        direction: Direction,
217        collector: &mut C,
218    ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
219        let start = match direction {
220            Direction::Forward => &self.right,
221            Direction::Backward => &self.left,
222        };
223        let end = match hamiltonian.leapfrog(math, start, direction, collector) {
224            LeapfrogResult::Divergence(info) => return Ok(Err(info)),
225            LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
226            LeapfrogResult::Ok(end) => end,
227        };
228
229        let log_size = -end.point().energy_error();
230        Ok(Ok(NutsTree {
231            right: end.clone(),
232            left: end.clone(),
233            draw: end,
234            depth: 0,
235            log_size,
236            is_main: false,
237            _phantom2: PhantomData,
238        }))
239    }
240
241    fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
242        SampleInfo {
243            depth: self.depth,
244            divergence_info,
245            reached_maxdepth: maxdepth,
246            initial_energy: self.draw.point().initial_energy(),
247            draw_energy: self.draw.energy(),
248        }
249    }
250}
251
252pub struct NutsOptions {
253    pub maxdepth: u64,
254    pub mindepth: u64,
255    pub store_gradient: bool,
256    pub store_unconstrained: bool,
257    pub check_turning: bool,
258    pub store_divergences: bool,
259}
260
261impl Default for NutsOptions {
262    fn default() -> Self {
263        NutsOptions {
264            maxdepth: 10,
265            mindepth: 0,
266            store_gradient: false,
267            store_unconstrained: false,
268            check_turning: true,
269            store_divergences: false,
270        }
271    }
272}
273
274pub(crate) fn draw<M, H, R, C>(
275    math: &mut M,
276    init: &mut State<M, H::Point>,
277    rng: &mut R,
278    hamiltonian: &mut H,
279    options: &NutsOptions,
280    collector: &mut C,
281) -> Result<(State<M, H::Point>, SampleInfo)>
282where
283    M: Math,
284    H: Hamiltonian<M>,
285    R: rand::Rng + ?Sized,
286    C: Collector<M, H::Point>,
287{
288    hamiltonian.initialize_trajectory(math, init, rng)?;
289    collector.register_init(math, init, options);
290
291    let mut tree = NutsTree::new(init.clone());
292
293    if math.dim() == 0 {
294        let info = tree.info(false, None);
295        collector.register_draw(math, init, &info);
296        return Ok((init.clone(), info));
297    }
298
299    let options_no_check = NutsOptions {
300        check_turning: false,
301        ..*options
302    };
303
304    while tree.depth < options.maxdepth {
305        let direction: Direction = rng.random();
306        let current_options = if tree.depth < options.mindepth {
307            &options_no_check
308        } else {
309            options
310        };
311        tree = match tree.extend(
312            math,
313            rng,
314            hamiltonian,
315            direction,
316            collector,
317            current_options,
318        ) {
319            ExtendResult::Ok(tree) => tree,
320            ExtendResult::Turning(tree) => {
321                let info = tree.info(false, None);
322                collector.register_draw(math, &tree.draw, &info);
323                return Ok((tree.draw, info));
324            }
325            ExtendResult::Diverging(tree, info) => {
326                let info = tree.info(false, Some(info));
327                collector.register_draw(math, &tree.draw, &info);
328                return Ok((tree.draw, info));
329            }
330            ExtendResult::Err(error) => {
331                return Err(error);
332            }
333        };
334    }
335    let info = tree.info(true, None);
336    collector.register_draw(math, &tree.draw, &info);
337    Ok((tree.draw, info))
338}
339
340#[cfg(test)]
341mod tests {
342    use rand::rng;
343
344    use crate::{
345        Chain, Settings, adapt_strategy::test_logps::NormalLogp, cpu_math::CpuMath,
346        sampler::DiagGradNutsSettings,
347    };
348
349    #[test]
350    fn to_arrow() {
351        let ndim = 10;
352        let func = NormalLogp::new(ndim, 3.);
353        let math = CpuMath::new(func);
354
355        let settings = DiagGradNutsSettings::default();
356        let mut rng = rng();
357
358        let mut chain = settings.new_chain(0, math, &mut rng);
359
360        let (_, mut progress) = chain.draw().unwrap();
361        for _ in 0..10 {
362            let (_, prog) = chain.draw().unwrap();
363            progress = prog;
364        }
365
366        assert!(!progress.diverging);
367    }
368}