nuts_rs/stepsize/
adapt.rs

1use itertools::Either;
2use nuts_derive::Storable;
3use rand::Rng;
4use rand_distr::Uniform;
5use serde::Serialize;
6
7use super::adam::{Adam, AdamOptions};
8use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
9use crate::{
10    Math, NutsError,
11    hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point},
12    nuts::{Collector, NutsOptions},
13    sampler_stats::SamplerStats,
14};
15use std::fmt::Debug;
16
17/// Method used for step size adaptation
18#[derive(Debug, Clone, Copy, Serialize, Default)]
19pub enum StepSizeAdaptMethod {
20    /// Use dual averaging for step size adaptation (default)
21    #[default]
22    DualAverage,
23    /// Use Adam optimizer for step size adaptation
24    Adam,
25    Fixed(f64),
26}
27
28/// Options for step size adaptation
29#[derive(Debug, Clone, Copy, Serialize)]
30pub struct StepSizeAdaptOptions {
31    pub method: StepSizeAdaptMethod,
32    /// Dual averaging adaptation options
33    pub dual_average: DualAverageOptions,
34    /// Adam optimizer adaptation options
35    pub adam: AdamOptions,
36}
37
38impl Default for StepSizeAdaptOptions {
39    fn default() -> Self {
40        Self {
41            method: StepSizeAdaptMethod::DualAverage,
42            dual_average: DualAverageOptions::default(),
43            adam: AdamOptions::default(),
44        }
45    }
46}
47
48/// Step size adaptation strategy
49pub struct Strategy {
50    /// The step size adaptation method being used
51    adaptation: Option<Either<DualAverage, Adam>>,
52    /// Settings for step size adaptation
53    options: StepSizeSettings,
54    /// Last mean tree accept rate
55    pub last_mean_tree_accept: f64,
56    /// Last symmetric mean tree accept rate
57    pub last_sym_mean_tree_accept: f64,
58    /// Last number of steps
59    pub last_n_steps: u64,
60}
61
62impl Strategy {
63    pub fn new(options: StepSizeSettings) -> Self {
64        let adaptation = match options.adapt_options.method {
65            StepSizeAdaptMethod::DualAverage => Some(Either::Left(DualAverage::new(
66                options.adapt_options.dual_average,
67                options.initial_step,
68            ))),
69            StepSizeAdaptMethod::Adam => Some(Either::Right(Adam::new(
70                options.adapt_options.adam,
71                options.initial_step,
72            ))),
73            StepSizeAdaptMethod::Fixed(_) => None,
74        };
75
76        Self {
77            adaptation,
78            options,
79            last_n_steps: 0,
80            last_sym_mean_tree_accept: 0.0,
81            last_mean_tree_accept: 0.0,
82        }
83    }
84
85    pub fn init<M: Math, R: Rng + ?Sized, P: Point<M>>(
86        &mut self,
87        math: &mut M,
88        options: &mut NutsOptions,
89        hamiltonian: &mut impl Hamiltonian<M, Point = P>,
90        position: &[f64],
91        rng: &mut R,
92    ) -> Result<(), NutsError> {
93        if let StepSizeAdaptMethod::Fixed(step_size) = self.options.adapt_options.method {
94            *hamiltonian.step_size_mut() = step_size;
95            return Ok(());
96        };
97        let mut state = hamiltonian.init_state(math, position)?;
98        hamiltonian.initialize_trajectory(math, &mut state, rng)?;
99
100        let mut collector = AcceptanceRateCollector::new();
101
102        collector.register_init(math, &state, options);
103
104        *hamiltonian.step_size_mut() = self.options.initial_step;
105
106        let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector);
107
108        let LeapfrogResult::Ok(_) = state_next else {
109            return Ok(());
110        };
111
112        let accept_stat = collector.mean.current();
113        let dir = if accept_stat > self.options.target_accept {
114            Direction::Forward
115        } else {
116            Direction::Backward
117        };
118
119        for _ in 0..100 {
120            let mut collector = AcceptanceRateCollector::new();
121            collector.register_init(math, &state, options);
122            let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector);
123            let LeapfrogResult::Ok(_) = state_next else {
124                *hamiltonian.step_size_mut() = self.options.initial_step;
125                return Ok(());
126            };
127            let accept_stat = collector.mean.current();
128            match dir {
129                Direction::Forward => {
130                    if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5)
131                    {
132                        match self.adaptation.as_mut().expect("Adaptation must be set") {
133                            Either::Left(adapt) => {
134                                *adapt = DualAverage::new(
135                                    self.options.adapt_options.dual_average,
136                                    hamiltonian.step_size(),
137                                );
138                            }
139                            Either::Right(adapt) => {
140                                *adapt = Adam::new(
141                                    self.options.adapt_options.adam,
142                                    hamiltonian.step_size(),
143                                );
144                            }
145                        }
146                        return Ok(());
147                    }
148                    *hamiltonian.step_size_mut() *= 2.;
149                }
150                Direction::Backward => {
151                    if (accept_stat >= self.options.target_accept)
152                        | (hamiltonian.step_size() < 1e-10)
153                    {
154                        match self.adaptation.as_mut().expect("Adaptation must be set") {
155                            Either::Left(adapt) => {
156                                *adapt = DualAverage::new(
157                                    self.options.adapt_options.dual_average,
158                                    hamiltonian.step_size(),
159                                );
160                            }
161                            Either::Right(adapt) => {
162                                *adapt = Adam::new(
163                                    self.options.adapt_options.adam,
164                                    hamiltonian.step_size(),
165                                );
166                            }
167                        }
168                        return Ok(());
169                    }
170                    *hamiltonian.step_size_mut() /= 2.;
171                }
172            }
173        }
174        // If we don't find something better, use the specified initial value
175        *hamiltonian.step_size_mut() = self.options.initial_step;
176        Ok(())
177    }
178
179    pub fn update(&mut self, collector: &AcceptanceRateCollector) {
180        let mean_sym = collector.mean_sym.current();
181        let mean = collector.mean.current();
182        let n_steps = collector.mean.count();
183        self.last_mean_tree_accept = mean;
184        self.last_sym_mean_tree_accept = mean_sym;
185        self.last_n_steps = n_steps;
186    }
187
188    pub fn update_estimator_early(&mut self) {
189        match self.adaptation.as_mut() {
190            None => {}
191            Some(Either::Left(adapt)) => {
192                adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
193            }
194            Some(Either::Right(adapt)) => {
195                adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
196            }
197        }
198    }
199
200    pub fn update_estimator_late(&mut self) {
201        match self.adaptation.as_mut() {
202            None => {}
203            Some(Either::Left(adapt)) => {
204                adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
205            }
206            Some(Either::Right(adapt)) => {
207                adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
208            }
209        }
210    }
211
212    pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
213        &mut self,
214        rng: &mut R,
215        hamiltonian: &mut impl Hamiltonian<M>,
216        use_best_guess: bool,
217    ) {
218        let step_size = match self.adaptation {
219            None => {
220                if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
221                    val
222                } else {
223                    panic!("Adaptation method must be Fixed if adaptation is None")
224                }
225            }
226            Some(Either::Left(ref adapt)) => {
227                if use_best_guess {
228                    adapt.current_step_size_adapted()
229                } else {
230                    adapt.current_step_size()
231                }
232            }
233            Some(Either::Right(ref adapt)) => adapt.current_step_size(),
234        };
235
236        if let Some(jitter) = self.options.jitter {
237            let jitter =
238                rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
239            let jittered_step_size = step_size * jitter;
240            *hamiltonian.step_size_mut() = jittered_step_size;
241        } else {
242            *hamiltonian.step_size_mut() = step_size;
243        }
244    }
245
246    pub fn new_collector(&self) -> AcceptanceRateCollector {
247        AcceptanceRateCollector::new()
248    }
249}
250
251#[derive(Debug, Storable)]
252pub struct Stats {
253    pub step_size_bar: f64,
254    pub mean_tree_accept: f64,
255    pub mean_tree_accept_sym: f64,
256    pub n_steps: u64,
257}
258
259impl<M: Math> SamplerStats<M> for Strategy {
260    type Stats = Stats;
261    type StatsOptions = ();
262
263    fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
264        Stats {
265            step_size_bar: match self.adaptation {
266                None => {
267                    if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
268                        val
269                    } else {
270                        panic!("Adaptation method must be Fixed if adaptation is None")
271                    }
272                }
273                Some(Either::Left(ref adapt)) => adapt.current_step_size_adapted(),
274                Some(Either::Right(ref adapt)) => adapt.current_step_size(),
275            },
276            mean_tree_accept: self.last_mean_tree_accept,
277            mean_tree_accept_sym: self.last_sym_mean_tree_accept,
278            n_steps: self.last_n_steps,
279        }
280    }
281}
282
283#[derive(Debug, Clone, Copy, Serialize)]
284pub struct StepSizeSettings {
285    /// Target acceptance rate
286    pub target_accept: f64,
287    /// Initial step size
288    pub initial_step: f64,
289    /// Optional jitter to add to step size (randomization)
290    pub jitter: Option<f64>,
291    /// Adaptation options specific to the chosen method
292    pub adapt_options: StepSizeAdaptOptions,
293}
294
295impl Default for StepSizeSettings {
296    fn default() -> Self {
297        Self {
298            target_accept: 0.8,
299            initial_step: 0.1,
300            jitter: Some(0.1),
301            adapt_options: StepSizeAdaptOptions::default(),
302        }
303    }
304}