Skip to main content

nuts_rs/
chain.rs

1//! Drive a single sampling chain by wiring together the Hamiltonian, adaptation, and per-draw bookkeeping.
2
3use std::{
4    cell::{Ref, RefCell},
5    fmt::Debug,
6    marker::PhantomData,
7    ops::DerefMut,
8};
9
10use nuts_storable::{HasDims, Storable};
11use rand::Rng;
12
13use crate::{
14    Math, NutsError,
15    dynamics::{DivergenceStats, Hamiltonian, Point, State},
16    nuts::{Collector, NutsOptions, SampleInfo, draw},
17    sampler::Progress,
18    sampler_stats::{SamplerStats, StatsDims},
19};
20
21use anyhow::Result;
22
23/// Draw samples from the posterior distribution using Hamiltonian MCMC.
24pub trait Chain<M: Math>: SamplerStats<M> {
25    type AdaptStrategy: AdaptStrategy<M>;
26
27    /// Initialize the sampler to a position. This should be called
28    /// before calling draw.
29    ///
30    /// This fails if the logp function returns an error.
31    fn set_position(&mut self, position: &[f64]) -> Result<()>;
32
33    /// Draw a new sample and return the position and some diagnosic information.
34    fn draw(&mut self) -> Result<(Box<[f64]>, Progress)>;
35
36    /// The dimensionality of the posterior.
37    fn dim(&self) -> usize;
38
39    fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)>;
40
41    fn math(&self) -> Ref<'_, M>;
42}
43
44pub struct NutsChain<M, R, A>
45where
46    M: Math,
47    R: rand::Rng,
48    A: AdaptStrategy<M>,
49{
50    hamiltonian: A::Hamiltonian,
51    collector: A::Collector,
52    options: NutsOptions,
53    rng: R,
54    state: State<M, <A::Hamiltonian as Hamiltonian<M>>::Point>,
55    last_info: Option<SampleInfo>,
56    chain: u64,
57    draw_count: u64,
58    strategy: A,
59    math: RefCell<M>,
60    stats_options: StatOptions<M, A>,
61}
62
63impl<M, R, A> NutsChain<M, R, A>
64where
65    M: Math,
66    R: rand::Rng,
67    A: AdaptStrategy<M>,
68{
69    pub fn new(
70        mut math: M,
71        mut hamiltonian: A::Hamiltonian,
72        strategy: A,
73        options: NutsOptions,
74        rng: R,
75        chain: u64,
76        stats_options: StatOptions<M, A>,
77    ) -> Self {
78        let init = hamiltonian.pool().new_state(&mut math);
79        let collector = strategy.new_collector(&mut math);
80        NutsChain {
81            hamiltonian,
82            collector,
83            options,
84            rng,
85            state: init,
86            last_info: None,
87            chain,
88            draw_count: 0,
89            strategy,
90            math: math.into(),
91            stats_options,
92        }
93    }
94}
95
96pub trait AdaptStrategy<M: Math>: SamplerStats<M> {
97    type Hamiltonian: Hamiltonian<M>;
98    type Collector: Collector<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>;
99    type Options: Copy + Send + Debug + Default;
100
101    fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self;
102
103    fn init<R: Rng + ?Sized>(
104        &mut self,
105        math: &mut M,
106        options: &mut NutsOptions,
107        hamiltonian: &mut Self::Hamiltonian,
108        position: &[f64],
109        rng: &mut R,
110    ) -> Result<(), NutsError>;
111
112    #[allow(clippy::too_many_arguments)]
113    fn adapt<R: Rng + ?Sized>(
114        &mut self,
115        math: &mut M,
116        options: &mut NutsOptions,
117        hamiltonian: &mut Self::Hamiltonian,
118        draw: u64,
119        collector: &Self::Collector,
120        state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
121        rng: &mut R,
122    ) -> Result<(), NutsError>;
123
124    fn new_collector(&self, math: &mut M) -> Self::Collector;
125    fn is_tuning(&self) -> bool;
126    fn last_num_steps(&self) -> u64;
127}
128
129impl<M, R, A> Chain<M> for NutsChain<M, R, A>
130where
131    M: Math,
132    R: rand::Rng,
133    A: AdaptStrategy<M>,
134{
135    type AdaptStrategy = A;
136
137    fn set_position(&mut self, position: &[f64]) -> Result<()> {
138        let mut math_ = self.math.borrow_mut();
139        let math = math_.deref_mut();
140        self.strategy.init(
141            math,
142            &mut self.options,
143            &mut self.hamiltonian,
144            position,
145            &mut self.rng,
146        )?;
147        self.state = self.hamiltonian.init_state(math, position)?;
148        Ok(())
149    }
150
151    fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> {
152        let mut math_ = self.math.borrow_mut();
153        let math = math_.deref_mut();
154        let (state, info) = draw(
155            math,
156            &mut self.state,
157            &mut self.rng,
158            &mut self.hamiltonian,
159            &self.options,
160            &mut self.collector,
161        )?;
162        let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
163        state.write_position(math, &mut position);
164
165        self.strategy.adapt(
166            math,
167            &mut self.options,
168            &mut self.hamiltonian,
169            self.draw_count,
170            &self.collector,
171            &state,
172            &mut self.rng,
173        )?;
174        let progress = Progress {
175            draw: self.draw_count,
176            chain: self.chain,
177            diverging: info.divergence_info.is_some(),
178            tuning: self.strategy.is_tuning(),
179            step_size: self.hamiltonian.step_size(),
180            num_steps: self.strategy.last_num_steps(),
181        };
182
183        self.draw_count += 1;
184
185        self.state = state;
186        self.last_info = Some(info);
187        Ok((position, progress))
188    }
189
190    fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> {
191        let (position, progress) = self.draw()?;
192        let mut math_ = self.math.borrow_mut();
193        let math = math_.deref_mut();
194
195        let stats = self.extract_stats(&mut *math, self.stats_options);
196        // Update the stats_options of the hamiltonian. This is used to
197        // store only changes in the transformation.
198        self.stats_options.hamiltonian = self
199            .hamiltonian
200            .update_stats_options(&mut *math, self.stats_options.hamiltonian);
201        let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?;
202
203        Ok((position, expanded, stats, progress))
204    }
205
206    fn dim(&self) -> usize {
207        self.math.borrow().dim()
208    }
209
210    fn math(&self) -> Ref<'_, M> {
211        self.math.borrow()
212    }
213}
214
215#[derive(Debug, nuts_derive::Storable)]
216pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>> {
217    pub depth: u64,
218    pub maxdepth_reached: bool,
219    pub chain: u64,
220    pub draw: u64,
221    #[storable(flatten)]
222    pub hamiltonian: H,
223    #[storable(flatten)]
224    pub adapt: A,
225    #[storable(flatten)]
226    pub point: D,
227    #[storable(flatten)]
228    pub divergence: DivergenceStats,
229    #[storable(ignore)]
230    _phantom: PhantomData<fn() -> P>,
231}
232
233pub struct StatOptions<M: Math, A: AdaptStrategy<M>> {
234    pub adapt: A::StatsOptions,
235    pub hamiltonian: <A::Hamiltonian as SamplerStats<M>>::StatsOptions,
236    pub point: <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::StatsOptions,
237    pub divergence: crate::dynamics::DivergenceStatsOptions,
238}
239
240impl<M, A> Clone for StatOptions<M, A>
241where
242    M: Math,
243    A: AdaptStrategy<M>,
244{
245    fn clone(&self) -> Self {
246        *self
247    }
248}
249
250impl<M, A> Copy for StatOptions<M, A>
251where
252    M: Math,
253    A: AdaptStrategy<M>,
254{
255}
256
257impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M, R, A> {
258    type Stats = NutsStats<
259        StatsDims,
260        <A::Hamiltonian as SamplerStats<M>>::Stats,
261        A::Stats,
262        <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
263    >;
264    type StatsOptions = StatOptions<M, A>;
265
266    fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats {
267        let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian);
268        let adapt_stats = self.strategy.extract_stats(math, options.adapt);
269        let point_stats = self.state.point().extract_stats(math, options.point);
270        let info = self.last_info.as_ref().expect("Sampler has not started");
271        let div_info = info.divergence_info.as_ref();
272
273        NutsStats {
274            depth: info.depth,
275            maxdepth_reached: info.reached_maxdepth,
276            chain: self.chain,
277            draw: self.draw_count,
278            hamiltonian: hamiltonian_stats,
279            adapt: adapt_stats,
280            point: point_stats,
281            divergence: (div_info, options.divergence, self.draw_count).into(),
282            _phantom: PhantomData,
283        }
284    }
285}