nuts_rs/
nuts.rs

1use thiserror::Error;
2
3use std::{fmt::Debug, marker::PhantomData};
4
5use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point};
6use crate::math::logaddexp;
7use crate::state::State;
8
9use crate::math_base::Math;
10
11#[non_exhaustive]
12#[derive(Error, Debug)]
13pub enum NutsError {
14    #[error("Logp function returned error: {0:?}")]
15    LogpFailure(Box<dyn std::error::Error + Send + Sync>),
16
17    #[error("Could not serialize sample stats")]
18    SerializeFailure(),
19
20    #[error("Could not initialize state because of bad initial gradient: {0:?}")]
21    BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
22}
23
24pub type Result<T> = std::result::Result<T, NutsError>;
25
26/// Callbacks for various events during a Nuts sampling step.
27///
28/// Collectors can compute statistics like the mean acceptance rate
29/// or collect data for mass matrix adaptation.
30pub trait Collector<M: Math, P: Point<M>> {
31    fn register_leapfrog(
32        &mut self,
33        _math: &mut M,
34        _start: &State<M, P>,
35        _end: &State<M, P>,
36        _divergence_info: Option<&DivergenceInfo>,
37    ) {
38    }
39    fn register_draw(&mut self, _math: &mut M, _state: &State<M, P>, _info: &SampleInfo) {}
40    fn register_init(&mut self, _math: &mut M, _state: &State<M, P>, _options: &NutsOptions) {}
41}
42
43/// Information about a draw, exported as part of the sampler stats
44#[derive(Debug)]
45pub struct SampleInfo {
46    /// The depth of the trajectory that this point was sampled from
47    pub depth: u64,
48
49    /// More detailed information about a divergence that might have
50    /// occured in the trajectory.
51    pub divergence_info: Option<DivergenceInfo>,
52
53    /// Whether the trajectory was terminated because it reached
54    /// the maximum tree depth.
55    pub reached_maxdepth: bool,
56
57    pub initial_energy: f64,
58    pub draw_energy: f64,
59}
60
61/// A part of the trajectory tree during NUTS sampling.
62struct NutsTree<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
63    /// The left position of the tree.
64    ///
65    /// The left side always has the smaller index_in_trajectory.
66    /// Leapfrogs in backward direction will replace the left.
67    left: State<M, H::Point>,
68    right: State<M, H::Point>,
69
70    /// A draw from the trajectory between left and right using
71    /// multinomial sampling.
72    draw: State<M, H::Point>,
73    log_size: f64,
74    depth: u64,
75
76    /// A tree is the main tree if it contains the initial point
77    /// of the trajectory.
78    is_main: bool,
79    _phantom2: PhantomData<C>,
80}
81
82enum ExtendResult<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> {
83    /// The tree extension succeeded properly, and the termination
84    /// criterion was not reached.
85    Ok(NutsTree<M, H, C>),
86    /// An unrecoverable error happend during a leapfrog step
87    Err(NutsError),
88    /// Tree extension succeeded and the termination criterion
89    /// was reached.
90    Turning(NutsTree<M, H, C>),
91    /// A divergence happend during tree extension.
92    Diverging(NutsTree<M, H, C>, DivergenceInfo),
93}
94
95impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
96    fn new(state: State<M, H::Point>) -> NutsTree<M, H, C> {
97        NutsTree {
98            right: state.clone(),
99            left: state.clone(),
100            draw: state,
101            depth: 0,
102            log_size: 0.,
103            is_main: true,
104            _phantom2: PhantomData,
105        }
106    }
107
108    #[allow(clippy::too_many_arguments)]
109    #[inline]
110    fn extend<R>(
111        mut self,
112        math: &mut M,
113        rng: &mut R,
114        hamiltonian: &mut H,
115        direction: Direction,
116        collector: &mut C,
117        options: &NutsOptions,
118    ) -> ExtendResult<M, H, C>
119    where
120        H: Hamiltonian<M>,
121        R: rand::Rng + ?Sized,
122    {
123        let mut other = match self.single_step(math, hamiltonian, direction, collector) {
124            Ok(Ok(tree)) => tree,
125            Ok(Err(info)) => return ExtendResult::Diverging(self, info),
126            Err(err) => return ExtendResult::Err(err),
127        };
128
129        while other.depth < self.depth {
130            use ExtendResult::*;
131            other = match other.extend(math, rng, hamiltonian, direction, collector, options) {
132                Ok(tree) => tree,
133                Turning(_) => {
134                    return Turning(self);
135                }
136                Diverging(_, info) => {
137                    return Diverging(self, info);
138                }
139                Err(error) => {
140                    return Err(error);
141                }
142            };
143        }
144
145        let (first, last) = match direction {
146            Direction::Forward => (&self.left, &other.right),
147            Direction::Backward => (&other.left, &self.right),
148        };
149
150        let turning = if options.check_turning {
151            let mut turning = hamiltonian.is_turning(math, first, last);
152            if self.depth > 0 {
153                if !turning {
154                    turning = hamiltonian.is_turning(math, &self.right, &other.right);
155                }
156                if !turning {
157                    turning = hamiltonian.is_turning(math, &self.left, &other.left);
158                }
159            }
160            turning
161        } else {
162            false
163        };
164
165        self.merge_into(math, other, rng, direction);
166
167        if turning {
168            ExtendResult::Turning(self)
169        } else {
170            ExtendResult::Ok(self)
171        }
172    }
173
174    fn merge_into<R: rand::Rng + ?Sized>(
175        &mut self,
176        _math: &mut M,
177        other: NutsTree<M, H, C>,
178        rng: &mut R,
179        direction: Direction,
180    ) {
181        assert!(self.depth == other.depth);
182        assert!(self.left.index_in_trajectory() <= self.right.index_in_trajectory());
183        match direction {
184            Direction::Forward => {
185                self.right = other.right;
186            }
187            Direction::Backward => {
188                self.left = other.left;
189            }
190        }
191        let log_size = logaddexp(self.log_size, other.log_size);
192
193        let self_log_size = if self.is_main {
194            assert!(self.left.index_in_trajectory() <= 0);
195            assert!(self.right.index_in_trajectory() >= 0);
196            self.log_size
197        } else {
198            log_size
199        };
200
201        if (other.log_size >= self_log_size)
202            || (rng.random_bool((other.log_size - self_log_size).exp()))
203        {
204            self.draw = other.draw;
205        }
206
207        self.depth += 1;
208        self.log_size = log_size;
209    }
210
211    fn single_step(
212        &self,
213        math: &mut M,
214        hamiltonian: &mut H,
215        direction: Direction,
216        collector: &mut C,
217    ) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
218        let start = match direction {
219            Direction::Forward => &self.right,
220            Direction::Backward => &self.left,
221        };
222        let end = match hamiltonian.leapfrog(math, start, direction, collector) {
223            LeapfrogResult::Divergence(info) => return Ok(Err(info)),
224            LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
225            LeapfrogResult::Ok(end) => end,
226        };
227
228        let log_size = -end.point().energy_error();
229        Ok(Ok(NutsTree {
230            right: end.clone(),
231            left: end.clone(),
232            draw: end,
233            depth: 0,
234            log_size,
235            is_main: false,
236            _phantom2: PhantomData,
237        }))
238    }
239
240    fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
241        SampleInfo {
242            depth: self.depth,
243            divergence_info,
244            reached_maxdepth: maxdepth,
245            initial_energy: self.draw.point().initial_energy(),
246            draw_energy: self.draw.energy(),
247        }
248    }
249}
250
251pub struct NutsOptions {
252    pub maxdepth: u64,
253    pub store_gradient: bool,
254    pub store_unconstrained: bool,
255    pub check_turning: bool,
256    pub store_divergences: bool,
257}
258
259pub(crate) fn draw<M, H, R, C>(
260    math: &mut M,
261    init: &mut State<M, H::Point>,
262    rng: &mut R,
263    hamiltonian: &mut H,
264    options: &NutsOptions,
265    collector: &mut C,
266) -> Result<(State<M, H::Point>, SampleInfo)>
267where
268    M: Math,
269    H: Hamiltonian<M>,
270    R: rand::Rng + ?Sized,
271    C: Collector<M, H::Point>,
272{
273    hamiltonian.initialize_trajectory(math, init, rng)?;
274    collector.register_init(math, init, options);
275
276    let mut tree = NutsTree::new(init.clone());
277
278    if math.dim() == 0 {
279        let info = tree.info(false, None);
280        collector.register_draw(math, init, &info);
281        return Ok((init.clone(), info));
282    }
283
284    while tree.depth < options.maxdepth {
285        let direction: Direction = rng.random();
286        tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) {
287            ExtendResult::Ok(tree) => tree,
288            ExtendResult::Turning(tree) => {
289                let info = tree.info(false, None);
290                collector.register_draw(math, &tree.draw, &info);
291                return Ok((tree.draw, info));
292            }
293            ExtendResult::Diverging(tree, info) => {
294                let info = tree.info(false, Some(info));
295                collector.register_draw(math, &tree.draw, &info);
296                return Ok((tree.draw, info));
297            }
298            ExtendResult::Err(error) => {
299                return Err(error);
300            }
301        };
302    }
303    let info = tree.info(true, None);
304    collector.register_draw(math, &tree.draw, &info);
305    Ok((tree.draw, info))
306}
307
308#[cfg(test)]
309mod tests {
310    use rand::{rng, rngs::ThreadRng};
311
312    use crate::{
313        adapt_strategy::test_logps::NormalLogp,
314        chain::NutsChain,
315        cpu_math::CpuMath,
316        sampler::DiagGradNutsSettings,
317        sampler_stats::{SamplerStats, StatTraceBuilder},
318        Chain, Settings,
319    };
320
321    #[test]
322    fn to_arrow() {
323        let ndim = 10;
324        let func = NormalLogp::new(ndim, 3.);
325        let math = CpuMath::new(func);
326
327        let settings = DiagGradNutsSettings::default();
328        let mut rng = rng();
329
330        let mut chain = settings.new_chain(0, math, &mut rng);
331
332        let opt_settings = settings.stats_options(&chain);
333        let mut builder = chain.new_builder(opt_settings, &settings, ndim);
334
335        let (_, mut progress) = chain.draw().unwrap();
336        for _ in 0..10 {
337            let (_, prog) = chain.draw().unwrap();
338            progress = prog;
339            builder.append_value(None, &chain);
340        }
341
342        assert!(!progress.diverging);
343        StatTraceBuilder::<_, NutsChain<_, ThreadRng, _>>::finalize(builder);
344    }
345}