1use std::{
4 cell::{Ref, RefCell},
5 fmt::Debug,
6 marker::PhantomData,
7 ops::DerefMut,
8};
9
10use nuts_storable::{HasDims, Storable};
11use rand::Rng;
12
13use crate::{
14 Math, NutsError,
15 dynamics::{DivergenceStats, Hamiltonian, Point, State},
16 nuts::{Collector, NutsOptions, SampleInfo, draw},
17 sampler::Progress,
18 sampler_stats::{SamplerStats, StatsDims},
19};
20
21use anyhow::Result;
22
23pub trait Chain<M: Math>: SamplerStats<M> {
25 type AdaptStrategy: AdaptStrategy<M>;
26
27 fn set_position(&mut self, position: &[f64]) -> Result<()>;
32
33 fn draw(&mut self) -> Result<(Box<[f64]>, Progress)>;
35
36 fn dim(&self) -> usize;
38
39 fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)>;
40
41 fn math(&self) -> Ref<'_, M>;
42}
43
44pub struct NutsChain<M, R, A>
45where
46 M: Math,
47 R: rand::Rng,
48 A: AdaptStrategy<M>,
49{
50 hamiltonian: A::Hamiltonian,
51 collector: A::Collector,
52 options: NutsOptions,
53 rng: R,
54 state: State<M, <A::Hamiltonian as Hamiltonian<M>>::Point>,
55 last_info: Option<SampleInfo>,
56 chain: u64,
57 draw_count: u64,
58 strategy: A,
59 math: RefCell<M>,
60 stats_options: StatOptions<M, A>,
61}
62
63impl<M, R, A> NutsChain<M, R, A>
64where
65 M: Math,
66 R: rand::Rng,
67 A: AdaptStrategy<M>,
68{
69 pub fn new(
70 mut math: M,
71 mut hamiltonian: A::Hamiltonian,
72 strategy: A,
73 options: NutsOptions,
74 rng: R,
75 chain: u64,
76 stats_options: StatOptions<M, A>,
77 ) -> Self {
78 let init = hamiltonian.pool().new_state(&mut math);
79 let collector = strategy.new_collector(&mut math);
80 NutsChain {
81 hamiltonian,
82 collector,
83 options,
84 rng,
85 state: init,
86 last_info: None,
87 chain,
88 draw_count: 0,
89 strategy,
90 math: math.into(),
91 stats_options,
92 }
93 }
94}
95
96pub trait AdaptStrategy<M: Math>: SamplerStats<M> {
97 type Hamiltonian: Hamiltonian<M>;
98 type Collector: Collector<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>;
99 type Options: Copy + Send + Debug + Default;
100
101 fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self;
102
103 fn init<R: Rng + ?Sized>(
104 &mut self,
105 math: &mut M,
106 options: &mut NutsOptions,
107 hamiltonian: &mut Self::Hamiltonian,
108 position: &[f64],
109 rng: &mut R,
110 ) -> Result<(), NutsError>;
111
112 #[allow(clippy::too_many_arguments)]
113 fn adapt<R: Rng + ?Sized>(
114 &mut self,
115 math: &mut M,
116 options: &mut NutsOptions,
117 hamiltonian: &mut Self::Hamiltonian,
118 draw: u64,
119 collector: &Self::Collector,
120 state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
121 rng: &mut R,
122 ) -> Result<(), NutsError>;
123
124 fn new_collector(&self, math: &mut M) -> Self::Collector;
125 fn is_tuning(&self) -> bool;
126 fn last_num_steps(&self) -> u64;
127}
128
129impl<M, R, A> Chain<M> for NutsChain<M, R, A>
130where
131 M: Math,
132 R: rand::Rng,
133 A: AdaptStrategy<M>,
134{
135 type AdaptStrategy = A;
136
137 fn set_position(&mut self, position: &[f64]) -> Result<()> {
138 let mut math_ = self.math.borrow_mut();
139 let math = math_.deref_mut();
140 self.strategy.init(
141 math,
142 &mut self.options,
143 &mut self.hamiltonian,
144 position,
145 &mut self.rng,
146 )?;
147 self.state = self.hamiltonian.init_state(math, position)?;
148 Ok(())
149 }
150
151 fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> {
152 let mut math_ = self.math.borrow_mut();
153 let math = math_.deref_mut();
154 let (state, info) = draw(
155 math,
156 &mut self.state,
157 &mut self.rng,
158 &mut self.hamiltonian,
159 &self.options,
160 &mut self.collector,
161 )?;
162 let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
163 state.write_position(math, &mut position);
164
165 self.strategy.adapt(
166 math,
167 &mut self.options,
168 &mut self.hamiltonian,
169 self.draw_count,
170 &self.collector,
171 &state,
172 &mut self.rng,
173 )?;
174 let progress = Progress {
175 draw: self.draw_count,
176 chain: self.chain,
177 diverging: info.divergence_info.is_some(),
178 tuning: self.strategy.is_tuning(),
179 step_size: self.hamiltonian.step_size(),
180 num_steps: self.strategy.last_num_steps(),
181 };
182
183 self.draw_count += 1;
184
185 self.state = state;
186 self.last_info = Some(info);
187 Ok((position, progress))
188 }
189
190 fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> {
191 let (position, progress) = self.draw()?;
192 let mut math_ = self.math.borrow_mut();
193 let math = math_.deref_mut();
194
195 let stats = self.extract_stats(&mut *math, self.stats_options);
196 self.stats_options.hamiltonian = self
199 .hamiltonian
200 .update_stats_options(&mut *math, self.stats_options.hamiltonian);
201 let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?;
202
203 Ok((position, expanded, stats, progress))
204 }
205
206 fn dim(&self) -> usize {
207 self.math.borrow().dim()
208 }
209
210 fn math(&self) -> Ref<'_, M> {
211 self.math.borrow()
212 }
213}
214
215#[derive(Debug, nuts_derive::Storable)]
216pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>> {
217 pub depth: u64,
218 pub maxdepth_reached: bool,
219 pub chain: u64,
220 pub draw: u64,
221 #[storable(flatten)]
222 pub hamiltonian: H,
223 #[storable(flatten)]
224 pub adapt: A,
225 #[storable(flatten)]
226 pub point: D,
227 #[storable(flatten)]
228 pub divergence: DivergenceStats,
229 #[storable(ignore)]
230 _phantom: PhantomData<fn() -> P>,
231}
232
233pub struct StatOptions<M: Math, A: AdaptStrategy<M>> {
234 pub adapt: A::StatsOptions,
235 pub hamiltonian: <A::Hamiltonian as SamplerStats<M>>::StatsOptions,
236 pub point: <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::StatsOptions,
237 pub divergence: crate::dynamics::DivergenceStatsOptions,
238}
239
240impl<M, A> Clone for StatOptions<M, A>
241where
242 M: Math,
243 A: AdaptStrategy<M>,
244{
245 fn clone(&self) -> Self {
246 *self
247 }
248}
249
250impl<M, A> Copy for StatOptions<M, A>
251where
252 M: Math,
253 A: AdaptStrategy<M>,
254{
255}
256
257impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M, R, A> {
258 type Stats = NutsStats<
259 StatsDims,
260 <A::Hamiltonian as SamplerStats<M>>::Stats,
261 A::Stats,
262 <<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
263 >;
264 type StatsOptions = StatOptions<M, A>;
265
266 fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats {
267 let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian);
268 let adapt_stats = self.strategy.extract_stats(math, options.adapt);
269 let point_stats = self.state.point().extract_stats(math, options.point);
270 let info = self.last_info.as_ref().expect("Sampler has not started");
271 let div_info = info.divergence_info.as_ref();
272
273 NutsStats {
274 depth: info.depth,
275 maxdepth_reached: info.reached_maxdepth,
276 chain: self.chain,
277 draw: self.draw_count,
278 hamiltonian: hamiltonian_stats,
279 adapt: adapt_stats,
280 point: point_stats,
281 divergence: (div_info, options.divergence, self.draw_count).into(),
282 _phantom: PhantomData,
283 }
284 }
285}