1use std::{
2 cell::{Ref, RefCell},
3 fmt::Debug,
4 marker::PhantomData,
5 ops::DerefMut,
6};
7
8use nuts_storable::{HasDims, Storable};
9use rand::Rng;
10
11use crate::{
12 Math, NutsError,
13 hamiltonian::{Hamiltonian, Point},
14 nuts::{Collector, NutsOptions, SampleInfo, draw},
15 sampler::Progress,
16 sampler_stats::{SamplerStats, StatsDims},
17 state::State,
18};
19
20use anyhow::Result;
21
22pub trait Chain<M: Math>: SamplerStats<M> {
24 type AdaptStrategy: AdaptStrategy<M>;
25
26 fn set_position(&mut self, position: &[f64]) -> Result<()>;
31
32 fn draw(&mut self) -> Result<(Box<[f64]>, Progress)>;
34
35 fn dim(&self) -> usize;
37
38 fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)>;
39
40 fn math(&self) -> Ref<'_, M>;
41}
42
43pub struct NutsChain<M, R, A>
44where
45 M: Math,
46 R: rand::Rng,
47 A: AdaptStrategy<M>,
48{
49 hamiltonian: A::Hamiltonian,
50 collector: A::Collector,
51 options: NutsOptions,
52 rng: R,
53 state: State<M, <A::Hamiltonian as Hamiltonian<M>>::Point>,
54 last_info: Option<SampleInfo>,
55 chain: u64,
56 draw_count: u64,
57 strategy: A,
58 math: RefCell<M>,
59 stats_options: StatOptions<M, A>,
60}
61
62impl<M, R, A> NutsChain<M, R, A>
63where
64 M: Math,
65 R: rand::Rng,
66 A: AdaptStrategy<M>,
67{
68 pub fn new(
69 mut math: M,
70 mut hamiltonian: A::Hamiltonian,
71 strategy: A,
72 options: NutsOptions,
73 rng: R,
74 chain: u64,
75 stats_options: StatOptions<M, A>,
76 ) -> Self {
77 let init = hamiltonian.pool().new_state(&mut math);
78 let collector = strategy.new_collector(&mut math);
79 NutsChain {
80 hamiltonian,
81 collector,
82 options,
83 rng,
84 state: init,
85 last_info: None,
86 chain,
87 draw_count: 0,
88 strategy,
89 math: math.into(),
90 stats_options,
91 }
92 }
93}
94
95pub trait AdaptStrategy<M: Math>: SamplerStats<M> {
96 type Hamiltonian: Hamiltonian<M>;
97 type Collector: Collector<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>;
98 type Options: Copy + Send + Debug + Default;
99
100 fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self;
101
102 fn init<R: Rng + ?Sized>(
103 &mut self,
104 math: &mut M,
105 options: &mut NutsOptions,
106 hamiltonian: &mut Self::Hamiltonian,
107 position: &[f64],
108 rng: &mut R,
109 ) -> Result<(), NutsError>;
110
111 #[allow(clippy::too_many_arguments)]
112 fn adapt<R: Rng + ?Sized>(
113 &mut self,
114 math: &mut M,
115 options: &mut NutsOptions,
116 hamiltonian: &mut Self::Hamiltonian,
117 draw: u64,
118 collector: &Self::Collector,
119 state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
120 rng: &mut R,
121 ) -> Result<(), NutsError>;
122
123 fn new_collector(&self, math: &mut M) -> Self::Collector;
124 fn is_tuning(&self) -> bool;
125 fn last_num_steps(&self) -> u64;
126}
127
128impl<M, R, A> Chain<M> for NutsChain<M, R, A>
129where
130 M: Math,
131 R: rand::Rng,
132 A: AdaptStrategy<M>,
133{
134 type AdaptStrategy = A;
135
136 fn set_position(&mut self, position: &[f64]) -> Result<()> {
137 let mut math_ = self.math.borrow_mut();
138 let math = math_.deref_mut();
139 self.strategy.init(
140 math,
141 &mut self.options,
142 &mut self.hamiltonian,
143 position,
144 &mut self.rng,
145 )?;
146 self.state = self.hamiltonian.init_state(math, position)?;
147 Ok(())
148 }
149
150 fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> {
151 let mut math_ = self.math.borrow_mut();
152 let math = math_.deref_mut();
153 let (state, info) = draw(
154 math,
155 &mut self.state,
156 &mut self.rng,
157 &mut self.hamiltonian,
158 &self.options,
159 &mut self.collector,
160 )?;
161 let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
162 state.write_position(math, &mut position);
163
164 self.strategy.adapt(
165 math,
166 &mut self.options,
167 &mut self.hamiltonian,
168 self.draw_count,
169 &self.collector,
170 &state,
171 &mut self.rng,
172 )?;
173 let progress = Progress {
174 draw: self.draw_count,
175 chain: self.chain,
176 diverging: info.divergence_info.is_some(),
177 tuning: self.strategy.is_tuning(),
178 step_size: self.hamiltonian.step_size(),
179 num_steps: self.strategy.last_num_steps(),
180 };
181
182 self.draw_count += 1;
183
184 self.state = state;
185 self.last_info = Some(info);
186 Ok((position, progress))
187 }
188
189 fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> {
190 let (position, progress) = self.draw()?;
191 let mut math_ = self.math.borrow_mut();
192 let math = math_.deref_mut();
193
194 let stats = self.extract_stats(&mut *math, self.stats_options);
195 let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?;
196
197 Ok((position, expanded, stats, progress))
198 }
199
200 fn dim(&self) -> usize {
201 self.math.borrow().dim()
202 }
203
204 fn math(&self) -> Ref<'_, M> {
205 self.math.borrow()
206 }
207}
208
209#[derive(Debug, nuts_derive::Storable)]
210pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>> {
211 pub depth: u64,
212 pub maxdepth_reached: bool,
213 pub index_in_trajectory: i64,
214 pub logp: f64,
215 pub energy: f64,
216 pub chain: u64,
217 pub draw: u64,
218 pub energy_error: f64,
219 #[storable(dims("unconstrained_parameter"))]
220 pub unconstrained_draw: Option<Vec<f64>>,
221 #[storable(dims("unconstrained_parameter"))]
222 pub gradient: Option<Vec<f64>>,
223 #[storable(flatten)]
224 pub hamiltonian: H,
225 #[storable(flatten)]
226 pub adapt: A,
227 #[storable(flatten)]
228 pub point: D,
229 pub diverging: bool,
230 #[storable(dims("unconstrained_parameter"))]
231 pub divergence_start: Option<Vec<f64>>,
232 #[storable(dims("unconstrained_parameter"))]
233 pub divergence_start_gradient: Option<Vec<f64>>,
234 #[storable(dims("unconstrained_parameter"))]
235 pub divergence_end: Option<Vec<f64>>,
236 #[storable(dims("unconstrained_parameter"))]
237 pub divergence_momentum: Option<Vec<f64>>,
238 #[storable(ignore)]
240 _phantom: PhantomData<fn() -> P>,
241}
242
243pub struct StatOptions<M: Math, A: AdaptStrategy<M>> {
244 pub adapt: A::StatsOptions,
245 pub hamiltonian: <A::Hamiltonian as SamplerStats<M>>::StatsOptions,
246 pub point: <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::StatsOptions,
247}
248
249impl<M, A> Clone for StatOptions<M, A>
250where
251 M: Math,
252 A: AdaptStrategy<M>,
253{
254 fn clone(&self) -> Self {
255 *self
256 }
257}
258
259impl<M, A> Copy for StatOptions<M, A>
260where
261 M: Math,
262 A: AdaptStrategy<M>,
263{
264}
265
266impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M, R, A> {
267 type Stats = NutsStats<
268 StatsDims,
269 <A::Hamiltonian as SamplerStats<M>>::Stats,
270 A::Stats,
271 <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
272 >;
273 type StatsOptions = StatOptions<M, A>;
274
275 fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats {
276 let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian);
277 let adapt_stats = self.strategy.extract_stats(math, options.adapt);
278 let point_stats = self.state.point().extract_stats(math, options.point);
279 let info = self.last_info.as_ref().expect("Sampler has not started");
280 let point = self.state.point();
281 let div_info = info.divergence_info.as_ref();
282
283 NutsStats {
284 depth: info.depth,
285 maxdepth_reached: info.reached_maxdepth,
286 index_in_trajectory: point.index_in_trajectory(),
287 logp: point.logp(),
288 energy: point.energy(),
289 chain: self.chain,
290 draw: self.draw_count,
291 energy_error: point.energy_error(),
292 unconstrained_draw: Some(math.box_array(point.position()).into_vec()),
293 gradient: Some(math.box_array(point.gradient()).into_vec()),
294 hamiltonian: hamiltonian_stats,
295 adapt: adapt_stats,
296 point: point_stats,
297 diverging: div_info.is_some(),
298 divergence_start: div_info
299 .and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec())),
300 divergence_start_gradient: div_info
301 .and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec())),
302 divergence_end: div_info
303 .and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())),
304 divergence_momentum: div_info
305 .and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())),
306 _phantom: PhantomData,
308 }
309 }
310}