Skip to main content

nuts_rs/
adapt_strategy.rs

1//! Orchestrate the tuning schedule that jointly adapts step size and mass matrix during warmup.
2
3use std::{fmt::Debug, marker::PhantomData};
4
5use nuts_derive::Storable;
6use nuts_storable::{HasDims, Storable};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10use super::stepsize::AcceptanceRateCollector;
11use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
12use crate::dynamics::{
13    DivergenceInfo, Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint,
14};
15use crate::transform::MassMatrixAdaptStrategy;
16use crate::{
17    NutsError,
18    chain::AdaptStrategy,
19    math::Math,
20    nuts::{Collector, NutsOptions},
21    sampler_stats::{SamplerStats, StatsDims},
22};
23
24pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
25    step_size: StepSizeStrategy,
26    mass_matrix_adapt: A,
27    options: EuclideanAdaptOptions<A::Options>,
28    num_tune: u64,
29    // The number of draws in the early window
30    early_end: u64,
31
32    // The first draw number for the final step size adaptation window
33    final_step_size_window: u64,
34    tuning: bool,
35    has_initial_mass_matrix: bool,
36    last_update: u64,
37    // Current target window size for the main (non-early) phase; grows after each switch.
38    current_window_size: u64,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
42pub struct EuclideanAdaptOptions<S: Debug + Default> {
43    pub step_size_settings: StepSizeSettings,
44    pub mass_matrix_options: S,
45    pub early_window: f64,
46    pub step_size_window: f64,
47    /// Initial window size for the main (non-early) mass-matrix adaptation phase.
48    pub mass_matrix_switch_freq: u64,
49    pub early_mass_matrix_switch_freq: u64,
50    pub mass_matrix_update_freq: u64,
51    /// Multiplicative growth factor applied to the window size after each switch in the
52    /// main phase. 1.0 means constant windows (old behaviour). Must be >= 1.0.
53    pub mass_matrix_window_growth: f64,
54}
55
56impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
57    fn default() -> Self {
58        Self {
59            step_size_settings: StepSizeSettings::default(),
60            mass_matrix_options: S::default(),
61            early_window: 0.3,
62            step_size_window: 0.15,
63            mass_matrix_switch_freq: 80,
64            early_mass_matrix_switch_freq: 10,
65            mass_matrix_update_freq: 1,
66            mass_matrix_window_growth: 1.5,
67        }
68    }
69}
70
71impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
72    type Hamiltonian = TransformedHamiltonian<M, A::Transformation>;
73    type Collector =
74        CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, A::Collector>;
75    type Options = EuclideanAdaptOptions<A::Options>;
76
77    fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
78        let num_tune_f = num_tune as f64;
79        let step_size_window = (options.step_size_window * num_tune_f) as u64;
80        let early_end = (options.early_window * num_tune_f) as u64;
81        let final_second_step_size = num_tune.saturating_sub(step_size_window);
82
83        assert!(early_end < num_tune);
84        assert!(options.mass_matrix_window_growth >= 1.0);
85
86        Self {
87            step_size: StepSizeStrategy::new(options.step_size_settings),
88            mass_matrix_adapt: A::new(math, options.mass_matrix_options, num_tune, chain),
89            options,
90            num_tune,
91            early_end,
92            final_step_size_window: final_second_step_size,
93            tuning: true,
94            has_initial_mass_matrix: true,
95            last_update: 0,
96            current_window_size: options.mass_matrix_switch_freq,
97        }
98    }
99
100    fn init<R: Rng + ?Sized>(
101        &mut self,
102        math: &mut M,
103        options: &mut NutsOptions,
104        hamiltonian: &mut Self::Hamiltonian,
105        position: &[f64],
106        rng: &mut R,
107    ) -> Result<(), NutsError> {
108        let state = hamiltonian.init_state_untransformed(math, position)?;
109        self.mass_matrix_adapt.init(
110            math,
111            options,
112            hamiltonian.transformation_mut(),
113            state.point(),
114            rng,
115        )?;
116        self.step_size
117            .init(math, options, hamiltonian, position, rng)?;
118        Ok(())
119    }
120
121    fn adapt<R: Rng + ?Sized>(
122        &mut self,
123        math: &mut M,
124        options: &mut NutsOptions,
125        hamiltonian: &mut Self::Hamiltonian,
126        draw: u64,
127        collector: &Self::Collector,
128        state: &State<M, TransformedPoint<M>>,
129        rng: &mut R,
130    ) -> Result<(), NutsError> {
131        self.step_size.update(&collector.collector1);
132
133        if draw >= self.num_tune {
134            // Needed for step size jitter
135            self.step_size.update_stepsize(rng, hamiltonian, true);
136            self.tuning = false;
137            return Ok(());
138        }
139
140        if draw < self.final_step_size_window {
141            let is_early = draw < self.early_end;
142
143            // At the transition from early to main phase, seed current_window_size as the
144            // maximum of the configured initial size and the background count already
145            // accumulated, so we never shrink the window.
146            if !is_early && draw == self.early_end {
147                self.current_window_size = self
148                    .current_window_size
149                    .max(self.mass_matrix_adapt.background_count());
150            }
151
152            let switch_freq = if is_early {
153                self.options.early_mass_matrix_switch_freq
154            } else {
155                self.current_window_size
156            };
157
158            self.mass_matrix_adapt
159                .update_estimators(math, &collector.collector2);
160            // We only switch if we have switch_freq draws in the background estimate,
161            // and if the number of remaining mass matrix steps is larger than
162            // the switch frequency.
163            let could_switch = self.mass_matrix_adapt.background_count() >= switch_freq;
164            // For the main phase: after switching, the *next* window will be larger, so
165            // is_late must look ahead using that next size to decide whether there is
166            // still room for another full window before the step-size window.
167            let next_window_size = if is_early {
168                self.options.early_mass_matrix_switch_freq
169            } else {
170                (self.current_window_size + 1).max(
171                    (self.current_window_size as f64 * self.options.mass_matrix_window_growth)
172                        .round() as u64,
173                )
174            };
175            let is_late = next_window_size + draw > self.final_step_size_window;
176
177            let mut force_update = false;
178            if could_switch && (!is_late) {
179                self.mass_matrix_adapt.switch(math);
180                force_update = true;
181                // Grow the window for the next main-phase switch.
182                if !is_early {
183                    self.current_window_size = next_window_size;
184                }
185            }
186
187            let did_change = if force_update
188                | (draw - self.last_update >= self.options.mass_matrix_update_freq)
189            {
190                self.mass_matrix_adapt
191                    .adapt(math, hamiltonian.transformation_mut())
192            } else {
193                false
194            };
195
196            if did_change {
197                self.last_update = draw;
198            }
199
200            if is_late {
201                self.step_size.update_estimator_late();
202            } else {
203                self.step_size.update_estimator_early();
204            }
205
206            // First time we change the mass matrix
207            if did_change & self.has_initial_mass_matrix {
208                self.has_initial_mass_matrix = false;
209                let position = math.box_array(state.point().position());
210                self.step_size
211                    .init(math, options, hamiltonian, &position, rng)?;
212            } else {
213                self.step_size.update_stepsize(rng, hamiltonian, false)
214            }
215            return Ok(());
216        }
217
218        self.step_size.update_estimator_late();
219        let is_last = draw == self.num_tune - 1;
220        self.step_size.update_stepsize(rng, hamiltonian, is_last);
221        Ok(())
222    }
223
224    fn new_collector(&self, math: &mut M) -> Self::Collector {
225        Self::Collector::new(
226            self.step_size.new_collector(),
227            self.mass_matrix_adapt.new_collector(math),
228        )
229    }
230
231    fn is_tuning(&self) -> bool {
232        self.tuning
233    }
234
235    fn last_num_steps(&self) -> u64 {
236        self.step_size.last_n_steps
237    }
238}
239
240#[derive(Debug, Storable)]
241pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
242    #[storable(flatten)]
243    pub step_size: S,
244    #[storable(flatten)]
245    pub mass_matrix: M,
246    pub tuning: bool,
247    #[storable(ignore)]
248    _phantom: std::marker::PhantomData<fn() -> P>,
249}
250
251#[derive(Debug)]
252pub struct GlobalStrategyStatsOptions<M: Math, A: MassMatrixAdaptStrategy<M>> {
253    pub step_size: (),
254    pub mass_matrix: A::StatsOptions,
255}
256
257impl<M: Math, A: MassMatrixAdaptStrategy<M>> Clone for GlobalStrategyStatsOptions<M, A> {
258    fn clone(&self) -> Self {
259        *self
260    }
261}
262
263impl<M: Math, A: MassMatrixAdaptStrategy<M>> Copy for GlobalStrategyStatsOptions<M, A> {}
264
265impl<M: Math, A> SamplerStats<M> for GlobalStrategy<M, A>
266where
267    A: MassMatrixAdaptStrategy<M>,
268{
269    type Stats =
270        GlobalStrategyStats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats, A::Stats>;
271    type StatsOptions = GlobalStrategyStatsOptions<M, A>;
272
273    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
274        GlobalStrategyStats {
275            step_size: {
276                let _: () = opt.step_size;
277                self.step_size.extract_stats(math, ())
278            },
279            mass_matrix: self.mass_matrix_adapt.extract_stats(math, opt.mass_matrix),
280            tuning: self.tuning,
281            _phantom: PhantomData,
282        }
283    }
284}
285
286pub struct CombinedCollector<M, P, C1, C2>
287where
288    M: Math,
289    P: Point<M>,
290    C1: Collector<M, P>,
291    C2: Collector<M, P>,
292{
293    pub collector1: C1,
294    pub collector2: C2,
295    _phantom: PhantomData<M>,
296    _phantom2: PhantomData<P>,
297}
298
299impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
300where
301    M: Math,
302    P: Point<M>,
303    C1: Collector<M, P>,
304    C2: Collector<M, P>,
305{
306    pub fn new(collector1: C1, collector2: C2) -> Self {
307        CombinedCollector {
308            collector1,
309            collector2,
310            _phantom: PhantomData,
311            _phantom2: PhantomData,
312        }
313    }
314}
315
316impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
317where
318    M: Math,
319    P: Point<M>,
320    C1: Collector<M, P>,
321    C2: Collector<M, P>,
322{
323    fn register_leapfrog(
324        &mut self,
325        math: &mut M,
326        start: &State<M, P>,
327        end: &State<M, P>,
328        divergence_info: Option<&DivergenceInfo>,
329    ) {
330        self.collector1
331            .register_leapfrog(math, start, end, divergence_info);
332        self.collector2
333            .register_leapfrog(math, start, end, divergence_info);
334    }
335
336    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
337        self.collector1.register_draw(math, state, info);
338        self.collector2.register_draw(math, state, info);
339    }
340
341    fn register_init(
342        &mut self,
343        math: &mut M,
344        state: &State<M, P>,
345        options: &crate::nuts::NutsOptions,
346    ) {
347        self.collector1.register_init(math, state, options);
348        self.collector2.register_init(math, state, options);
349    }
350}
351
352#[cfg(test)]
353mod test {
354    use super::*;
355    use crate::math::test_logps::NormalLogp;
356    use crate::{
357        Chain, DiagAdaptExpSettings,
358        chain::{NutsChain, StatOptions},
359        dynamics::{
360            DivergenceStatsOptions, KineticEnergyKind, TransformedHamiltonian,
361            TransformedPointStatsOptions,
362        },
363        math::CpuMath,
364        transform::{DiagAdaptStrategy, DiagMassMatrix},
365    };
366
367    #[test]
368    fn instanciate_adaptive_sampler() {
369        let ndim = 10;
370        let func = NormalLogp::new(ndim, 30.);
371        let mut math = CpuMath::new(func);
372        let num_tune = 100;
373        let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
374        let strategy =
375            GlobalStrategy::<_, DiagAdaptStrategy<_>>::new(&mut math, options, num_tune, 0u64);
376
377        let mass_matrix = DiagMassMatrix::new(&mut math, true);
378
379        let hamiltonian: TransformedHamiltonian<_, DiagMassMatrix<CpuMath<NormalLogp>>> =
380            TransformedHamiltonian::new(&mut math, mass_matrix, KineticEnergyKind::Euclidean);
381
382        let options = NutsOptions {
383            maxdepth: 10u64,
384            mindepth: 0,
385            check_turning: true,
386            store_divergences: false,
387            target_integration_time: None,
388            extra_doublings: 0,
389            max_energy_error: 1000.0,
390        };
391
392        let rng = {
393            use rand::SeedableRng;
394            rand::rngs::StdRng::seed_from_u64(42)
395        };
396        let chain = 0u64;
397
398        let stats_options = StatOptions {
399            adapt: GlobalStrategyStatsOptions {
400                step_size: (),
401                mass_matrix: (),
402            },
403            hamiltonian: -1i64,
404            point: TransformedPointStatsOptions {
405                store_gradient: true,
406                store_unconstrained: true,
407                store_transformed: false,
408            },
409            divergence: DivergenceStatsOptions {
410                store_divergences: true,
411            },
412        };
413
414        let mut sampler = NutsChain::new(
415            math,
416            hamiltonian,
417            strategy,
418            options,
419            rng,
420            chain,
421            stats_options,
422        );
423        sampler.set_position(&vec![1.5f64; ndim]).unwrap();
424        for _ in 0..200 {
425            sampler.draw().unwrap();
426        }
427
428        // Check that we arrive at 3
429        let (last_position, _, _, prog) = sampler.expanded_draw().unwrap();
430        dbg!(&last_position);
431        for p in last_position {
432            assert!((p - 30.).abs() < 5.0);
433        }
434        assert!(!prog.diverging);
435    }
436}