1use std::{fmt::Debug, marker::PhantomData};
2
3use nuts_derive::Storable;
4use nuts_storable::{HasDims, Storable};
5use rand::Rng;
6use serde::Serialize;
7
8use super::stepsize::AcceptanceRateCollector;
9use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
10use crate::mass_matrix::MassMatrixAdaptStrategy;
11use crate::{
12 NutsError,
13 chain::AdaptStrategy,
14 euclidean_hamiltonian::EuclideanHamiltonian,
15 hamiltonian::{DivergenceInfo, Hamiltonian, Point},
16 math_base::Math,
17 nuts::{Collector, NutsOptions},
18 sampler_stats::{SamplerStats, StatsDims},
19 state::State,
20};
21
22pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
23 step_size: StepSizeStrategy,
24 mass_matrix: A,
25 options: EuclideanAdaptOptions<A::Options>,
26 num_tune: u64,
27 early_end: u64,
29
30 final_step_size_window: u64,
32 tuning: bool,
33 has_initial_mass_matrix: bool,
34 last_update: u64,
35}
36
37#[derive(Debug, Clone, Copy, Serialize)]
38pub struct EuclideanAdaptOptions<S: Debug + Default> {
39 pub step_size_settings: StepSizeSettings,
40 pub mass_matrix_options: S,
41 pub early_window: f64,
42 pub step_size_window: f64,
43 pub mass_matrix_switch_freq: u64,
44 pub early_mass_matrix_switch_freq: u64,
45 pub mass_matrix_update_freq: u64,
46}
47
48impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
49 fn default() -> Self {
50 Self {
51 step_size_settings: StepSizeSettings::default(),
52 mass_matrix_options: S::default(),
53 early_window: 0.3,
54 step_size_window: 0.15,
55 mass_matrix_switch_freq: 80,
56 early_mass_matrix_switch_freq: 10,
57 mass_matrix_update_freq: 1,
58 }
59 }
60}
61
62impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
63 type Hamiltonian = EuclideanHamiltonian<M, A::MassMatrix>;
64 type Collector = CombinedCollector<
65 M,
66 <Self::Hamiltonian as Hamiltonian<M>>::Point,
67 AcceptanceRateCollector,
68 A::Collector,
69 >;
70 type Options = EuclideanAdaptOptions<A::Options>;
71
72 fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
73 let num_tune_f = num_tune as f64;
74 let step_size_window = (options.step_size_window * num_tune_f) as u64;
75 let early_end = (options.early_window * num_tune_f) as u64;
76 let final_second_step_size = num_tune.saturating_sub(step_size_window);
77
78 assert!(early_end < num_tune);
79
80 Self {
81 step_size: StepSizeStrategy::new(options.step_size_settings),
82 mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain),
83 options,
84 num_tune,
85 early_end,
86 final_step_size_window: final_second_step_size,
87 tuning: true,
88 has_initial_mass_matrix: true,
89 last_update: 0,
90 }
91 }
92
93 fn init<R: Rng + ?Sized>(
94 &mut self,
95 math: &mut M,
96 options: &mut NutsOptions,
97 hamiltonian: &mut Self::Hamiltonian,
98 position: &[f64],
99 rng: &mut R,
100 ) -> Result<(), NutsError> {
101 let state = hamiltonian.init_state(math, position)?;
102 self.mass_matrix.init(
103 math,
104 options,
105 &mut hamiltonian.mass_matrix,
106 state.point(),
107 rng,
108 )?;
109 self.step_size
110 .init(math, options, hamiltonian, position, rng)?;
111 Ok(())
112 }
113
114 fn adapt<R: Rng + ?Sized>(
115 &mut self,
116 math: &mut M,
117 options: &mut NutsOptions,
118 hamiltonian: &mut Self::Hamiltonian,
119 draw: u64,
120 collector: &Self::Collector,
121 state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
122 rng: &mut R,
123 ) -> Result<(), NutsError> {
124 self.step_size.update(&collector.collector1);
125
126 if draw >= self.num_tune {
127 self.step_size.update_stepsize(rng, hamiltonian, true);
129 self.tuning = false;
130 return Ok(());
131 }
132
133 if draw < self.final_step_size_window {
134 let is_early = draw < self.early_end;
135 let switch_freq = if is_early {
136 self.options.early_mass_matrix_switch_freq
137 } else {
138 self.options.mass_matrix_switch_freq
139 };
140
141 self.mass_matrix
142 .update_estimators(math, &collector.collector2);
143 let could_switch = self.mass_matrix.background_count() >= switch_freq;
147 let is_late = switch_freq + draw > self.final_step_size_window;
148
149 let mut force_update = false;
150 if could_switch && (!is_late) {
151 self.mass_matrix.switch(math);
152 force_update = true;
153 }
154
155 let did_change = if force_update
156 | (draw - self.last_update >= self.options.mass_matrix_update_freq)
157 {
158 self.mass_matrix.adapt(math, &mut hamiltonian.mass_matrix)
159 } else {
160 false
161 };
162
163 if did_change {
164 self.last_update = draw;
165 }
166
167 if is_late {
168 self.step_size.update_estimator_late();
169 } else {
170 self.step_size.update_estimator_early();
171 }
172
173 if did_change & self.has_initial_mass_matrix {
175 self.has_initial_mass_matrix = false;
176 let position = math.box_array(state.point().position());
177 self.step_size
178 .init(math, options, hamiltonian, &position, rng)?;
179 } else {
180 self.step_size.update_stepsize(rng, hamiltonian, false)
181 }
182 return Ok(());
183 }
184
185 self.step_size.update_estimator_late();
186 let is_last = draw == self.num_tune - 1;
187 self.step_size.update_stepsize(rng, hamiltonian, is_last);
188 Ok(())
189 }
190
191 fn new_collector(&self, math: &mut M) -> Self::Collector {
192 Self::Collector::new(
193 self.step_size.new_collector(),
194 self.mass_matrix.new_collector(math),
195 )
196 }
197
198 fn is_tuning(&self) -> bool {
199 self.tuning
200 }
201
202 fn last_num_steps(&self) -> u64 {
203 self.step_size.last_n_steps
204 }
205}
206
207#[derive(Debug, Storable)]
208pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
209 #[storable(flatten)]
210 pub step_size: S,
211 #[storable(flatten)]
212 pub mass_matrix: M,
213 pub tuning: bool,
214 #[storable(ignore)]
215 _phantom: std::marker::PhantomData<fn() -> P>,
216}
217
218#[derive(Debug)]
219pub struct GlobalStrategyStatsOptions<M: Math, A: MassMatrixAdaptStrategy<M>> {
220 pub step_size: (),
221 pub mass_matrix: A::StatsOptions,
222}
223
224impl<M: Math, A: MassMatrixAdaptStrategy<M>> Clone for GlobalStrategyStatsOptions<M, A> {
225 fn clone(&self) -> Self {
226 *self
227 }
228}
229
230impl<M: Math, A: MassMatrixAdaptStrategy<M>> Copy for GlobalStrategyStatsOptions<M, A> {}
231
232impl<M: Math, A> SamplerStats<M> for GlobalStrategy<M, A>
233where
234 A: MassMatrixAdaptStrategy<M>,
235{
236 type Stats =
237 GlobalStrategyStats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats, A::Stats>;
238 type StatsOptions = GlobalStrategyStatsOptions<M, A>;
239
240 fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
241 GlobalStrategyStats {
242 step_size: {
243 let _: () = opt.step_size;
244 self.step_size.extract_stats(math, ())
245 },
246 mass_matrix: self.mass_matrix.extract_stats(math, opt.mass_matrix),
247 tuning: self.tuning,
248 _phantom: PhantomData,
249 }
250 }
251}
252
253pub struct CombinedCollector<M, P, C1, C2>
254where
255 M: Math,
256 P: Point<M>,
257 C1: Collector<M, P>,
258 C2: Collector<M, P>,
259{
260 pub collector1: C1,
261 pub collector2: C2,
262 _phantom: PhantomData<M>,
263 _phantom2: PhantomData<P>,
264}
265
266impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
267where
268 M: Math,
269 P: Point<M>,
270 C1: Collector<M, P>,
271 C2: Collector<M, P>,
272{
273 pub fn new(collector1: C1, collector2: C2) -> Self {
274 CombinedCollector {
275 collector1,
276 collector2,
277 _phantom: PhantomData,
278 _phantom2: PhantomData,
279 }
280 }
281}
282
283impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
284where
285 M: Math,
286 P: Point<M>,
287 C1: Collector<M, P>,
288 C2: Collector<M, P>,
289{
290 fn register_leapfrog(
291 &mut self,
292 math: &mut M,
293 start: &State<M, P>,
294 end: &State<M, P>,
295 divergence_info: Option<&DivergenceInfo>,
296 ) {
297 self.collector1
298 .register_leapfrog(math, start, end, divergence_info);
299 self.collector2
300 .register_leapfrog(math, start, end, divergence_info);
301 }
302
303 fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
304 self.collector1.register_draw(math, state, info);
305 self.collector2.register_draw(math, state, info);
306 }
307
308 fn register_init(
309 &mut self,
310 math: &mut M,
311 state: &State<M, P>,
312 options: &crate::nuts::NutsOptions,
313 ) {
314 self.collector1.register_init(math, state, options);
315 self.collector2.register_init(math, state, options);
316 }
317}
318
319#[cfg(test)]
320pub mod test_logps {
321 use std::collections::HashMap;
322
323 use crate::{cpu_math::CpuLogpFunc, math_base::LogpError};
324 use nuts_storable::HasDims;
325 use thiserror::Error;
326
327 #[derive(Clone, Debug)]
328 pub struct NormalLogp {
329 dim: usize,
330 mu: f64,
331 }
332
333 impl NormalLogp {
334 pub(crate) fn new(dim: usize, mu: f64) -> NormalLogp {
335 NormalLogp { dim, mu }
336 }
337 }
338
339 #[derive(Error, Debug)]
340 pub enum NormalLogpError {}
341
342 impl LogpError for NormalLogpError {
343 fn is_recoverable(&self) -> bool {
344 false
345 }
346 }
347
348 impl HasDims for NormalLogp {
349 fn dim_sizes(&self) -> HashMap<String, u64> {
350 vec![("unconstrained_parameter".to_string(), self.dim as u64)]
351 .into_iter()
352 .collect()
353 }
354 }
355
356 impl CpuLogpFunc for NormalLogp {
357 type LogpError = NormalLogpError;
358 type FlowParameters = ();
359 type ExpandedVector = Vec<f64>;
360
361 fn dim(&self) -> usize {
362 self.dim
363 }
364 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
365 let n = position.len();
366 assert!(gradient.len() == n);
367
368 let mut logp = 0f64;
369 for (p, g) in position.iter().zip(gradient.iter_mut()) {
370 let val = *p - self.mu;
371 logp -= val * val / 2.;
372 *g = -val;
373 }
374 Ok(logp)
375 }
376
377 fn expand_vector<R>(
378 &mut self,
379 _rng: &mut R,
380 array: &[f64],
381 ) -> Result<Self::ExpandedVector, crate::cpu_math::CpuMathError>
382 where
383 R: rand::Rng + ?Sized,
384 {
385 Ok(array.to_vec())
386 }
387
388 fn inv_transform_normalize(
389 &mut self,
390 _params: &Self::FlowParameters,
391 _untransformed_position: &[f64],
392 _untransofrmed_gradient: &[f64],
393 _transformed_position: &mut [f64],
394 _transformed_gradient: &mut [f64],
395 ) -> Result<f64, Self::LogpError> {
396 unimplemented!()
397 }
398
399 fn init_from_transformed_position(
400 &mut self,
401 _params: &Self::FlowParameters,
402 _untransformed_position: &mut [f64],
403 _untransformed_gradient: &mut [f64],
404 _transformed_position: &[f64],
405 _transformed_gradient: &mut [f64],
406 ) -> Result<(f64, f64), Self::LogpError> {
407 unimplemented!()
408 }
409
410 fn init_from_untransformed_position(
411 &mut self,
412 _params: &Self::FlowParameters,
413 _untransformed_position: &[f64],
414 _untransformed_gradient: &mut [f64],
415 _transformed_position: &mut [f64],
416 _transformed_gradient: &mut [f64],
417 ) -> Result<(f64, f64), Self::LogpError> {
418 unimplemented!()
419 }
420
421 fn update_transformation<'a, R: rand::Rng + ?Sized>(
422 &'a mut self,
423 _rng: &mut R,
424 _untransformed_positions: impl Iterator<Item = &'a [f64]>,
425 _untransformed_gradients: impl Iterator<Item = &'a [f64]>,
426 _untransformed_logp: impl Iterator<Item = &'a f64>,
427 _params: &'a mut Self::FlowParameters,
428 ) -> Result<(), Self::LogpError> {
429 unimplemented!()
430 }
431
432 fn new_transformation<R: rand::Rng + ?Sized>(
433 &mut self,
434 _rng: &mut R,
435 _untransformed_position: &[f64],
436 _untransfogmed_gradient: &[f64],
437 _chain: u64,
438 ) -> Result<Self::FlowParameters, Self::LogpError> {
439 unimplemented!()
440 }
441
442 fn transformation_id(
443 &self,
444 _params: &Self::FlowParameters,
445 ) -> Result<i64, Self::LogpError> {
446 unimplemented!()
447 }
448 }
449}
450
451#[cfg(test)]
452mod test {
453 use super::test_logps::NormalLogp;
454 use super::*;
455 use crate::{
456 Chain, DiagAdaptExpSettings,
457 chain::{NutsChain, StatOptions},
458 cpu_math::CpuMath,
459 euclidean_hamiltonian::EuclideanHamiltonian,
460 mass_matrix::DiagMassMatrix,
461 };
462
463 #[test]
464 fn instanciate_adaptive_sampler() {
465 use crate::mass_matrix::Strategy;
466
467 let ndim = 10;
468 let func = NormalLogp::new(ndim, 3.);
469 let mut math = CpuMath::new(func);
470 let num_tune = 100;
471 let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
472 let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune, 0u64);
473
474 let mass_matrix = DiagMassMatrix::new(&mut math, true);
475 let max_energy_error = 1000f64;
476 let step_size = 0.1f64;
477
478 let hamiltonian =
479 EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size);
480 let options = NutsOptions {
481 maxdepth: 10u64,
482 mindepth: 0,
483 store_gradient: true,
484 store_unconstrained: true,
485 check_turning: true,
486 store_divergences: false,
487 };
488
489 let rng = {
490 use rand::SeedableRng;
491 rand::rngs::StdRng::seed_from_u64(42)
492 };
493 let chain = 0u64;
494
495 let stats_options = StatOptions {
496 adapt: GlobalStrategyStatsOptions {
497 step_size: (),
498 mass_matrix: (),
499 },
500 hamiltonian: (),
501 point: (),
502 };
503
504 let mut sampler = NutsChain::new(
505 math,
506 hamiltonian,
507 strategy,
508 options,
509 rng,
510 chain,
511 stats_options,
512 );
513 sampler.set_position(&vec![1.5f64; ndim]).unwrap();
514 for _ in 0..200 {
515 sampler.draw().unwrap();
516 }
517 }
518}