Skip to main content

nuts_rs/stepsize/
adapt.rs

1//! Coordinate step-size search at initialisation and dispatch to the chosen adaptation algorithm during tuning.
2
3use itertools::Either;
4use nuts_derive::Storable;
5use rand::distr::Uniform;
6use rand::{Rng, RngExt};
7use serde::{Deserialize, Serialize};
8
9use super::adam::{Adam, AdamOptions};
10use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
11use crate::{
12    Math, NutsError,
13    dynamics::{Direction, Hamiltonian, LeapfrogResult, Point},
14    nuts::{Collector, NutsOptions},
15    sampler_stats::SamplerStats,
16};
17use std::f64;
18use std::fmt::Debug;
19
20/// Method used for step size adaptation
21#[derive(Debug, Clone, Copy, Serialize, Default, Deserialize)]
22pub enum StepSizeAdaptMethod {
23    /// Use dual averaging for step size adaptation (default)
24    #[default]
25    DualAverage,
26    /// Use Adam optimizer for step size adaptation
27    Adam,
28    Fixed(f64),
29}
30
31/// Options for step size adaptation
32#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
33pub struct StepSizeAdaptOptions {
34    pub method: StepSizeAdaptMethod,
35    /// Dual averaging adaptation options
36    pub dual_average: DualAverageOptions,
37    /// Adam optimizer adaptation options
38    pub adam: AdamOptions,
39}
40
41impl Default for StepSizeAdaptOptions {
42    fn default() -> Self {
43        Self {
44            method: StepSizeAdaptMethod::DualAverage,
45            dual_average: DualAverageOptions::default(),
46            adam: AdamOptions::default(),
47        }
48    }
49}
50
51/// Step size adaptation strategy
52pub struct Strategy {
53    /// The step size adaptation method being used
54    adaptation: Option<Either<DualAverage, Adam>>,
55    /// Settings for step size adaptation
56    options: StepSizeSettings,
57    /// Last mean tree accept rate
58    pub last_mean_tree_accept: f64,
59    /// Last symmetric mean tree accept rate
60    pub last_sym_mean_tree_accept: f64,
61    /// Last number of steps
62    pub last_n_steps: u64,
63    /// Maximum absolute energy error observed in the last trajectory
64    pub last_max_energy_error: f64,
65}
66
67impl Strategy {
68    pub fn new(options: StepSizeSettings) -> Self {
69        let adaptation = match options.adapt_options.method {
70            StepSizeAdaptMethod::DualAverage => Some(Either::Left(DualAverage::new(
71                options.adapt_options.dual_average,
72                options.initial_step,
73            ))),
74            StepSizeAdaptMethod::Adam => Some(Either::Right(Adam::new(
75                options.adapt_options.adam,
76                options.initial_step,
77            ))),
78            StepSizeAdaptMethod::Fixed(_) => None,
79        };
80
81        Self {
82            adaptation,
83            options,
84            last_n_steps: 0,
85            last_sym_mean_tree_accept: 0.0,
86            last_mean_tree_accept: 0.0,
87            last_max_energy_error: 0.0,
88        }
89    }
90
91    pub fn init<M: Math, R: Rng + ?Sized, P: Point<M>>(
92        &mut self,
93        math: &mut M,
94        options: &mut NutsOptions,
95        hamiltonian: &mut impl Hamiltonian<M, Point = P>,
96        position: &[f64],
97        rng: &mut R,
98    ) -> Result<(), NutsError> {
99        if let StepSizeAdaptMethod::Fixed(step_size) = self.options.adapt_options.method {
100            *hamiltonian.step_size_mut() = step_size;
101            return Ok(());
102        };
103        let mut state = hamiltonian.init_state(math, position)?;
104        hamiltonian.initialize_trajectory(math, &mut state, true, rng)?;
105
106        let mut collector = AcceptanceRateCollector::new();
107
108        collector.register_init(math, &state, options);
109
110        *hamiltonian.step_size_mut() = self.options.initial_step;
111
112        let state_next = hamiltonian.leapfrog(
113            math,
114            &state,
115            Direction::Forward,
116            1.0,
117            state.point().initial_energy(),
118            1000.0,
119            &mut collector,
120        );
121
122        let LeapfrogResult::Ok(_) = state_next else {
123            return Ok(());
124        };
125
126        let accept_stat = collector.mean.current();
127        let dir = if accept_stat > self.options.target_accept {
128            Direction::Forward
129        } else {
130            Direction::Backward
131        };
132
133        for _ in 0..100 {
134            let mut collector = AcceptanceRateCollector::new();
135            collector.register_init(math, &state, options);
136            let state_next = hamiltonian.leapfrog(
137                math,
138                &state,
139                dir,
140                1.0,
141                state.point().initial_energy(),
142                1000.0,
143                &mut collector,
144            );
145            let LeapfrogResult::Ok(_) = state_next else {
146                *hamiltonian.step_size_mut() = self.options.initial_step;
147                return Ok(());
148            };
149            let accept_stat = collector.mean.current();
150            match dir {
151                Direction::Forward => {
152                    if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5)
153                    {
154                        match self.adaptation.as_mut().expect("Adaptation must be set") {
155                            Either::Left(adapt) => {
156                                *adapt = DualAverage::new(
157                                    self.options.adapt_options.dual_average,
158                                    hamiltonian.step_size(),
159                                );
160                            }
161                            Either::Right(adapt) => {
162                                *adapt = Adam::new(
163                                    self.options.adapt_options.adam,
164                                    hamiltonian.step_size(),
165                                );
166                            }
167                        }
168                        return Ok(());
169                    }
170                    *hamiltonian.step_size_mut() *= 2.;
171                }
172                Direction::Backward => {
173                    if (accept_stat >= self.options.target_accept)
174                        | (hamiltonian.step_size() < 1e-10)
175                    {
176                        match self.adaptation.as_mut().expect("Adaptation must be set") {
177                            Either::Left(adapt) => {
178                                *adapt = DualAverage::new(
179                                    self.options.adapt_options.dual_average,
180                                    hamiltonian.step_size(),
181                                );
182                            }
183                            Either::Right(adapt) => {
184                                *adapt = Adam::new(
185                                    self.options.adapt_options.adam,
186                                    hamiltonian.step_size(),
187                                );
188                            }
189                        }
190                        return Ok(());
191                    }
192                    *hamiltonian.step_size_mut() /= 2.;
193                }
194            }
195        }
196        // If we don't find something better, use the specified initial value
197        *hamiltonian.step_size_mut() = self.options.initial_step;
198        Ok(())
199    }
200
201    pub fn update(&mut self, collector: &AcceptanceRateCollector) {
202        let mean_sym = collector.mean_sym.current();
203        let mean = collector.mean.current();
204        let n_steps = collector.mean.count();
205        self.last_mean_tree_accept = mean;
206        self.last_sym_mean_tree_accept = mean_sym;
207        self.last_n_steps = n_steps;
208        self.last_max_energy_error = collector.max_energy_error;
209    }
210
211    pub fn update_estimator_early(&mut self) {
212        match self.adaptation.as_mut() {
213            None => {}
214            Some(Either::Left(adapt)) => {
215                adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
216            }
217            Some(Either::Right(adapt)) => {
218                adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
219            }
220        }
221    }
222
223    pub fn update_estimator_late(&mut self) {
224        match self.adaptation.as_mut() {
225            None => {}
226            Some(Either::Left(adapt)) => {
227                adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
228            }
229            Some(Either::Right(adapt)) => {
230                adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
231            }
232        }
233    }
234
235    pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
236        &mut self,
237        rng: &mut R,
238        hamiltonian: &mut impl Hamiltonian<M>,
239        use_best_guess: bool,
240    ) {
241        let step_size = match self.adaptation {
242            None => {
243                if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
244                    val
245                } else {
246                    panic!("Adaptation method must be Fixed if adaptation is None")
247                }
248            }
249            Some(Either::Left(ref adapt)) => {
250                if use_best_guess {
251                    adapt.current_step_size_adapted()
252                } else {
253                    adapt.current_step_size()
254                }
255            }
256            Some(Either::Right(ref adapt)) => adapt.current_step_size(),
257        };
258
259        if let Some(jitter) = self.options.jitter {
260            let jitter =
261                rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
262            let jittered_step_size = step_size * jitter;
263            *hamiltonian.step_size_mut() = jittered_step_size;
264        } else {
265            *hamiltonian.step_size_mut() = step_size;
266        }
267    }
268
269    pub fn new_collector(&self) -> AcceptanceRateCollector {
270        AcceptanceRateCollector::new()
271    }
272}
273
274#[derive(Debug, Storable)]
275pub struct Stats {
276    pub step_size_bar: f64,
277    pub mean_tree_accept: f64,
278    pub mean_tree_accept_sym: f64,
279    pub n_steps: u64,
280    pub max_energy_error: f64,
281}
282
283impl<M: Math> SamplerStats<M> for Strategy {
284    type Stats = Stats;
285    type StatsOptions = ();
286
287    fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
288        Stats {
289            step_size_bar: match self.adaptation {
290                None => {
291                    if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
292                        val
293                    } else {
294                        panic!("Adaptation method must be Fixed if adaptation is None")
295                    }
296                }
297                Some(Either::Left(ref adapt)) => adapt.current_step_size_adapted(),
298                Some(Either::Right(ref adapt)) => adapt.current_step_size(),
299            },
300            mean_tree_accept: self.last_mean_tree_accept,
301            mean_tree_accept_sym: self.last_sym_mean_tree_accept,
302            n_steps: self.last_n_steps,
303            max_energy_error: self.last_max_energy_error,
304        }
305    }
306}
307
308#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
309pub struct StepSizeSettings {
310    /// Target acceptance rate
311    pub target_accept: f64,
312    /// Initial step size
313    pub initial_step: f64,
314    /// Optional jitter to add to step size (randomization)
315    pub jitter: Option<f64>,
316    /// Adaptation options specific to the chosen method
317    pub adapt_options: StepSizeAdaptOptions,
318}
319
320impl Default for StepSizeSettings {
321    fn default() -> Self {
322        Self {
323            target_accept: 0.8,
324            initial_step: 0.1,
325            jitter: Some(0.1),
326            adapt_options: StepSizeAdaptOptions::default(),
327        }
328    }
329}