nuts_rs/
nuts.rs

1use thiserror::Error;
2
3use std::{fmt::Debug, marker::PhantomData};
4
5use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point};
6use crate::math::logaddexp;
7use crate::state::State;
8
9use crate::math_base::Math;
10
11#[non_exhaustive]
12#[derive(Error, Debug)]
13pub enum NutsError {
14    #[error("Logp function returned error: {0:?}")]
15    LogpFailure(Box<dyn std::error::Error + Send + Sync>),
16
17    #[error("Could not serialize sample stats")]
18    SerializeFailure(),
19
20    #[error("Could not initialize state because of bad initial gradient: {0:?}")]
21    BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
22}
23
24pub type Result<T> = std::result::Result<T, NutsError>;
25
26/// Callbacks for various events during a Nuts sampling step.
27///
28/// Collectors can compute statistics like the mean acceptance rate
29/// or collect data for mass matrix adaptation.
30pub trait Collector<M: Math, P: Point<M>> {
31    fn register_leapfrog(
32        &mut self,
33        _math: &mut M,
34        _start: &State<M, P>,
35        _end: &State<M, P>,
36        _divergence_info: Option<&DivergenceInfo>,
37    ) {
38    }
39    fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
40    fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
41}
42
43/// Information about a draw, exported as part of the sampler stats
44#[derive(Debug)]
45pub struct SampleInfo {
46    /// The depth of the trajectory that this point was sampled from
47    pub depth: u64,
48
49    /// More detailed information about a divergence that might have
50    /// occured in the trajectory.
51    pub divergence_info: Option<DivergenceInfo>,
52
53    /// Whether the trajectory was terminated because it reached
54    /// the maximum tree depth.
55    pub reached_maxdepth: bool,
56
57    pub initial_energy: f64,
58    pub draw_energy: f64,
59}
60
61/// A part of the trajectory tree during NUTS sampling.
62struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
63    /// The left position of the tree.
64    ///
65    /// The left side always has the smaller index_in_trajectory.
66    /// Leapfrogs in backward direction will replace the left.
67    left: State<M, H::Point>,
68    right: State<M, H::Point>,
69
70    /// A draw from the trajectory between left and right using
71    /// multinomial sampling.
72    draw: State<M, H::Point>,
73    log_size: f64,
74    depth: u64,
75
76    /// A tree is the main tree if it contains the initial point
77    /// of the trajectory.
78    is_main: bool,
79    _phantom2: PhantomData<C>,
80}
81
82enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
83    /// The tree extension succeeded properly, and the termination
84    /// criterion was not reached.
85    Ok(NutsTree<M, H, C>),
86    /// An unrecoverable error happend during a leapfrog step
87    Err(NutsError),
88    /// Tree extension succeeded and the termination criterion
89    /// was reached.
90    Turning(NutsTree<M, H, C>),
91    /// A divergence happend during tree extension.
92    Diverging(NutsTree<M, H, C>, DivergenceInfo),
93}
94
95impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
96    fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
97        NutsTree {
98            right: state.clone(),
99            left: state.clone(),
100            draw: state,
101            depth: 0,
102            log_size: 0.,
103            is_main: true,
104            _phantom2: PhantomData,
105        }
106    }
107
108    #[allow(clippy::too_many_arguments)]
109    #[inline]
110    fn extend<R>(
111        mut self,
112        math: &mut M,
113        rng: &mut R,
114        hamiltonian: &mut H,
115        direction: Direction,
116        collector: &mut C,
117        options: &NutsOptions,
118    ) -> ExtendResult<M, H, C>
119    where
120        H: Hamiltonian<M>,
121        R: rand::Rng + ?Sized,
122    {
123        let mut other = match self.single_step(math, hamiltonian, direction, collector) {
124            Ok(Ok(tree)) => tree,
125            Ok(Err(info)) => return ExtendResult::Diverging(self, info),
126            Err(err) => return ExtendResult::Err(err),
127        };
128
129        while other.depth < self.depth {
130            use ExtendResult::*;
131            other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
132                Ok(tree) => tree,
133                Turning(_) => {
134                    return Turning(self);
135                }
136                Diverging(_, info) => {
137                    return Diverging(self, info);
138                }
139                Err(error) => {
140                    return Err(error);
141                }
142            };
143        }
144
145        let (first, last) = match direction {
146            Direction::Forward => (&self.left, &other.right),
147            Direction::Backward => (&other.left, &self.right),
148        };
149
150        let turning = if options.check_turning {
151            let mut turning = hamiltonian.is_turning(math, first, last);
152            if self.depth > 0 {
153                if !turning {
154                    turning = hamiltonian.is_turning(math, &self.right, &other.right);
155                }
156                if !turning {
157                    turning = hamiltonian.is_turning(math, &self.left, &other.left);
158                }
159            }
160            turning
161        } else {
162            false
163        };
164
165        self.merge_into(math, other, rng, direction);
166
167        if turning {
168            ExtendResult::Turning(self)
169        } else {
170            ExtendResult::Ok(self)
171        }
172    }
173
174    fn merge_into<R: rand::Rng + ?Sized>(
175        &mut self,
176        _math: &mut M,
177        other: NutsTree<M, H, C>,
178        rng: &mut R,
179        direction: Direction,
180    ) {
181        assert!(self.depth == other.depth);
182        assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
183        match direction {
184            Direction::Forward => {
185                self.right = other.right;
186            }
187            Direction::Backward => {
188                self.left = other.left;
189            }
190        }
191        let log_size = logaddexp(self.log_size, other.log_size);
192
193        let self_log_size = if self.is_main {
194            assert!(self.left.index_in_trajectory() <= 0);
195            assert!(self.right.index_in_trajectory() >= 0);
196            self.log_size
197        } else {
198            log_size
199        };
200
201        if (other.log_size >= self_log_size)
202            || (rng.random_bool((other.log_size - self_log_size).exp()))
203        {
204            self.draw = other.draw;
205        }
206
207        self.depth += 1;
208        self.log_size = log_size;
209    }
210
211    fn single_step(
212        &self,
213        math: &mut M,
214        hamiltonian: &mut H,
215        direction: Direction,
216        collector: &mut C,
217    ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
218        let start = match direction {
219            Direction::Forward => &self.right,
220            Direction::Backward => &self.left,
221        };
222        let end = match hamiltonian.leapfrog(math, start, direction, collector) {
223            LeapfrogResult::Divergence(info) => return Ok(Err(info)),
224            LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
225            LeapfrogResult::Ok(end) => end,
226        };
227
228        let log_size = -end.point().energy_error();
229        Ok(Ok(NutsTree {
230            right: end.clone(),
231            left: end.clone(),
232            draw: end,
233            depth: 0,
234            log_size,
235            is_main: false,
236            _phantom2: PhantomData,
237        }))
238    }
239
240    fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
241        SampleInfo {
242            depth: self.depth,
243            divergence_info,
244            reached_maxdepth: maxdepth,
245            initial_energy: self.draw.point().initial_energy(),
246            draw_energy: self.draw.energy(),
247        }
248    }
249}
250
251pub struct NutsOptions {
252    pub maxdepth: u64,
253    pub mindepth: u64,
254    pub store_gradient: bool,
255    pub store_unconstrained: bool,
256    pub check_turning: bool,
257    pub store_divergences: bool,
258}
259
260impl Default for NutsOptions {
261    fn default() -> Self {
262        NutsOptions {
263            maxdepth: 10,
264            mindepth: 0,
265            store_gradient: false,
266            store_unconstrained: false,
267            check_turning: true,
268            store_divergences: false,
269        }
270    }
271}
272
273pub(crate) fn draw<M, H, R, C>(
274    math: &mut M,
275    init: &mut State<M, H::Point>,
276    rng: &mut R,
277    hamiltonian: &mut H,
278    options: &NutsOptions,
279    collector: &mut C,
280) -> Result<(State<M, H::Point>, SampleInfo)>
281where
282    M: Math,
283    H: Hamiltonian<M>,
284    R: rand::Rng + ?Sized,
285    C: Collector<M, H::Point>,
286{
287    hamiltonian.initialize_trajectory(math, init, rng)?;
288    collector.register_init(math, init, options);
289
290    let mut tree = NutsTree::new(init.clone());
291
292    if math.dim() == 0 {
293        let info = tree.info(false, None);
294        collector.register_draw(math, init, &info);
295        return Ok((init.clone(), info));
296    }
297
298    let options_no_check = NutsOptions {
299        check_turning: false,
300        ..*options
301    };
302
303    while tree.depth < options.maxdepth {
304        let direction: Direction = rng.random();
305        let current_options = if tree.depth < options.mindepth {
306            &options_no_check
307        } else {
308            options
309        };
310        tree = match tree.extend(
311            math,
312            rng,
313            hamiltonian,
314            direction,
315            collector,
316            current_options,
317        ) {
318            ExtendResult::Ok(tree) => tree,
319            ExtendResult::Turning(tree) => {
320                let info = tree.info(false, None);
321                collector.register_draw(math, &tree.draw, &info);
322                return Ok((tree.draw, info));
323            }
324            ExtendResult::Diverging(tree, info) => {
325                let info = tree.info(false, Some(info));
326                collector.register_draw(math, &tree.draw, &info);
327                return Ok((tree.draw, info));
328            }
329            ExtendResult::Err(error) => {
330                return Err(error);
331            }
332        };
333    }
334    let info = tree.info(true, None);
335    collector.register_draw(math, &tree.draw, &info);
336    Ok((tree.draw, info))
337}
338
339#[cfg(test)]
340mod tests {
341    use rand::rng;
342
343    use crate::{
344        Chain, Settings, adapt_strategy::test_logps::NormalLogp, cpu_math::CpuMath,
345        sampler::DiagGradNutsSettings,
346    };
347
348    #[test]
349    fn to_arrow() {
350        let ndim = 10;
351        let func = NormalLogp::new(ndim, 3.);
352        let math = CpuMath::new(func);
353
354        let settings = DiagGradNutsSettings::default();
355        let mut rng = rng();
356
357        let mut chain = settings.new_chain(0, math, &mut rng);
358
359        let (_, mut progress) = chain.draw().unwrap();
360        for _ in 0..10 {
361            let (_, prog) = chain.draw().unwrap();
362            progress = prog;
363        }
364
365        assert!(!progress.diverging);
366    }
367}