1use std::{fmt::Debug, marker::PhantomData};
4
5use nuts_derive::Storable;
6use nuts_storable::{HasDims, Storable};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10use super::stepsize::AcceptanceRateCollector;
11use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
12use crate::dynamics::{
13 DivergenceInfo, Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint,
14};
15use crate::transform::MassMatrixAdaptStrategy;
16use crate::{
17 NutsError,
18 chain::AdaptStrategy,
19 math::Math,
20 nuts::{Collector, NutsOptions},
21 sampler_stats::{SamplerStats, StatsDims},
22};
23
24pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
25 step_size: StepSizeStrategy,
26 mass_matrix_adapt: A,
27 options: EuclideanAdaptOptions<A::Options>,
28 num_tune: u64,
29 early_end: u64,
31
32 final_step_size_window: u64,
34 tuning: bool,
35 has_initial_mass_matrix: bool,
36 last_update: u64,
37 current_window_size: u64,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
42pub struct EuclideanAdaptOptions<S: Debug + Default> {
43 pub step_size_settings: StepSizeSettings,
44 pub mass_matrix_options: S,
45 pub early_window: f64,
46 pub step_size_window: f64,
47 pub mass_matrix_switch_freq: u64,
49 pub early_mass_matrix_switch_freq: u64,
50 pub mass_matrix_update_freq: u64,
51 pub mass_matrix_window_growth: f64,
54}
55
56impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
57 fn default() -> Self {
58 Self {
59 step_size_settings: StepSizeSettings::default(),
60 mass_matrix_options: S::default(),
61 early_window: 0.3,
62 step_size_window: 0.15,
63 mass_matrix_switch_freq: 80,
64 early_mass_matrix_switch_freq: 10,
65 mass_matrix_update_freq: 1,
66 mass_matrix_window_growth: 1.5,
67 }
68 }
69}
70
71impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
72 type Hamiltonian = TransformedHamiltonian<M, A::Transformation>;
73 type Collector =
74 CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, A::Collector>;
75 type Options = EuclideanAdaptOptions<A::Options>;
76
77 fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
78 let num_tune_f = num_tune as f64;
79 let step_size_window = (options.step_size_window * num_tune_f) as u64;
80 let early_end = (options.early_window * num_tune_f) as u64;
81 let final_second_step_size = num_tune.saturating_sub(step_size_window);
82
83 assert!(early_end < num_tune);
84 assert!(options.mass_matrix_window_growth >= 1.0);
85
86 Self {
87 step_size: StepSizeStrategy::new(options.step_size_settings),
88 mass_matrix_adapt: A::new(math, options.mass_matrix_options, num_tune, chain),
89 options,
90 num_tune,
91 early_end,
92 final_step_size_window: final_second_step_size,
93 tuning: true,
94 has_initial_mass_matrix: true,
95 last_update: 0,
96 current_window_size: options.mass_matrix_switch_freq,
97 }
98 }
99
100 fn init<R: Rng + ?Sized>(
101 &mut self,
102 math: &mut M,
103 options: &mut NutsOptions,
104 hamiltonian: &mut Self::Hamiltonian,
105 position: &[f64],
106 rng: &mut R,
107 ) -> Result<(), NutsError> {
108 let state = hamiltonian.init_state_untransformed(math, position)?;
109 self.mass_matrix_adapt.init(
110 math,
111 options,
112 hamiltonian.transformation_mut(),
113 state.point(),
114 rng,
115 )?;
116 self.step_size
117 .init(math, options, hamiltonian, position, rng)?;
118 Ok(())
119 }
120
121 fn adapt<R: Rng + ?Sized>(
122 &mut self,
123 math: &mut M,
124 options: &mut NutsOptions,
125 hamiltonian: &mut Self::Hamiltonian,
126 draw: u64,
127 collector: &Self::Collector,
128 state: &State<M, TransformedPoint<M>>,
129 rng: &mut R,
130 ) -> Result<(), NutsError> {
131 self.step_size.update(&collector.collector1);
132
133 if draw >= self.num_tune {
134 self.step_size.update_stepsize(rng, hamiltonian, true);
136 self.tuning = false;
137 return Ok(());
138 }
139
140 if draw < self.final_step_size_window {
141 let is_early = draw < self.early_end;
142
143 if !is_early && draw == self.early_end {
147 self.current_window_size = self
148 .current_window_size
149 .max(self.mass_matrix_adapt.background_count());
150 }
151
152 let switch_freq = if is_early {
153 self.options.early_mass_matrix_switch_freq
154 } else {
155 self.current_window_size
156 };
157
158 self.mass_matrix_adapt
159 .update_estimators(math, &collector.collector2);
160 let could_switch = self.mass_matrix_adapt.background_count() >= switch_freq;
164 let next_window_size = if is_early {
168 self.options.early_mass_matrix_switch_freq
169 } else {
170 (self.current_window_size + 1).max(
171 (self.current_window_size as f64 * self.options.mass_matrix_window_growth)
172 .round() as u64,
173 )
174 };
175 let is_late = next_window_size + draw > self.final_step_size_window;
176
177 let mut force_update = false;
178 if could_switch && (!is_late) {
179 self.mass_matrix_adapt.switch(math);
180 force_update = true;
181 if !is_early {
183 self.current_window_size = next_window_size;
184 }
185 }
186
187 let did_change = if force_update
188 | (draw - self.last_update >= self.options.mass_matrix_update_freq)
189 {
190 self.mass_matrix_adapt
191 .adapt(math, hamiltonian.transformation_mut())
192 } else {
193 false
194 };
195
196 if did_change {
197 self.last_update = draw;
198 }
199
200 if is_late {
201 self.step_size.update_estimator_late();
202 } else {
203 self.step_size.update_estimator_early();
204 }
205
206 if did_change & self.has_initial_mass_matrix {
208 self.has_initial_mass_matrix = false;
209 let position = math.box_array(state.point().position());
210 self.step_size
211 .init(math, options, hamiltonian, &position, rng)?;
212 } else {
213 self.step_size.update_stepsize(rng, hamiltonian, false)
214 }
215 return Ok(());
216 }
217
218 self.step_size.update_estimator_late();
219 let is_last = draw == self.num_tune - 1;
220 self.step_size.update_stepsize(rng, hamiltonian, is_last);
221 Ok(())
222 }
223
224 fn new_collector(&self, math: &mut M) -> Self::Collector {
225 Self::Collector::new(
226 self.step_size.new_collector(),
227 self.mass_matrix_adapt.new_collector(math),
228 )
229 }
230
231 fn is_tuning(&self) -> bool {
232 self.tuning
233 }
234
235 fn last_num_steps(&self) -> u64 {
236 self.step_size.last_n_steps
237 }
238}
239
240#[derive(Debug, Storable)]
241pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
242 #[storable(flatten)]
243 pub step_size: S,
244 #[storable(flatten)]
245 pub mass_matrix: M,
246 pub tuning: bool,
247 #[storable(ignore)]
248 _phantom: std::marker::PhantomData<fn() -> P>,
249}
250
251#[derive(Debug)]
252pub struct GlobalStrategyStatsOptions<M: Math, A: MassMatrixAdaptStrategy<M>> {
253 pub step_size: (),
254 pub mass_matrix: A::StatsOptions,
255}
256
257impl<M: Math, A: MassMatrixAdaptStrategy<M>> Clone for GlobalStrategyStatsOptions<M, A> {
258 fn clone(&self) -> Self {
259 *self
260 }
261}
262
263impl<M: Math, A: MassMatrixAdaptStrategy<M>> Copy for GlobalStrategyStatsOptions<M, A> {}
264
265impl<M: Math, A> SamplerStats<M> for GlobalStrategy<M, A>
266where
267 A: MassMatrixAdaptStrategy<M>,
268{
269 type Stats =
270 GlobalStrategyStats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats, A::Stats>;
271 type StatsOptions = GlobalStrategyStatsOptions<M, A>;
272
273 fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
274 GlobalStrategyStats {
275 step_size: {
276 let _: () = opt.step_size;
277 self.step_size.extract_stats(math, ())
278 },
279 mass_matrix: self.mass_matrix_adapt.extract_stats(math, opt.mass_matrix),
280 tuning: self.tuning,
281 _phantom: PhantomData,
282 }
283 }
284}
285
286pub struct CombinedCollector<M, P, C1, C2>
287where
288 M: Math,
289 P: Point<M>,
290 C1: Collector<M, P>,
291 C2: Collector<M, P>,
292{
293 pub collector1: C1,
294 pub collector2: C2,
295 _phantom: PhantomData<M>,
296 _phantom2: PhantomData<P>,
297}
298
299impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
300where
301 M: Math,
302 P: Point<M>,
303 C1: Collector<M, P>,
304 C2: Collector<M, P>,
305{
306 pub fn new(collector1: C1, collector2: C2) -> Self {
307 CombinedCollector {
308 collector1,
309 collector2,
310 _phantom: PhantomData,
311 _phantom2: PhantomData,
312 }
313 }
314}
315
316impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
317where
318 M: Math,
319 P: Point<M>,
320 C1: Collector<M, P>,
321 C2: Collector<M, P>,
322{
323 fn register_leapfrog(
324 &mut self,
325 math: &mut M,
326 start: &State<M, P>,
327 end: &State<M, P>,
328 divergence_info: Option<&DivergenceInfo>,
329 ) {
330 self.collector1
331 .register_leapfrog(math, start, end, divergence_info);
332 self.collector2
333 .register_leapfrog(math, start, end, divergence_info);
334 }
335
336 fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
337 self.collector1.register_draw(math, state, info);
338 self.collector2.register_draw(math, state, info);
339 }
340
341 fn register_init(
342 &mut self,
343 math: &mut M,
344 state: &State<M, P>,
345 options: &crate::nuts::NutsOptions,
346 ) {
347 self.collector1.register_init(math, state, options);
348 self.collector2.register_init(math, state, options);
349 }
350}
351
352#[cfg(test)]
353mod test {
354 use super::*;
355 use crate::math::test_logps::NormalLogp;
356 use crate::{
357 Chain, DiagAdaptExpSettings,
358 chain::{NutsChain, StatOptions},
359 dynamics::{
360 DivergenceStatsOptions, KineticEnergyKind, TransformedHamiltonian,
361 TransformedPointStatsOptions,
362 },
363 math::CpuMath,
364 transform::{DiagAdaptStrategy, DiagMassMatrix},
365 };
366
367 #[test]
368 fn instanciate_adaptive_sampler() {
369 let ndim = 10;
370 let func = NormalLogp::new(ndim, 30.);
371 let mut math = CpuMath::new(func);
372 let num_tune = 100;
373 let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
374 let strategy =
375 GlobalStrategy::<_, DiagAdaptStrategy<_>>::new(&mut math, options, num_tune, 0u64);
376
377 let mass_matrix = DiagMassMatrix::new(&mut math, true);
378
379 let hamiltonian: TransformedHamiltonian<_, DiagMassMatrix<CpuMath<NormalLogp>>> =
380 TransformedHamiltonian::new(&mut math, mass_matrix, KineticEnergyKind::Euclidean);
381
382 let options = NutsOptions {
383 maxdepth: 10u64,
384 mindepth: 0,
385 check_turning: true,
386 store_divergences: false,
387 target_integration_time: None,
388 extra_doublings: 0,
389 max_energy_error: 1000.0,
390 };
391
392 let rng = {
393 use rand::SeedableRng;
394 rand::rngs::StdRng::seed_from_u64(42)
395 };
396 let chain = 0u64;
397
398 let stats_options = StatOptions {
399 adapt: GlobalStrategyStatsOptions {
400 step_size: (),
401 mass_matrix: (),
402 },
403 hamiltonian: -1i64,
404 point: TransformedPointStatsOptions {
405 store_gradient: true,
406 store_unconstrained: true,
407 store_transformed: false,
408 },
409 divergence: DivergenceStatsOptions {
410 store_divergences: true,
411 },
412 };
413
414 let mut sampler = NutsChain::new(
415 math,
416 hamiltonian,
417 strategy,
418 options,
419 rng,
420 chain,
421 stats_options,
422 );
423 sampler.set_position(&vec![1.5f64; ndim]).unwrap();
424 for _ in 0..200 {
425 sampler.draw().unwrap();
426 }
427
428 let (last_position, _, _, prog) = sampler.expanded_draw().unwrap();
430 dbg!(&last_position);
431 for p in last_position {
432 assert!((p - 30.).abs() < 5.0);
433 }
434 assert!(!prog.diverging);
435 }
436}