nuts_rs/
chain.rs

1use std::{
2    cell::{Ref, RefCell},
3    fmt::Debug,
4    marker::PhantomData,
5    ops::DerefMut,
6};
7
8use nuts_storable::{HasDims, Storable};
9use rand::Rng;
10
11use crate::{
12    Math, NutsError,
13    hamiltonian::{Hamiltonian, Point},
14    nuts::{Collector, NutsOptions, SampleInfo, draw},
15    sampler::Progress,
16    sampler_stats::{SamplerStats, StatsDims},
17    state::State,
18};
19
20use anyhow::Result;
21
22/// Draw samples from the posterior distribution using Hamiltonian MCMC.
23pub trait Chain<M: Math>: SamplerStats<M> {
24    type AdaptStrategy: AdaptStrategy<M>;
25
26    /// Initialize the sampler to a position. This should be called
27    /// before calling draw.
28    ///
29    /// This fails if the logp function returns an error.
30    fn set_position(&mut self, position: &[f64]) -> Result<()>;
31
32    /// Draw a new sample and return the position and some diagnosic information.
33    fn draw(&mut self) -> Result<(Box<[f64]>, Progress)>;
34
35    /// The dimensionality of the posterior.
36    fn dim(&self) -> usize;
37
38    fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)>;
39
40    fn math(&self) -> Ref<'_, M>;
41}
42
43pub struct NutsChain<M, R, A>
44where
45    M: Math,
46    R: rand::Rng,
47    A: AdaptStrategy<M>,
48{
49    hamiltonian: A::Hamiltonian,
50    collector: A::Collector,
51    options: NutsOptions,
52    rng: R,
53    state: State<M, <A::Hamiltonian as Hamiltonian<M>>::Point>,
54    last_info: Option<SampleInfo>,
55    chain: u64,
56    draw_count: u64,
57    strategy: A,
58    math: RefCell<M>,
59    stats_options: StatOptions<M, A>,
60}
61
62impl<M, R, A> NutsChain<M, R, A>
63where
64    M: Math,
65    R: rand::Rng,
66    A: AdaptStrategy<M>,
67{
68    pub fn new(
69        mut math: M,
70        mut hamiltonian: A::Hamiltonian,
71        strategy: A,
72        options: NutsOptions,
73        rng: R,
74        chain: u64,
75        stats_options: StatOptions<M, A>,
76    ) -> Self {
77        let init = hamiltonian.pool().new_state(&mut math);
78        let collector = strategy.new_collector(&mut math);
79        NutsChain {
80            hamiltonian,
81            collector,
82            options,
83            rng,
84            state: init,
85            last_info: None,
86            chain,
87            draw_count: 0,
88            strategy,
89            math: math.into(),
90            stats_options,
91        }
92    }
93}
94
95pub trait AdaptStrategy<M: Math>: SamplerStats<M> {
96    type Hamiltonian: Hamiltonian<M>;
97    type Collector: Collector<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>;
98    type Options: Copy + Send + Debug + Default;
99
100    fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self;
101
102    fn init<R: Rng + ?Sized>(
103        &mut self,
104        math: &mut M,
105        options: &mut NutsOptions,
106        hamiltonian: &mut Self::Hamiltonian,
107        position: &[f64],
108        rng: &mut R,
109    ) -> Result<(), NutsError>;
110
111    #[allow(clippy::too_many_arguments)]
112    fn adapt<R: Rng + ?Sized>(
113        &mut self,
114        math: &mut M,
115        options: &mut NutsOptions,
116        hamiltonian: &mut Self::Hamiltonian,
117        draw: u64,
118        collector: &Self::Collector,
119        state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
120        rng: &mut R,
121    ) -> Result<(), NutsError>;
122
123    fn new_collector(&self, math: &mut M) -> Self::Collector;
124    fn is_tuning(&self) -> bool;
125    fn last_num_steps(&self) -> u64;
126}
127
128impl<M, R, A> Chain<M> for NutsChain<M, R, A>
129where
130    M: Math,
131    R: rand::Rng,
132    A: AdaptStrategy<M>,
133{
134    type AdaptStrategy = A;
135
136    fn set_position(&mut self, position: &[f64]) -> Result<()> {
137        let mut math_ = self.math.borrow_mut();
138        let math = math_.deref_mut();
139        self.strategy.init(
140            math,
141            &mut self.options,
142            &mut self.hamiltonian,
143            position,
144            &mut self.rng,
145        )?;
146        self.state = self.hamiltonian.init_state(math, position)?;
147        Ok(())
148    }
149
150    fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> {
151        let mut math_ = self.math.borrow_mut();
152        let math = math_.deref_mut();
153        let (state, info) = draw(
154            math,
155            &mut self.state,
156            &mut self.rng,
157            &mut self.hamiltonian,
158            &self.options,
159            &mut self.collector,
160        )?;
161        let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
162        state.write_position(math, &mut position);
163
164        self.strategy.adapt(
165            math,
166            &mut self.options,
167            &mut self.hamiltonian,
168            self.draw_count,
169            &self.collector,
170            &state,
171            &mut self.rng,
172        )?;
173        let progress = Progress {
174            draw: self.draw_count,
175            chain: self.chain,
176            diverging: info.divergence_info.is_some(),
177            tuning: self.strategy.is_tuning(),
178            step_size: self.hamiltonian.step_size(),
179            num_steps: self.strategy.last_num_steps(),
180        };
181
182        self.draw_count += 1;
183
184        self.state = state;
185        self.last_info = Some(info);
186        Ok((position, progress))
187    }
188
189    fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> {
190        let (position, progress) = self.draw()?;
191        let mut math_ = self.math.borrow_mut();
192        let math = math_.deref_mut();
193
194        let stats = self.extract_stats(&mut *math, self.stats_options);
195        let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?;
196
197        Ok((position, expanded, stats, progress))
198    }
199
200    fn dim(&self) -> usize {
201        self.math.borrow().dim()
202    }
203
204    fn math(&self) -> Ref<'_, M> {
205        self.math.borrow()
206    }
207}
208
209#[derive(Debug, nuts_derive::Storable)]
210pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>> {
211    pub depth: u64,
212    pub maxdepth_reached: bool,
213    pub index_in_trajectory: i64,
214    pub logp: f64,
215    pub energy: f64,
216    pub chain: u64,
217    pub draw: u64,
218    pub energy_error: f64,
219    #[storable(dims("unconstrained_parameter"))]
220    pub unconstrained_draw: Option<Vec<f64>>,
221    #[storable(dims("unconstrained_parameter"))]
222    pub gradient: Option<Vec<f64>>,
223    #[storable(flatten)]
224    pub hamiltonian: H,
225    #[storable(flatten)]
226    pub adapt: A,
227    #[storable(flatten)]
228    pub point: D,
229    pub diverging: bool,
230    #[storable(dims("unconstrained_parameter"))]
231    pub divergence_start: Option<Vec<f64>>,
232    #[storable(dims("unconstrained_parameter"))]
233    pub divergence_start_gradient: Option<Vec<f64>>,
234    #[storable(dims("unconstrained_parameter"))]
235    pub divergence_end: Option<Vec<f64>>,
236    #[storable(dims("unconstrained_parameter"))]
237    pub divergence_momentum: Option<Vec<f64>>,
238    //pub divergence_message: Option<String>,
239    #[storable(ignore)]
240    _phantom: PhantomData<fn() -> P>,
241}
242
243pub struct StatOptions<M: Math, A: AdaptStrategy<M>> {
244    pub adapt: A::StatsOptions,
245    pub hamiltonian: <A::Hamiltonian as SamplerStats<M>>::StatsOptions,
246    pub point: <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::StatsOptions,
247}
248
249impl<M, A> Clone for StatOptions<M, A>
250where
251    M: Math,
252    A: AdaptStrategy<M>,
253{
254    fn clone(&self) -> Self {
255        *self
256    }
257}
258
259impl<M, A> Copy for StatOptions<M, A>
260where
261    M: Math,
262    A: AdaptStrategy<M>,
263{
264}
265
266impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M, R, A> {
267    type Stats = NutsStats<
268        StatsDims,
269        <A::Hamiltonian as SamplerStats<M>>::Stats,
270        A::Stats,
271        <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
272    >;
273    type StatsOptions = StatOptions<M, A>;
274
275    fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats {
276        let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian);
277        let adapt_stats = self.strategy.extract_stats(math, options.adapt);
278        let point_stats = self.state.point().extract_stats(math, options.point);
279        let info = self.last_info.as_ref().expect("Sampler has not started");
280        let point = self.state.point();
281        let div_info = info.divergence_info.as_ref();
282
283        NutsStats {
284            depth: info.depth,
285            maxdepth_reached: info.reached_maxdepth,
286            index_in_trajectory: point.index_in_trajectory(),
287            logp: point.logp(),
288            energy: point.energy(),
289            chain: self.chain,
290            draw: self.draw_count,
291            energy_error: point.energy_error(),
292            unconstrained_draw: Some(math.box_array(point.position()).into_vec()),
293            gradient: Some(math.box_array(point.gradient()).into_vec()),
294            hamiltonian: hamiltonian_stats,
295            adapt: adapt_stats,
296            point: point_stats,
297            diverging: div_info.is_some(),
298            divergence_start: div_info
299                .and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec())),
300            divergence_start_gradient: div_info
301                .and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec())),
302            divergence_end: div_info
303                .and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())),
304            divergence_momentum: div_info
305                .and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())),
306            //divergence_message: self.divergence_msg.clone(),
307            _phantom: PhantomData,
308        }
309    }
310}