nuts_rs/
adapt_strategy.rs

1use std::{fmt::Debug, marker::PhantomData};
2
3use nuts_derive::Storable;
4use nuts_storable::{HasDims, Storable};
5use rand::Rng;
6use serde::Serialize;
7
8use super::stepsize::AcceptanceRateCollector;
9use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
10use crate::mass_matrix::MassMatrixAdaptStrategy;
11use crate::{
12    NutsError,
13    chain::AdaptStrategy,
14    euclidean_hamiltonian::EuclideanHamiltonian,
15    hamiltonian::{DivergenceInfo, Hamiltonian, Point},
16    math_base::Math,
17    nuts::{Collector, NutsOptions},
18    sampler_stats::{SamplerStats, StatsDims},
19    state::State,
20};
21
22pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
23    step_size: StepSizeStrategy,
24    mass_matrix: A,
25    options: EuclideanAdaptOptions<A::Options>,
26    num_tune: u64,
27    // The number of draws in the the early window
28    early_end: u64,
29
30    // The first draw number for the final step size adaptation window
31    final_step_size_window: u64,
32    tuning: bool,
33    has_initial_mass_matrix: bool,
34    last_update: u64,
35}
36
37#[derive(Debug, Clone, Copy, Serialize)]
38pub struct EuclideanAdaptOptions<S: Debug + Default> {
39    pub step_size_settings: StepSizeSettings,
40    pub mass_matrix_options: S,
41    pub early_window: f64,
42    pub step_size_window: f64,
43    pub mass_matrix_switch_freq: u64,
44    pub early_mass_matrix_switch_freq: u64,
45    pub mass_matrix_update_freq: u64,
46}
47
48impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
49    fn default() -> Self {
50        Self {
51            step_size_settings: StepSizeSettings::default(),
52            mass_matrix_options: S::default(),
53            early_window: 0.3,
54            step_size_window: 0.15,
55            mass_matrix_switch_freq: 80,
56            early_mass_matrix_switch_freq: 10,
57            mass_matrix_update_freq: 1,
58        }
59    }
60}
61
62impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
63    type Hamiltonian = EuclideanHamiltonian<M, A::MassMatrix>;
64    type Collector = CombinedCollector<
65        M,
66        <Self::Hamiltonian as Hamiltonian<M>>::Point,
67        AcceptanceRateCollector,
68        A::Collector,
69    >;
70    type Options = EuclideanAdaptOptions<A::Options>;
71
72    fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
73        let num_tune_f = num_tune as f64;
74        let step_size_window = (options.step_size_window * num_tune_f) as u64;
75        let early_end = (options.early_window * num_tune_f) as u64;
76        let final_second_step_size = num_tune.saturating_sub(step_size_window);
77
78        assert!(early_end < num_tune);
79
80        Self {
81            step_size: StepSizeStrategy::new(options.step_size_settings),
82            mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain),
83            options,
84            num_tune,
85            early_end,
86            final_step_size_window: final_second_step_size,
87            tuning: true,
88            has_initial_mass_matrix: true,
89            last_update: 0,
90        }
91    }
92
93    fn init<R: Rng + ?Sized>(
94        &mut self,
95        math: &mut M,
96        options: &mut NutsOptions,
97        hamiltonian: &mut Self::Hamiltonian,
98        position: &[f64],
99        rng: &mut R,
100    ) -> Result<(), NutsError> {
101        let state = hamiltonian.init_state(math, position)?;
102        self.mass_matrix.init(
103            math,
104            options,
105            &mut hamiltonian.mass_matrix,
106            state.point(),
107            rng,
108        )?;
109        self.step_size
110            .init(math, options, hamiltonian, position, rng)?;
111        Ok(())
112    }
113
114    fn adapt<R: Rng + ?Sized>(
115        &mut self,
116        math: &mut M,
117        options: &mut NutsOptions,
118        hamiltonian: &mut Self::Hamiltonian,
119        draw: u64,
120        collector: &Self::Collector,
121        state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
122        rng: &mut R,
123    ) -> Result<(), NutsError> {
124        self.step_size.update(&collector.collector1);
125
126        if draw >= self.num_tune {
127            // Needed for step size jitter
128            self.step_size.update_stepsize(rng, hamiltonian, true);
129            self.tuning = false;
130            return Ok(());
131        }
132
133        if draw < self.final_step_size_window {
134            let is_early = draw < self.early_end;
135            let switch_freq = if is_early {
136                self.options.early_mass_matrix_switch_freq
137            } else {
138                self.options.mass_matrix_switch_freq
139            };
140
141            self.mass_matrix
142                .update_estimators(math, &collector.collector2);
143            // We only switch if we have switch_freq draws in the background estimate,
144            // and if the number of remaining mass matrix steps is larger than
145            // the switch frequency.
146            let could_switch = self.mass_matrix.background_count() >= switch_freq;
147            let is_late = switch_freq + draw > self.final_step_size_window;
148
149            let mut force_update = false;
150            if could_switch && (!is_late) {
151                self.mass_matrix.switch(math);
152                force_update = true;
153            }
154
155            let did_change = if force_update
156                | (draw - self.last_update >= self.options.mass_matrix_update_freq)
157            {
158                self.mass_matrix.adapt(math, &mut hamiltonian.mass_matrix)
159            } else {
160                false
161            };
162
163            if did_change {
164                self.last_update = draw;
165            }
166
167            if is_late {
168                self.step_size.update_estimator_late();
169            } else {
170                self.step_size.update_estimator_early();
171            }
172
173            // First time we change the mass matrix
174            if did_change & self.has_initial_mass_matrix {
175                self.has_initial_mass_matrix = false;
176                let position = math.box_array(state.point().position());
177                self.step_size
178                    .init(math, options, hamiltonian, &position, rng)?;
179            } else {
180                self.step_size.update_stepsize(rng, hamiltonian, false)
181            }
182            return Ok(());
183        }
184
185        self.step_size.update_estimator_late();
186        let is_last = draw == self.num_tune - 1;
187        self.step_size.update_stepsize(rng, hamiltonian, is_last);
188        Ok(())
189    }
190
191    fn new_collector(&self, math: &mut M) -> Self::Collector {
192        Self::Collector::new(
193            self.step_size.new_collector(),
194            self.mass_matrix.new_collector(math),
195        )
196    }
197
198    fn is_tuning(&self) -> bool {
199        self.tuning
200    }
201
202    fn last_num_steps(&self) -> u64 {
203        self.step_size.last_n_steps
204    }
205}
206
207#[derive(Debug, Storable)]
208pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
209    #[storable(flatten)]
210    pub step_size: S,
211    #[storable(flatten)]
212    pub mass_matrix: M,
213    pub tuning: bool,
214    #[storable(ignore)]
215    _phantom: std::marker::PhantomData<fn() -> P>,
216}
217
218#[derive(Debug)]
219pub struct GlobalStrategyStatsOptions<M: Math, A: MassMatrixAdaptStrategy<M>> {
220    pub step_size: (),
221    pub mass_matrix: A::StatsOptions,
222}
223
224impl<M: Math, A: MassMatrixAdaptStrategy<M>> Clone for GlobalStrategyStatsOptions<M, A> {
225    fn clone(&self) -> Self {
226        *self
227    }
228}
229
230impl<M: Math, A: MassMatrixAdaptStrategy<M>> Copy for GlobalStrategyStatsOptions<M, A> {}
231
232impl<M: Math, A> SamplerStats<M> for GlobalStrategy<M, A>
233where
234    A: MassMatrixAdaptStrategy<M>,
235{
236    type Stats =
237        GlobalStrategyStats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats, A::Stats>;
238    type StatsOptions = GlobalStrategyStatsOptions<M, A>;
239
240    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats {
241        GlobalStrategyStats {
242            step_size: {
243                let _: () = opt.step_size;
244                self.step_size.extract_stats(math, ())
245            },
246            mass_matrix: self.mass_matrix.extract_stats(math, opt.mass_matrix),
247            tuning: self.tuning,
248            _phantom: PhantomData,
249        }
250    }
251}
252
253pub struct CombinedCollector<M, P, C1, C2>
254where
255    M: Math,
256    P: Point<M>,
257    C1: Collector<M, P>,
258    C2: Collector<M, P>,
259{
260    pub collector1: C1,
261    pub collector2: C2,
262    _phantom: PhantomData<M>,
263    _phantom2: PhantomData<P>,
264}
265
266impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
267where
268    M: Math,
269    P: Point<M>,
270    C1: Collector<M, P>,
271    C2: Collector<M, P>,
272{
273    pub fn new(collector1: C1, collector2: C2) -> Self {
274        CombinedCollector {
275            collector1,
276            collector2,
277            _phantom: PhantomData,
278            _phantom2: PhantomData,
279        }
280    }
281}
282
283impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
284where
285    M: Math,
286    P: Point<M>,
287    C1: Collector<M, P>,
288    C2: Collector<M, P>,
289{
290    fn register_leapfrog(
291        &mut self,
292        math: &mut M,
293        start: &State<M, P>,
294        end: &State<M, P>,
295        divergence_info: Option<&DivergenceInfo>,
296    ) {
297        self.collector1
298            .register_leapfrog(math, start, end, divergence_info);
299        self.collector2
300            .register_leapfrog(math, start, end, divergence_info);
301    }
302
303    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
304        self.collector1.register_draw(math, state, info);
305        self.collector2.register_draw(math, state, info);
306    }
307
308    fn register_init(
309        &mut self,
310        math: &mut M,
311        state: &State<M, P>,
312        options: &crate::nuts::NutsOptions,
313    ) {
314        self.collector1.register_init(math, state, options);
315        self.collector2.register_init(math, state, options);
316    }
317}
318
319#[cfg(test)]
320pub mod test_logps {
321    use std::collections::HashMap;
322
323    use crate::{cpu_math::CpuLogpFunc, math_base::LogpError};
324    use nuts_storable::HasDims;
325    use thiserror::Error;
326
327    #[derive(Clone, Debug)]
328    pub struct NormalLogp {
329        dim: usize,
330        mu: f64,
331    }
332
333    impl NormalLogp {
334        pub(crate) fn new(dim: usize, mu: f64) -> NormalLogp {
335            NormalLogp { dim, mu }
336        }
337    }
338
339    #[derive(Error, Debug)]
340    pub enum NormalLogpError {}
341
342    impl LogpError for NormalLogpError {
343        fn is_recoverable(&self) -> bool {
344            false
345        }
346    }
347
348    impl HasDims for NormalLogp {
349        fn dim_sizes(&self) -> HashMap<String, u64> {
350            vec![("unconstrained_parameter".to_string(), self.dim as u64)]
351                .into_iter()
352                .collect()
353        }
354    }
355
356    impl CpuLogpFunc for NormalLogp {
357        type LogpError = NormalLogpError;
358        type FlowParameters = ();
359        type ExpandedVector = Vec<f64>;
360
361        fn dim(&self) -> usize {
362            self.dim
363        }
364        fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
365            let n = position.len();
366            assert!(gradient.len() == n);
367
368            let mut logp = 0f64;
369            for (p, g) in position.iter().zip(gradient.iter_mut()) {
370                let val = *p - self.mu;
371                logp -= val * val / 2.;
372                *g = -val;
373            }
374            Ok(logp)
375        }
376
377        fn expand_vector<R>(
378            &mut self,
379            _rng: &mut R,
380            array: &[f64],
381        ) -> Result<Self::ExpandedVector, crate::cpu_math::CpuMathError>
382        where
383            R: rand::Rng + ?Sized,
384        {
385            Ok(array.to_vec())
386        }
387
388        fn inv_transform_normalize(
389            &mut self,
390            _params: &Self::FlowParameters,
391            _untransformed_position: &[f64],
392            _untransofrmed_gradient: &[f64],
393            _transformed_position: &mut [f64],
394            _transformed_gradient: &mut [f64],
395        ) -> Result<f64, Self::LogpError> {
396            unimplemented!()
397        }
398
399        fn init_from_transformed_position(
400            &mut self,
401            _params: &Self::FlowParameters,
402            _untransformed_position: &mut [f64],
403            _untransformed_gradient: &mut [f64],
404            _transformed_position: &[f64],
405            _transformed_gradient: &mut [f64],
406        ) -> Result<(f64, f64), Self::LogpError> {
407            unimplemented!()
408        }
409
410        fn init_from_untransformed_position(
411            &mut self,
412            _params: &Self::FlowParameters,
413            _untransformed_position: &[f64],
414            _untransformed_gradient: &mut [f64],
415            _transformed_position: &mut [f64],
416            _transformed_gradient: &mut [f64],
417        ) -> Result<(f64, f64), Self::LogpError> {
418            unimplemented!()
419        }
420
421        fn update_transformation<'a, R: rand::Rng + ?Sized>(
422            &'a mut self,
423            _rng: &mut R,
424            _untransformed_positions: impl Iterator<Item = &'a [f64]>,
425            _untransformed_gradients: impl Iterator<Item = &'a [f64]>,
426            _untransformed_logp: impl Iterator<Item = &'a f64>,
427            _params: &'a mut Self::FlowParameters,
428        ) -> Result<(), Self::LogpError> {
429            unimplemented!()
430        }
431
432        fn new_transformation<R: rand::Rng + ?Sized>(
433            &mut self,
434            _rng: &mut R,
435            _untransformed_position: &[f64],
436            _untransfogmed_gradient: &[f64],
437            _chain: u64,
438        ) -> Result<Self::FlowParameters, Self::LogpError> {
439            unimplemented!()
440        }
441
442        fn transformation_id(
443            &self,
444            _params: &Self::FlowParameters,
445        ) -> Result<i64, Self::LogpError> {
446            unimplemented!()
447        }
448    }
449}
450
451#[cfg(test)]
452mod test {
453    use super::test_logps::NormalLogp;
454    use super::*;
455    use crate::{
456        Chain, DiagAdaptExpSettings,
457        chain::{NutsChain, StatOptions},
458        cpu_math::CpuMath,
459        euclidean_hamiltonian::EuclideanHamiltonian,
460        mass_matrix::DiagMassMatrix,
461    };
462
463    #[test]
464    fn instanciate_adaptive_sampler() {
465        use crate::mass_matrix::Strategy;
466
467        let ndim = 10;
468        let func = NormalLogp::new(ndim, 3.);
469        let mut math = CpuMath::new(func);
470        let num_tune = 100;
471        let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
472        let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune, 0u64);
473
474        let mass_matrix = DiagMassMatrix::new(&mut math, true);
475        let max_energy_error = 1000f64;
476        let step_size = 0.1f64;
477
478        let hamiltonian =
479            EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size);
480        let options = NutsOptions {
481            maxdepth: 10u64,
482            mindepth: 0,
483            store_gradient: true,
484            store_unconstrained: true,
485            check_turning: true,
486            store_divergences: false,
487        };
488
489        let rng = {
490            use rand::SeedableRng;
491            rand::rngs::StdRng::seed_from_u64(42)
492        };
493        let chain = 0u64;
494
495        let stats_options = StatOptions {
496            adapt: GlobalStrategyStatsOptions {
497                step_size: (),
498                mass_matrix: (),
499            },
500            hamiltonian: (),
501            point: (),
502        };
503
504        let mut sampler = NutsChain::new(
505            math,
506            hamiltonian,
507            strategy,
508            options,
509            rng,
510            chain,
511            stats_options,
512        );
513        sampler.set_position(&vec![1.5f64; ndim]).unwrap();
514        for _ in 0..200 {
515            sampler.draw().unwrap();
516        }
517    }
518}