Skip to main content

nuts_rs/
adapt_strategy.rs

1//! Orchestrate the tuning schedule that jointly adapts step size and mass matrix during warmup.
2
3use std::{fmt::Debug, marker::PhantomData};
4
5use nuts_derive::Storable;
6use nuts_storable::{HasDims, Storable};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10use super::stepsize::AcceptanceRateCollector;
11use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
12use crate::dynamics::{
13    DivergenceInfo, Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint,
14};
15use crate::transform::MassMatrixAdaptStrategy;
16use crate::{
17    NutsError,
18    chain::AdaptStrategy,
19    math::Math,
20    nuts::{Collector, NutsOptions},
21    sampler_stats::{SamplerStats, StatsDims},
22};
23
24pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
25    step_size: StepSizeStrategy,
26    mass_matrix_adapt: A,
27    options: EuclideanAdaptOptions<A::Options>,
28    num_tune: u64,
29    // The number of draws in the early window
30    early_end: u64,
31
32    // The first draw number for the final step size adaptation window
33    final_step_size_window: u64,
34    tuning: bool,
35    has_initial_mass_matrix: bool,
36    last_update: u64,
37    // Current target window size for the main (non-early) phase; grows after each switch.
38    current_window_size: u64,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
42pub struct EuclideanAdaptOptions<S: Debug + Default> {
43    pub step_size_settings: StepSizeSettings,
44    pub mass_matrix_options: S,
45    pub early_window: f64,
46    pub step_size_window: f64,
47    /// Initial window size for the main (non-early) mass-matrix adaptation phase.
48    pub mass_matrix_switch_freq: u64,
49    pub early_mass_matrix_switch_freq: u64,
50    pub mass_matrix_update_freq: u64,
51    /// Multiplicative growth factor applied to the window size after each switch in the
52    /// main phase. 1.0 means constant windows (old behaviour). Must be >= 1.0.
53    pub mass_matrix_window_growth: f64,
54}
55
56impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
57    fn default() -> Self {
58        Self {
59            step_size_settings: StepSizeSettings::default(),
60            mass_matrix_options: S::default(),
61            early_window: 0.3,
62            step_size_window: 0.15,
63            mass_matrix_switch_freq: 80,
64            early_mass_matrix_switch_freq: 10,
65            mass_matrix_update_freq: 1,
66            mass_matrix_window_growth: 1.5,
67        }
68    }
69}
70
71impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
72    type Hamiltonian = TransformedHamiltonian<M, A::Transformation>;
73    type Collector =
74        CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, A::Collector>;
75    type Options = EuclideanAdaptOptions<A::Options>;
76
77    fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
78        let num_tune_f = num_tune as f64;
79        let step_size_window = (options.step_size_window * num_tune_f) as u64;
80        let early_end = (options.early_window * num_tune_f) as u64;
81        let final_second_step_size = num_tune.saturating_sub(step_size_window);
82
83        assert!(early_end < num_tune);
84        assert!(options.mass_matrix_window_growth >= 1.0);
85
86        Self {
87            step_size: StepSizeStrategy::new(options.step_size_settings),
88            mass_matrix_adapt: A::new(math, options.mass_matrix_options, num_tune, chain),
89            options,
90            num_tune,
91            early_end,
92            final_step_size_window: final_second_step_size,
93            tuning: true,
94            has_initial_mass_matrix: true,
95            last_update: 0,
96            current_window_size: options.mass_matrix_switch_freq,
97        }
98    }
99
100    fn init<R: Rng + ?Sized>(
101        &mut self,
102        math: &mut M,
103        options: &mut NutsOptions,
104        hamiltonian: &mut Self::Hamiltonian,
105        position: &[f64],
106        rng: &mut R,
107    ) -> Result<(), NutsError> {
108        let state = hamiltonian.init_state_untransformed(math, position)?;
109        self.mass_matrix_adapt.init(
110            math,
111            options,
112            hamiltonian.transformation_mut(),
113            state.point(),
114            rng,
115        )?;
116        self.step_size
117            .init(math, options, hamiltonian, position, rng)?;
118        Ok(())
119    }
120
121    fn adapt<R: Rng + ?Sized>(
122        &mut self,
123        math: &mut M,
124        options: &mut NutsOptions,
125        hamiltonian: &mut Self::Hamiltonian,
126        draw: u64,
127        collector: &Self::Collector,
128        state: &State<M, TransformedPoint<M>>,
129        rng: &mut R,
130    ) -> Result<(), NutsError> {
131        self.step_size.update(&collector.collector1);
132
133        if draw >= self.num_tune {
134            // Needed for step size jitter
135            self.step_size.update_stepsize(rng, hamiltonian, true);
136            self.tuning = false;
137            return Ok(());
138        }
139
140        if draw < self.final_step_size_window {
141            let is_early = draw < self.early_end;
142
143            // At the transition from early to main phase, seed current_window_size as the
144            // maximum of the configured initial size and the background count already
145            // accumulated, so we never shrink the window.
146            if !is_early && draw == self.early_end {
147                self.current_window_size = self
148                    .current_window_size
149                    .max(self.mass_matrix_adapt.background_count());
150            }
151
152            let switch_freq = if is_early {
153                self.options.early_mass_matrix_switch_freq
154            } else {
155                self.current_window_size
156            };
157
158            self.mass_matrix_adapt
159                .update_estimators(math, &collector.collector2);
160            // We only switch if we have switch_freq draws in the background estimate,
161            // and if the number of remaining mass matrix steps is larger than
162            // the switch frequency.
163            let could_switch = self.mass_matrix_adapt.background_count() >= switch_freq;
164            // For the main phase: after switching, the *next* window will be larger, so
165            // is_late must look ahead using that next size to decide whether there is
166            // still room for another full window before the step-size window.
167            let next_window_size = if is_early {
168                self.options.early_mass_matrix_switch_freq
169            } else {
170                (self.current_window_size + 1).max(
171                    (self.current_window_size as f64 * self.options.mass_matrix_window_growth)
172                        .round() as u64,
173                )
174            };
175            let is_late = next_window_size + draw > self.final_step_size_window;
176
177            let mut force_update = false;
178            if could_switch && (!is_late) {
179                self.mass_matrix_adapt.switch(math);
180                force_update = true;
181                // Grow the window for the next main-phase switch.
182                if !is_early {
183                    self.current_window_size = next_window_size;
184                }
185            }
186
187            let did_change = if force_update
188                | (draw - self.last_update >= self.options.mass_matrix_update_freq)
189            {
190                self.mass_matrix_adapt
191                    .adapt(math, hamiltonian.transformation_mut())
192            } else {
193                false
194            };
195
196            if did_change {
197                self.last_update = draw;
198            }
199
200            if is_late {
201                self.step_size.update_estimator_late();
202            } else {
203                self.step_size.update_estimator_early();
204            }
205
206            // First time we change the mass matrix
207            if did_change & self.has_initial_mass_matrix {
208                self.has_initial_mass_matrix = false;
209                let position = math.box_array(state.point().position());
210                self.step_size
211                    .init(math, options, hamiltonian, &position, rng)?;
212            } else {
213                self.step_size.update_stepsize(rng, hamiltonian, false)
214            }
215            return Ok(());
216        }
217
218        self.step_size.update_estimator_late();
219        let is_last = draw == self.num_tune - 1;
220        self.step_size.update_stepsize(rng, hamiltonian, is_last);
221        Ok(())
222    }
223
224    fn new_collector(&self, math: &mut M) -> Self::Collector {
225        Self::Collector::new(
226            self.step_size.new_collector(),
227            self.mass_matrix_adapt.new_collector(math),
228        )
229    }
230
231    fn is_tuning(&self) -> bool {
232        self.tuning
233    }
234
235    fn last_num_steps(&self) -> u64 {
236        self.step_size.last_n_steps
237    }
238}
239
240#[derive(Debug, Storable)]
241pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
242    #[storable(flatten)]
243    pub step_size: S,
244    #[storable(flatten)]
245    pub mass_matrix: M,
246    pub tuning: bool,
247    #[storable(ignore)]
248    _phantom: std::marker::PhantomData<fn() -> P>,
249}
250
251#[derive(Debug)]
252pub struct GlobalStrategyStatsOptions<M: Math, A: MassMatrixAdaptStrategy<M>> {
253    pub step_size: (),
254    pub mass_matrix: A::StatsOptions,
255}
256
257impl<M: Math, A: MassMatrixAdaptStrategy<M>> Clone for GlobalStrategyStatsOptions<M, A> {
258    fn clone(&self) -> Self {
259        *self
260    }
261}
262
263impl<M: Math, A: MassMatrixAdaptStrategy<M>> Copy for GlobalStrategyStatsOptions<M, A> {}
264
265impl<M: Math, A> SamplerStats<M> for GlobalStrategy<M, A>
266where
267    A: MassMatrixAdaptStrategy<M>,
268{
269    type Stats =
270        GlobalStrategyStats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats, A::Stats>;
271    type StatsOptions = GlobalStrategyStatsOptions<M, A>;
272
273    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
274        GlobalStrategyStats {
275            step_size: {
276                let _: () = opt.step_size;
277                self.step_size.extract_stats(math, ())
278            },
279            mass_matrix: self.mass_matrix_adapt.extract_stats(math, opt.mass_matrix),
280            tuning: self.tuning,
281            _phantom: PhantomData,
282        }
283    }
284}
285
286pub struct CombinedCollector<M, P, C1, C2>
287where
288    M: Math,
289    P: Point<M>,
290    C1: Collector<M, P>,
291    C2: Collector<M, P>,
292{
293    pub collector1: C1,
294    pub collector2: C2,
295    _phantom: PhantomData<M>,
296    _phantom2: PhantomData<P>,
297}
298
299impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
300where
301    M: Math,
302    P: Point<M>,
303    C1: Collector<M, P>,
304    C2: Collector<M, P>,
305{
306    pub fn new(collector1: C1, collector2: C2) -> Self {
307        CombinedCollector {
308            collector1,
309            collector2,
310            _phantom: PhantomData,
311            _phantom2: PhantomData,
312        }
313    }
314}
315
316impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
317where
318    M: Math,
319    P: Point<M>,
320    C1: Collector<M, P>,
321    C2: Collector<M, P>,
322{
323    fn register_leapfrog(
324        &mut self,
325        math: &mut M,
326        start: &State<M, P>,
327        end: &State<M, P>,
328        divergence_info: Option<&DivergenceInfo>,
329    ) {
330        self.collector1
331            .register_leapfrog(math, start, end, divergence_info);
332        self.collector2
333            .register_leapfrog(math, start, end, divergence_info);
334    }
335
336    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
337        self.collector1.register_draw(math, state, info);
338        self.collector2.register_draw(math, state, info);
339    }
340
341    fn register_init(
342        &mut self,
343        math: &mut M,
344        state: &State<M, P>,
345        options: &crate::nuts::NutsOptions,
346    ) {
347        self.collector1.register_init(math, state, options);
348        self.collector2.register_init(math, state, options);
349    }
350}
351
352#[cfg(test)]
353pub mod test_logps {
354    use std::collections::HashMap;
355
356    use crate::math::{CpuLogpFunc, LogpError};
357    use nuts_storable::HasDims;
358    use thiserror::Error;
359
360    #[derive(Clone, Debug)]
361    pub struct NormalLogp {
362        dim: usize,
363        mu: f64,
364    }
365
366    impl NormalLogp {
367        pub(crate) fn new(dim: usize, mu: f64) -> NormalLogp {
368            NormalLogp { dim, mu }
369        }
370    }
371
372    #[derive(Error, Debug)]
373    pub enum NormalLogpError {}
374
375    impl LogpError for NormalLogpError {
376        fn is_recoverable(&self) -> bool {
377            false
378        }
379    }
380
381    impl HasDims for NormalLogp {
382        fn dim_sizes(&self) -> HashMap<String, u64> {
383            vec![("unconstrained_parameter".to_string(), self.dim as u64)]
384                .into_iter()
385                .collect()
386        }
387    }
388
389    impl CpuLogpFunc for NormalLogp {
390        type LogpError = NormalLogpError;
391        type FlowParameters = ();
392        type ExpandedVector = Vec<f64>;
393
394        fn dim(&self) -> usize {
395            self.dim
396        }
397        fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
398            let n = position.len();
399            assert!(gradient.len() == n);
400
401            let mut logp = 0f64;
402            for (p, g) in position.iter().zip(gradient.iter_mut()) {
403                let val = *p - self.mu;
404                logp -= val * val / 2.;
405                *g = -val;
406            }
407            Ok(logp)
408        }
409
410        fn expand_vector<R>(
411            &mut self,
412            _rng: &mut R,
413            array: &[f64],
414        ) -> Result<Self::ExpandedVector, crate::math::CpuMathError>
415        where
416            R: rand::Rng + ?Sized,
417        {
418            Ok(array.to_vec())
419        }
420    }
421}
422
423#[cfg(test)]
424mod test {
425    use super::test_logps::NormalLogp;
426    use super::*;
427    use crate::{
428        Chain, DiagAdaptExpSettings,
429        chain::{NutsChain, StatOptions},
430        dynamics::{
431            DivergenceStatsOptions, KineticEnergyKind, TransformedHamiltonian,
432            TransformedPointStatsOptions,
433        },
434        math::CpuMath,
435        transform::{DiagAdaptStrategy, DiagMassMatrix},
436    };
437
438    #[test]
439    fn instanciate_adaptive_sampler() {
440        let ndim = 10;
441        let func = NormalLogp::new(ndim, 30.);
442        let mut math = CpuMath::new(func);
443        let num_tune = 100;
444        let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
445        let strategy =
446            GlobalStrategy::<_, DiagAdaptStrategy<_>>::new(&mut math, options, num_tune, 0u64);
447
448        let mass_matrix = DiagMassMatrix::new(&mut math, true);
449
450        let hamiltonian: TransformedHamiltonian<_, DiagMassMatrix<CpuMath<NormalLogp>>> =
451            TransformedHamiltonian::new(&mut math, mass_matrix, KineticEnergyKind::Euclidean);
452
453        let options = NutsOptions {
454            maxdepth: 10u64,
455            mindepth: 0,
456            check_turning: true,
457            store_divergences: false,
458            target_integration_time: None,
459            extra_doublings: 0,
460            max_energy_error: 1000.0,
461        };
462
463        let rng = {
464            use rand::SeedableRng;
465            rand::rngs::StdRng::seed_from_u64(42)
466        };
467        let chain = 0u64;
468
469        let stats_options = StatOptions {
470            adapt: GlobalStrategyStatsOptions {
471                step_size: (),
472                mass_matrix: (),
473            },
474            hamiltonian: -1i64,
475            point: TransformedPointStatsOptions {
476                store_gradient: true,
477                store_unconstrained: true,
478                store_transformed: false,
479            },
480            divergence: DivergenceStatsOptions {
481                store_divergences: true,
482            },
483        };
484
485        let mut sampler = NutsChain::new(
486            math,
487            hamiltonian,
488            strategy,
489            options,
490            rng,
491            chain,
492            stats_options,
493        );
494        sampler.set_position(&vec![1.5f64; ndim]).unwrap();
495        for _ in 0..200 {
496            sampler.draw().unwrap();
497        }
498
499        // Check that we arrive at 3
500        let (last_position, _, _, prog) = sampler.expanded_draw().unwrap();
501        dbg!(&last_position);
502        for p in last_position {
503            assert!((p - 30.).abs() < 5.0);
504        }
505        assert!(!prog.diverging);
506    }
507}