1use itertools::Either;
4use nuts_derive::Storable;
5use rand::distr::Uniform;
6use rand::{Rng, RngExt};
7use serde::{Deserialize, Serialize};
8
9use super::adam::{Adam, AdamOptions};
10use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
11use crate::{
12 Math, NutsError,
13 dynamics::{Direction, Hamiltonian, LeapfrogResult, Point},
14 nuts::{Collector, NutsOptions},
15 sampler_stats::SamplerStats,
16};
17use std::f64;
18use std::fmt::Debug;
19
20#[derive(Debug, Clone, Copy, Serialize, Default, Deserialize)]
22pub enum StepSizeAdaptMethod {
23 #[default]
25 DualAverage,
26 Adam,
28 Fixed(f64),
29}
30
31#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
33pub struct StepSizeAdaptOptions {
34 pub method: StepSizeAdaptMethod,
35 pub dual_average: DualAverageOptions,
37 pub adam: AdamOptions,
39}
40
41impl Default for StepSizeAdaptOptions {
42 fn default() -> Self {
43 Self {
44 method: StepSizeAdaptMethod::DualAverage,
45 dual_average: DualAverageOptions::default(),
46 adam: AdamOptions::default(),
47 }
48 }
49}
50
51pub struct Strategy {
53 adaptation: Option<Either<DualAverage, Adam>>,
55 options: StepSizeSettings,
57 pub last_mean_tree_accept: f64,
59 pub last_sym_mean_tree_accept: f64,
61 pub last_n_steps: u64,
63 pub last_max_energy_error: f64,
65}
66
67impl Strategy {
68 pub fn new(options: StepSizeSettings) -> Self {
69 let adaptation = match options.adapt_options.method {
70 StepSizeAdaptMethod::DualAverage => Some(Either::Left(DualAverage::new(
71 options.adapt_options.dual_average,
72 options.initial_step,
73 ))),
74 StepSizeAdaptMethod::Adam => Some(Either::Right(Adam::new(
75 options.adapt_options.adam,
76 options.initial_step,
77 ))),
78 StepSizeAdaptMethod::Fixed(_) => None,
79 };
80
81 Self {
82 adaptation,
83 options,
84 last_n_steps: 0,
85 last_sym_mean_tree_accept: 0.0,
86 last_mean_tree_accept: 0.0,
87 last_max_energy_error: 0.0,
88 }
89 }
90
91 pub fn init<M: Math, R: Rng + ?Sized, P: Point<M>>(
92 &mut self,
93 math: &mut M,
94 options: &mut NutsOptions,
95 hamiltonian: &mut impl Hamiltonian<M, Point = P>,
96 position: &[f64],
97 rng: &mut R,
98 ) -> Result<(), NutsError> {
99 if let StepSizeAdaptMethod::Fixed(step_size) = self.options.adapt_options.method {
100 *hamiltonian.step_size_mut() = step_size;
101 return Ok(());
102 };
103 let mut state = hamiltonian.init_state(math, position)?;
104 hamiltonian.initialize_trajectory(math, &mut state, true, rng)?;
105
106 let mut collector = AcceptanceRateCollector::new();
107
108 collector.register_init(math, &state, options);
109
110 *hamiltonian.step_size_mut() = self.options.initial_step;
111
112 let state_next = hamiltonian.leapfrog(
113 math,
114 &state,
115 Direction::Forward,
116 1.0,
117 state.point().initial_energy(),
118 1000.0,
119 &mut collector,
120 );
121
122 let LeapfrogResult::Ok(_) = state_next else {
123 return Ok(());
124 };
125
126 let accept_stat = collector.mean.current();
127 let dir = if accept_stat > self.options.target_accept {
128 Direction::Forward
129 } else {
130 Direction::Backward
131 };
132
133 for _ in 0..100 {
134 let mut collector = AcceptanceRateCollector::new();
135 collector.register_init(math, &state, options);
136 let state_next = hamiltonian.leapfrog(
137 math,
138 &state,
139 dir,
140 1.0,
141 state.point().initial_energy(),
142 1000.0,
143 &mut collector,
144 );
145 let LeapfrogResult::Ok(_) = state_next else {
146 *hamiltonian.step_size_mut() = self.options.initial_step;
147 return Ok(());
148 };
149 let accept_stat = collector.mean.current();
150 match dir {
151 Direction::Forward => {
152 if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5)
153 {
154 match self.adaptation.as_mut().expect("Adaptation must be set") {
155 Either::Left(adapt) => {
156 *adapt = DualAverage::new(
157 self.options.adapt_options.dual_average,
158 hamiltonian.step_size(),
159 );
160 }
161 Either::Right(adapt) => {
162 *adapt = Adam::new(
163 self.options.adapt_options.adam,
164 hamiltonian.step_size(),
165 );
166 }
167 }
168 return Ok(());
169 }
170 *hamiltonian.step_size_mut() *= 2.;
171 }
172 Direction::Backward => {
173 if (accept_stat >= self.options.target_accept)
174 | (hamiltonian.step_size() < 1e-10)
175 {
176 match self.adaptation.as_mut().expect("Adaptation must be set") {
177 Either::Left(adapt) => {
178 *adapt = DualAverage::new(
179 self.options.adapt_options.dual_average,
180 hamiltonian.step_size(),
181 );
182 }
183 Either::Right(adapt) => {
184 *adapt = Adam::new(
185 self.options.adapt_options.adam,
186 hamiltonian.step_size(),
187 );
188 }
189 }
190 return Ok(());
191 }
192 *hamiltonian.step_size_mut() /= 2.;
193 }
194 }
195 }
196 *hamiltonian.step_size_mut() = self.options.initial_step;
198 Ok(())
199 }
200
201 pub fn update(&mut self, collector: &AcceptanceRateCollector) {
202 let mean_sym = collector.mean_sym.current();
203 let mean = collector.mean.current();
204 let n_steps = collector.mean.count();
205 self.last_mean_tree_accept = mean;
206 self.last_sym_mean_tree_accept = mean_sym;
207 self.last_n_steps = n_steps;
208 self.last_max_energy_error = collector.max_energy_error;
209 }
210
211 pub fn update_estimator_early(&mut self) {
212 match self.adaptation.as_mut() {
213 None => {}
214 Some(Either::Left(adapt)) => {
215 adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
216 }
217 Some(Either::Right(adapt)) => {
218 adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
219 }
220 }
221 }
222
223 pub fn update_estimator_late(&mut self) {
224 match self.adaptation.as_mut() {
225 None => {}
226 Some(Either::Left(adapt)) => {
227 adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
228 }
229 Some(Either::Right(adapt)) => {
230 adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
231 }
232 }
233 }
234
235 pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
236 &mut self,
237 rng: &mut R,
238 hamiltonian: &mut impl Hamiltonian<M>,
239 use_best_guess: bool,
240 ) {
241 let step_size = match self.adaptation {
242 None => {
243 if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
244 val
245 } else {
246 panic!("Adaptation method must be Fixed if adaptation is None")
247 }
248 }
249 Some(Either::Left(ref adapt)) => {
250 if use_best_guess {
251 adapt.current_step_size_adapted()
252 } else {
253 adapt.current_step_size()
254 }
255 }
256 Some(Either::Right(ref adapt)) => adapt.current_step_size(),
257 };
258
259 if let Some(jitter) = self.options.jitter {
260 let jitter =
261 rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
262 let jittered_step_size = step_size * jitter;
263 *hamiltonian.step_size_mut() = jittered_step_size;
264 } else {
265 *hamiltonian.step_size_mut() = step_size;
266 }
267 }
268
269 pub fn new_collector(&self) -> AcceptanceRateCollector {
270 AcceptanceRateCollector::new()
271 }
272}
273
274#[derive(Debug, Storable)]
275pub struct Stats {
276 pub step_size_bar: f64,
277 pub mean_tree_accept: f64,
278 pub mean_tree_accept_sym: f64,
279 pub n_steps: u64,
280 pub max_energy_error: f64,
281}
282
283impl<M: Math> SamplerStats<M> for Strategy {
284 type Stats = Stats;
285 type StatsOptions = ();
286
287 fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
288 Stats {
289 step_size_bar: match self.adaptation {
290 None => {
291 if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
292 val
293 } else {
294 panic!("Adaptation method must be Fixed if adaptation is None")
295 }
296 }
297 Some(Either::Left(ref adapt)) => adapt.current_step_size_adapted(),
298 Some(Either::Right(ref adapt)) => adapt.current_step_size(),
299 },
300 mean_tree_accept: self.last_mean_tree_accept,
301 mean_tree_accept_sym: self.last_sym_mean_tree_accept,
302 n_steps: self.last_n_steps,
303 max_energy_error: self.last_max_energy_error,
304 }
305 }
306}
307
308#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
309pub struct StepSizeSettings {
310 pub target_accept: f64,
312 pub initial_step: f64,
314 pub jitter: Option<f64>,
316 pub adapt_options: StepSizeAdaptOptions,
318}
319
320impl Default for StepSizeSettings {
321 fn default() -> Self {
322 Self {
323 target_accept: 0.8,
324 initial_step: 0.1,
325 jitter: Some(0.1),
326 adapt_options: StepSizeAdaptOptions::default(),
327 }
328 }
329}