1use std::{fmt::Debug, marker::PhantomData};
4
5use nuts_derive::Storable;
6use nuts_storable::{HasDims, Storable};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10use super::stepsize::AcceptanceRateCollector;
11use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
12use crate::dynamics::{
13 DivergenceInfo, Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint,
14};
15use crate::transform::MassMatrixAdaptStrategy;
16use crate::{
17 NutsError,
18 chain::AdaptStrategy,
19 math::Math,
20 nuts::{Collector, NutsOptions},
21 sampler_stats::{SamplerStats, StatsDims},
22};
23
24pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
25 step_size: StepSizeStrategy,
26 mass_matrix_adapt: A,
27 options: EuclideanAdaptOptions<A::Options>,
28 num_tune: u64,
29 early_end: u64,
31
32 final_step_size_window: u64,
34 tuning: bool,
35 has_initial_mass_matrix: bool,
36 last_update: u64,
37 current_window_size: u64,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
42pub struct EuclideanAdaptOptions<S: Debug + Default> {
43 pub step_size_settings: StepSizeSettings,
44 pub mass_matrix_options: S,
45 pub early_window: f64,
46 pub step_size_window: f64,
47 pub mass_matrix_switch_freq: u64,
49 pub early_mass_matrix_switch_freq: u64,
50 pub mass_matrix_update_freq: u64,
51 pub mass_matrix_window_growth: f64,
54}
55
56impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
57 fn default() -> Self {
58 Self {
59 step_size_settings: StepSizeSettings::default(),
60 mass_matrix_options: S::default(),
61 early_window: 0.3,
62 step_size_window: 0.15,
63 mass_matrix_switch_freq: 80,
64 early_mass_matrix_switch_freq: 10,
65 mass_matrix_update_freq: 1,
66 mass_matrix_window_growth: 1.5,
67 }
68 }
69}
70
71impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
72 type Hamiltonian = TransformedHamiltonian<M, A::Transformation>;
73 type Collector =
74 CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, A::Collector>;
75 type Options = EuclideanAdaptOptions<A::Options>;
76
77 fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
78 let num_tune_f = num_tune as f64;
79 let step_size_window = (options.step_size_window * num_tune_f) as u64;
80 let early_end = (options.early_window * num_tune_f) as u64;
81 let final_second_step_size = num_tune.saturating_sub(step_size_window);
82
83 assert!(early_end < num_tune);
84 assert!(options.mass_matrix_window_growth >= 1.0);
85
86 Self {
87 step_size: StepSizeStrategy::new(options.step_size_settings),
88 mass_matrix_adapt: A::new(math, options.mass_matrix_options, num_tune, chain),
89 options,
90 num_tune,
91 early_end,
92 final_step_size_window: final_second_step_size,
93 tuning: true,
94 has_initial_mass_matrix: true,
95 last_update: 0,
96 current_window_size: options.mass_matrix_switch_freq,
97 }
98 }
99
100 fn init<R: Rng + ?Sized>(
101 &mut self,
102 math: &mut M,
103 options: &mut NutsOptions,
104 hamiltonian: &mut Self::Hamiltonian,
105 position: &[f64],
106 rng: &mut R,
107 ) -> Result<(), NutsError> {
108 let state = hamiltonian.init_state_untransformed(math, position)?;
109 self.mass_matrix_adapt.init(
110 math,
111 options,
112 hamiltonian.transformation_mut(),
113 state.point(),
114 rng,
115 )?;
116 self.step_size
117 .init(math, options, hamiltonian, position, rng)?;
118 Ok(())
119 }
120
121 fn adapt<R: Rng + ?Sized>(
122 &mut self,
123 math: &mut M,
124 options: &mut NutsOptions,
125 hamiltonian: &mut Self::Hamiltonian,
126 draw: u64,
127 collector: &Self::Collector,
128 state: &State<M, TransformedPoint<M>>,
129 rng: &mut R,
130 ) -> Result<(), NutsError> {
131 self.step_size.update(&collector.collector1);
132
133 if draw >= self.num_tune {
134 self.step_size.update_stepsize(rng, hamiltonian, true);
136 self.tuning = false;
137 return Ok(());
138 }
139
140 if draw < self.final_step_size_window {
141 let is_early = draw < self.early_end;
142
143 if !is_early && draw == self.early_end {
147 self.current_window_size = self
148 .current_window_size
149 .max(self.mass_matrix_adapt.background_count());
150 }
151
152 let switch_freq = if is_early {
153 self.options.early_mass_matrix_switch_freq
154 } else {
155 self.current_window_size
156 };
157
158 self.mass_matrix_adapt
159 .update_estimators(math, &collector.collector2);
160 let could_switch = self.mass_matrix_adapt.background_count() >= switch_freq;
164 let next_window_size = if is_early {
168 self.options.early_mass_matrix_switch_freq
169 } else {
170 (self.current_window_size + 1).max(
171 (self.current_window_size as f64 * self.options.mass_matrix_window_growth)
172 .round() as u64,
173 )
174 };
175 let is_late = next_window_size + draw > self.final_step_size_window;
176
177 let mut force_update = false;
178 if could_switch && (!is_late) {
179 self.mass_matrix_adapt.switch(math);
180 force_update = true;
181 if !is_early {
183 self.current_window_size = next_window_size;
184 }
185 }
186
187 let did_change = if force_update
188 | (draw - self.last_update >= self.options.mass_matrix_update_freq)
189 {
190 self.mass_matrix_adapt
191 .adapt(math, hamiltonian.transformation_mut())
192 } else {
193 false
194 };
195
196 if did_change {
197 self.last_update = draw;
198 }
199
200 if is_late {
201 self.step_size.update_estimator_late();
202 } else {
203 self.step_size.update_estimator_early();
204 }
205
206 if did_change & self.has_initial_mass_matrix {
208 self.has_initial_mass_matrix = false;
209 let position = math.box_array(state.point().position());
210 self.step_size
211 .init(math, options, hamiltonian, &position, rng)?;
212 } else {
213 self.step_size.update_stepsize(rng, hamiltonian, false)
214 }
215 return Ok(());
216 }
217
218 self.step_size.update_estimator_late();
219 let is_last = draw == self.num_tune - 1;
220 self.step_size.update_stepsize(rng, hamiltonian, is_last);
221 Ok(())
222 }
223
224 fn new_collector(&self, math: &mut M) -> Self::Collector {
225 Self::Collector::new(
226 self.step_size.new_collector(),
227 self.mass_matrix_adapt.new_collector(math),
228 )
229 }
230
231 fn is_tuning(&self) -> bool {
232 self.tuning
233 }
234
235 fn last_num_steps(&self) -> u64 {
236 self.step_size.last_n_steps
237 }
238}
239
240#[derive(Debug, Storable)]
241pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
242 #[storable(flatten)]
243 pub step_size: S,
244 #[storable(flatten)]
245 pub mass_matrix: M,
246 pub tuning: bool,
247 #[storable(ignore)]
248 _phantom: std::marker::PhantomData<fn() -> P>,
249}
250
251#[derive(Debug)]
252pub struct GlobalStrategyStatsOptions<M: Math, A: MassMatrixAdaptStrategy<M>> {
253 pub step_size: (),
254 pub mass_matrix: A::StatsOptions,
255}
256
257impl<M: Math, A: MassMatrixAdaptStrategy<M>> Clone for GlobalStrategyStatsOptions<M, A> {
258 fn clone(&self) -> Self {
259 *self
260 }
261}
262
263impl<M: Math, A: MassMatrixAdaptStrategy<M>> Copy for GlobalStrategyStatsOptions<M, A> {}
264
265impl<M: Math, A> SamplerStats<M> for GlobalStrategy<M, A>
266where
267 A: MassMatrixAdaptStrategy<M>,
268{
269 type Stats =
270 GlobalStrategyStats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats, A::Stats>;
271 type StatsOptions = GlobalStrategyStatsOptions<M, A>;
272
273 fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
274 GlobalStrategyStats {
275 step_size: {
276 let _: () = opt.step_size;
277 self.step_size.extract_stats(math, ())
278 },
279 mass_matrix: self.mass_matrix_adapt.extract_stats(math, opt.mass_matrix),
280 tuning: self.tuning,
281 _phantom: PhantomData,
282 }
283 }
284}
285
286pub struct CombinedCollector<M, P, C1, C2>
287where
288 M: Math,
289 P: Point<M>,
290 C1: Collector<M, P>,
291 C2: Collector<M, P>,
292{
293 pub collector1: C1,
294 pub collector2: C2,
295 _phantom: PhantomData<M>,
296 _phantom2: PhantomData<P>,
297}
298
299impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
300where
301 M: Math,
302 P: Point<M>,
303 C1: Collector<M, P>,
304 C2: Collector<M, P>,
305{
306 pub fn new(collector1: C1, collector2: C2) -> Self {
307 CombinedCollector {
308 collector1,
309 collector2,
310 _phantom: PhantomData,
311 _phantom2: PhantomData,
312 }
313 }
314}
315
316impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
317where
318 M: Math,
319 P: Point<M>,
320 C1: Collector<M, P>,
321 C2: Collector<M, P>,
322{
323 fn register_leapfrog(
324 &mut self,
325 math: &mut M,
326 start: &State<M, P>,
327 end: &State<M, P>,
328 divergence_info: Option<&DivergenceInfo>,
329 ) {
330 self.collector1
331 .register_leapfrog(math, start, end, divergence_info);
332 self.collector2
333 .register_leapfrog(math, start, end, divergence_info);
334 }
335
336 fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
337 self.collector1.register_draw(math, state, info);
338 self.collector2.register_draw(math, state, info);
339 }
340
341 fn register_init(
342 &mut self,
343 math: &mut M,
344 state: &State<M, P>,
345 options: &crate::nuts::NutsOptions,
346 ) {
347 self.collector1.register_init(math, state, options);
348 self.collector2.register_init(math, state, options);
349 }
350}
351
352#[cfg(test)]
353pub mod test_logps {
354 use std::collections::HashMap;
355
356 use crate::math::{CpuLogpFunc, LogpError};
357 use nuts_storable::HasDims;
358 use thiserror::Error;
359
360 #[derive(Clone, Debug)]
361 pub struct NormalLogp {
362 dim: usize,
363 mu: f64,
364 }
365
366 impl NormalLogp {
367 pub(crate) fn new(dim: usize, mu: f64) -> NormalLogp {
368 NormalLogp { dim, mu }
369 }
370 }
371
372 #[derive(Error, Debug)]
373 pub enum NormalLogpError {}
374
375 impl LogpError for NormalLogpError {
376 fn is_recoverable(&self) -> bool {
377 false
378 }
379 }
380
381 impl HasDims for NormalLogp {
382 fn dim_sizes(&self) -> HashMap<String, u64> {
383 vec![("unconstrained_parameter".to_string(), self.dim as u64)]
384 .into_iter()
385 .collect()
386 }
387 }
388
389 impl CpuLogpFunc for NormalLogp {
390 type LogpError = NormalLogpError;
391 type FlowParameters = ();
392 type ExpandedVector = Vec<f64>;
393
394 fn dim(&self) -> usize {
395 self.dim
396 }
397 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
398 let n = position.len();
399 assert!(gradient.len() == n);
400
401 let mut logp = 0f64;
402 for (p, g) in position.iter().zip(gradient.iter_mut()) {
403 let val = *p - self.mu;
404 logp -= val * val / 2.;
405 *g = -val;
406 }
407 Ok(logp)
408 }
409
410 fn expand_vector<R>(
411 &mut self,
412 _rng: &mut R,
413 array: &[f64],
414 ) -> Result<Self::ExpandedVector, crate::math::CpuMathError>
415 where
416 R: rand::Rng + ?Sized,
417 {
418 Ok(array.to_vec())
419 }
420 }
421}
422
423#[cfg(test)]
424mod test {
425 use super::test_logps::NormalLogp;
426 use super::*;
427 use crate::{
428 Chain, DiagAdaptExpSettings,
429 chain::{NutsChain, StatOptions},
430 dynamics::{
431 DivergenceStatsOptions, KineticEnergyKind, TransformedHamiltonian,
432 TransformedPointStatsOptions,
433 },
434 math::CpuMath,
435 transform::{DiagAdaptStrategy, DiagMassMatrix},
436 };
437
438 #[test]
439 fn instanciate_adaptive_sampler() {
440 let ndim = 10;
441 let func = NormalLogp::new(ndim, 30.);
442 let mut math = CpuMath::new(func);
443 let num_tune = 100;
444 let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
445 let strategy =
446 GlobalStrategy::<_, DiagAdaptStrategy<_>>::new(&mut math, options, num_tune, 0u64);
447
448 let mass_matrix = DiagMassMatrix::new(&mut math, true);
449
450 let hamiltonian: TransformedHamiltonian<_, DiagMassMatrix<CpuMath<NormalLogp>>> =
451 TransformedHamiltonian::new(&mut math, mass_matrix, KineticEnergyKind::Euclidean);
452
453 let options = NutsOptions {
454 maxdepth: 10u64,
455 mindepth: 0,
456 check_turning: true,
457 store_divergences: false,
458 target_integration_time: None,
459 extra_doublings: 0,
460 max_energy_error: 1000.0,
461 };
462
463 let rng = {
464 use rand::SeedableRng;
465 rand::rngs::StdRng::seed_from_u64(42)
466 };
467 let chain = 0u64;
468
469 let stats_options = StatOptions {
470 adapt: GlobalStrategyStatsOptions {
471 step_size: (),
472 mass_matrix: (),
473 },
474 hamiltonian: -1i64,
475 point: TransformedPointStatsOptions {
476 store_gradient: true,
477 store_unconstrained: true,
478 store_transformed: false,
479 },
480 divergence: DivergenceStatsOptions {
481 store_divergences: true,
482 },
483 };
484
485 let mut sampler = NutsChain::new(
486 math,
487 hamiltonian,
488 strategy,
489 options,
490 rng,
491 chain,
492 stats_options,
493 );
494 sampler.set_position(&vec![1.5f64; ndim]).unwrap();
495 for _ in 0..200 {
496 sampler.draw().unwrap();
497 }
498
499 let (last_position, _, _, prog) = sampler.expanded_draw().unwrap();
501 dbg!(&last_position);
502 for p in last_position {
503 assert!((p - 30.).abs() < 5.0);
504 }
505 assert!(!prog.diverging);
506 }
507}