1use itertools::Either;
2use nuts_derive::Storable;
3use rand::Rng;
4use rand_distr::Uniform;
5use serde::Serialize;
6
7use super::adam::{Adam, AdamOptions};
8use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
9use crate::{
10 Math, NutsError,
11 hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point},
12 nuts::{Collector, NutsOptions},
13 sampler_stats::SamplerStats,
14};
15use std::fmt::Debug;
16
17#[derive(Debug, Clone, Copy, Serialize, Default)]
19pub enum StepSizeAdaptMethod {
20 #[default]
22 DualAverage,
23 Adam,
25 Fixed(f64),
26}
27
28#[derive(Debug, Clone, Copy, Serialize)]
30pub struct StepSizeAdaptOptions {
31 pub method: StepSizeAdaptMethod,
32 pub dual_average: DualAverageOptions,
34 pub adam: AdamOptions,
36}
37
38impl Default for StepSizeAdaptOptions {
39 fn default() -> Self {
40 Self {
41 method: StepSizeAdaptMethod::DualAverage,
42 dual_average: DualAverageOptions::default(),
43 adam: AdamOptions::default(),
44 }
45 }
46}
47
48pub struct Strategy {
50 adaptation: Option<Either<DualAverage, Adam>>,
52 options: StepSizeSettings,
54 pub last_mean_tree_accept: f64,
56 pub last_sym_mean_tree_accept: f64,
58 pub last_n_steps: u64,
60}
61
62impl Strategy {
63 pub fn new(options: StepSizeSettings) -> Self {
64 let adaptation = match options.adapt_options.method {
65 StepSizeAdaptMethod::DualAverage => Some(Either::Left(DualAverage::new(
66 options.adapt_options.dual_average,
67 options.initial_step,
68 ))),
69 StepSizeAdaptMethod::Adam => Some(Either::Right(Adam::new(
70 options.adapt_options.adam,
71 options.initial_step,
72 ))),
73 StepSizeAdaptMethod::Fixed(_) => None,
74 };
75
76 Self {
77 adaptation,
78 options,
79 last_n_steps: 0,
80 last_sym_mean_tree_accept: 0.0,
81 last_mean_tree_accept: 0.0,
82 }
83 }
84
85 pub fn init<M: Math, R: Rng + ?Sized, P: Point<M>>(
86 &mut self,
87 math: &mut M,
88 options: &mut NutsOptions,
89 hamiltonian: &mut impl Hamiltonian<M, Point = P>,
90 position: &[f64],
91 rng: &mut R,
92 ) -> Result<(), NutsError> {
93 if let StepSizeAdaptMethod::Fixed(step_size) = self.options.adapt_options.method {
94 *hamiltonian.step_size_mut() = step_size;
95 return Ok(());
96 };
97 let mut state = hamiltonian.init_state(math, position)?;
98 hamiltonian.initialize_trajectory(math, &mut state, rng)?;
99
100 let mut collector = AcceptanceRateCollector::new();
101
102 collector.register_init(math, &state, options);
103
104 *hamiltonian.step_size_mut() = self.options.initial_step;
105
106 let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector);
107
108 let LeapfrogResult::Ok(_) = state_next else {
109 return Ok(());
110 };
111
112 let accept_stat = collector.mean.current();
113 let dir = if accept_stat > self.options.target_accept {
114 Direction::Forward
115 } else {
116 Direction::Backward
117 };
118
119 for _ in 0..100 {
120 let mut collector = AcceptanceRateCollector::new();
121 collector.register_init(math, &state, options);
122 let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector);
123 let LeapfrogResult::Ok(_) = state_next else {
124 *hamiltonian.step_size_mut() = self.options.initial_step;
125 return Ok(());
126 };
127 let accept_stat = collector.mean.current();
128 match dir {
129 Direction::Forward => {
130 if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5)
131 {
132 match self.adaptation.as_mut().expect("Adaptation must be set") {
133 Either::Left(adapt) => {
134 *adapt = DualAverage::new(
135 self.options.adapt_options.dual_average,
136 hamiltonian.step_size(),
137 );
138 }
139 Either::Right(adapt) => {
140 *adapt = Adam::new(
141 self.options.adapt_options.adam,
142 hamiltonian.step_size(),
143 );
144 }
145 }
146 return Ok(());
147 }
148 *hamiltonian.step_size_mut() *= 2.;
149 }
150 Direction::Backward => {
151 if (accept_stat >= self.options.target_accept)
152 | (hamiltonian.step_size() < 1e-10)
153 {
154 match self.adaptation.as_mut().expect("Adaptation must be set") {
155 Either::Left(adapt) => {
156 *adapt = DualAverage::new(
157 self.options.adapt_options.dual_average,
158 hamiltonian.step_size(),
159 );
160 }
161 Either::Right(adapt) => {
162 *adapt = Adam::new(
163 self.options.adapt_options.adam,
164 hamiltonian.step_size(),
165 );
166 }
167 }
168 return Ok(());
169 }
170 *hamiltonian.step_size_mut() /= 2.;
171 }
172 }
173 }
174 *hamiltonian.step_size_mut() = self.options.initial_step;
176 Ok(())
177 }
178
179 pub fn update(&mut self, collector: &AcceptanceRateCollector) {
180 let mean_sym = collector.mean_sym.current();
181 let mean = collector.mean.current();
182 let n_steps = collector.mean.count();
183 self.last_mean_tree_accept = mean;
184 self.last_sym_mean_tree_accept = mean_sym;
185 self.last_n_steps = n_steps;
186 }
187
188 pub fn update_estimator_early(&mut self) {
189 match self.adaptation.as_mut() {
190 None => {}
191 Some(Either::Left(adapt)) => {
192 adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
193 }
194 Some(Either::Right(adapt)) => {
195 adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
196 }
197 }
198 }
199
200 pub fn update_estimator_late(&mut self) {
201 match self.adaptation.as_mut() {
202 None => {}
203 Some(Either::Left(adapt)) => {
204 adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
205 }
206 Some(Either::Right(adapt)) => {
207 adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
208 }
209 }
210 }
211
212 pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
213 &mut self,
214 rng: &mut R,
215 hamiltonian: &mut impl Hamiltonian<M>,
216 use_best_guess: bool,
217 ) {
218 let step_size = match self.adaptation {
219 None => {
220 if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
221 val
222 } else {
223 panic!("Adaptation method must be Fixed if adaptation is None")
224 }
225 }
226 Some(Either::Left(ref adapt)) => {
227 if use_best_guess {
228 adapt.current_step_size_adapted()
229 } else {
230 adapt.current_step_size()
231 }
232 }
233 Some(Either::Right(ref adapt)) => adapt.current_step_size(),
234 };
235
236 if let Some(jitter) = self.options.jitter {
237 let jitter =
238 rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
239 let jittered_step_size = step_size * jitter;
240 *hamiltonian.step_size_mut() = jittered_step_size;
241 } else {
242 *hamiltonian.step_size_mut() = step_size;
243 }
244 }
245
246 pub fn new_collector(&self) -> AcceptanceRateCollector {
247 AcceptanceRateCollector::new()
248 }
249}
250
251#[derive(Debug, Storable)]
252pub struct Stats {
253 pub step_size_bar: f64,
254 pub mean_tree_accept: f64,
255 pub mean_tree_accept_sym: f64,
256 pub n_steps: u64,
257}
258
259impl<M: Math> SamplerStats<M> for Strategy {
260 type Stats = Stats;
261 type StatsOptions = ();
262
263 fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
264 Stats {
265 step_size_bar: match self.adaptation {
266 None => {
267 if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
268 val
269 } else {
270 panic!("Adaptation method must be Fixed if adaptation is None")
271 }
272 }
273 Some(Either::Left(ref adapt)) => adapt.current_step_size_adapted(),
274 Some(Either::Right(ref adapt)) => adapt.current_step_size(),
275 },
276 mean_tree_accept: self.last_mean_tree_accept,
277 mean_tree_accept_sym: self.last_sym_mean_tree_accept,
278 n_steps: self.last_n_steps,
279 }
280 }
281}
282
283#[derive(Debug, Clone, Copy, Serialize)]
284pub struct StepSizeSettings {
285 pub target_accept: f64,
287 pub initial_step: f64,
289 pub jitter: Option<f64>,
291 pub adapt_options: StepSizeAdaptOptions,
293}
294
295impl Default for StepSizeSettings {
296 fn default() -> Self {
297 Self {
298 target_accept: 0.8,
299 initial_step: 0.1,
300 jitter: Some(0.1),
301 adapt_options: StepSizeAdaptOptions::default(),
302 }
303 }
304}