use itertools::Either;
use nuts_derive::Storable;
use rand::distr::Uniform;
use rand::{Rng, RngExt};
use serde::{Deserialize, Serialize};
use super::adam::{Adam, AdamOptions};
use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
use crate::{
Math, NutsError,
dynamics::{Direction, Hamiltonian, LeapfrogResult, Point},
nuts::{Collector, NutsOptions},
sampler_stats::SamplerStats,
};
use std::f64;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, Serialize, Default, Deserialize)]
pub enum StepSizeAdaptMethod {
#[default]
DualAverage,
Adam,
Fixed(f64),
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct StepSizeAdaptOptions {
pub method: StepSizeAdaptMethod,
pub dual_average: DualAverageOptions,
pub adam: AdamOptions,
}
impl Default for StepSizeAdaptOptions {
fn default() -> Self {
Self {
method: StepSizeAdaptMethod::DualAverage,
dual_average: DualAverageOptions::default(),
adam: AdamOptions::default(),
}
}
}
pub struct Strategy {
adaptation: Option<Either<DualAverage, Adam>>,
options: StepSizeSettings,
pub last_mean_tree_accept: f64,
pub last_sym_mean_tree_accept: f64,
pub last_n_steps: u64,
pub last_max_energy_error: f64,
}
impl Strategy {
pub fn new(options: StepSizeSettings) -> Self {
let adaptation = match options.adapt_options.method {
StepSizeAdaptMethod::DualAverage => Some(Either::Left(DualAverage::new(
options.adapt_options.dual_average,
options.initial_step,
))),
StepSizeAdaptMethod::Adam => Some(Either::Right(Adam::new(
options.adapt_options.adam,
options.initial_step,
))),
StepSizeAdaptMethod::Fixed(_) => None,
};
Self {
adaptation,
options,
last_n_steps: 0,
last_sym_mean_tree_accept: 0.0,
last_mean_tree_accept: 0.0,
last_max_energy_error: 0.0,
}
}
pub fn init<M: Math, R: Rng + ?Sized, P: Point<M>>(
&mut self,
math: &mut M,
options: &mut NutsOptions,
hamiltonian: &mut impl Hamiltonian<M, Point = P>,
position: &[f64],
rng: &mut R,
) -> Result<(), NutsError> {
if let StepSizeAdaptMethod::Fixed(step_size) = self.options.adapt_options.method {
*hamiltonian.step_size_mut() = step_size;
return Ok(());
};
let mut state = hamiltonian.init_state(math, position)?;
hamiltonian.initialize_trajectory(math, &mut state, true, rng)?;
let mut collector = AcceptanceRateCollector::new();
collector.register_init(math, &state, options);
*hamiltonian.step_size_mut() = self.options.initial_step;
let state_next = hamiltonian.leapfrog(
math,
&state,
Direction::Forward,
1.0,
state.point().initial_energy(),
1000.0,
&mut collector,
);
let LeapfrogResult::Ok(_) = state_next else {
return Ok(());
};
let accept_stat = collector.mean.current();
let dir = if accept_stat > self.options.target_accept {
Direction::Forward
} else {
Direction::Backward
};
for _ in 0..100 {
let mut collector = AcceptanceRateCollector::new();
collector.register_init(math, &state, options);
let state_next = hamiltonian.leapfrog(
math,
&state,
dir,
1.0,
state.point().initial_energy(),
1000.0,
&mut collector,
);
let LeapfrogResult::Ok(_) = state_next else {
*hamiltonian.step_size_mut() = self.options.initial_step;
return Ok(());
};
let accept_stat = collector.mean.current();
match dir {
Direction::Forward => {
if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5)
{
match self.adaptation.as_mut().expect("Adaptation must be set") {
Either::Left(adapt) => {
*adapt = DualAverage::new(
self.options.adapt_options.dual_average,
hamiltonian.step_size(),
);
}
Either::Right(adapt) => {
*adapt = Adam::new(
self.options.adapt_options.adam,
hamiltonian.step_size(),
);
}
}
return Ok(());
}
*hamiltonian.step_size_mut() *= 2.;
}
Direction::Backward => {
if (accept_stat >= self.options.target_accept)
| (hamiltonian.step_size() < 1e-10)
{
match self.adaptation.as_mut().expect("Adaptation must be set") {
Either::Left(adapt) => {
*adapt = DualAverage::new(
self.options.adapt_options.dual_average,
hamiltonian.step_size(),
);
}
Either::Right(adapt) => {
*adapt = Adam::new(
self.options.adapt_options.adam,
hamiltonian.step_size(),
);
}
}
return Ok(());
}
*hamiltonian.step_size_mut() /= 2.;
}
}
}
*hamiltonian.step_size_mut() = self.options.initial_step;
Ok(())
}
pub fn update(&mut self, collector: &AcceptanceRateCollector) {
let mean_sym = collector.mean_sym.current();
let mean = collector.mean.current();
let n_steps = collector.mean.count();
self.last_mean_tree_accept = mean;
self.last_sym_mean_tree_accept = mean_sym;
self.last_n_steps = n_steps;
self.last_max_energy_error = collector.max_energy_error;
}
pub fn update_estimator_early(&mut self) {
match self.adaptation.as_mut() {
None => {}
Some(Either::Left(adapt)) => {
adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
}
Some(Either::Right(adapt)) => {
adapt.advance(self.last_mean_tree_accept, self.options.target_accept);
}
}
}
pub fn update_estimator_late(&mut self) {
match self.adaptation.as_mut() {
None => {}
Some(Either::Left(adapt)) => {
adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
}
Some(Either::Right(adapt)) => {
adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
}
}
}
pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
&mut self,
rng: &mut R,
hamiltonian: &mut impl Hamiltonian<M>,
use_best_guess: bool,
) {
let step_size = match self.adaptation {
None => {
if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
val
} else {
panic!("Adaptation method must be Fixed if adaptation is None")
}
}
Some(Either::Left(ref adapt)) => {
if use_best_guess {
adapt.current_step_size_adapted()
} else {
adapt.current_step_size()
}
}
Some(Either::Right(ref adapt)) => adapt.current_step_size(),
};
if let Some(jitter) = self.options.jitter {
let jitter =
rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
let jittered_step_size = step_size * jitter;
*hamiltonian.step_size_mut() = jittered_step_size;
} else {
*hamiltonian.step_size_mut() = step_size;
}
}
pub fn new_collector(&self) -> AcceptanceRateCollector {
AcceptanceRateCollector::new()
}
}
#[derive(Debug, Storable)]
pub struct Stats {
pub step_size_bar: f64,
pub mean_tree_accept: f64,
pub mean_tree_accept_sym: f64,
pub n_steps: u64,
pub max_energy_error: f64,
}
impl<M: Math> SamplerStats<M> for Strategy {
type Stats = Stats;
type StatsOptions = ();
fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
Stats {
step_size_bar: match self.adaptation {
None => {
if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method {
val
} else {
panic!("Adaptation method must be Fixed if adaptation is None")
}
}
Some(Either::Left(ref adapt)) => adapt.current_step_size_adapted(),
Some(Either::Right(ref adapt)) => adapt.current_step_size(),
},
mean_tree_accept: self.last_mean_tree_accept,
mean_tree_accept_sym: self.last_sym_mean_tree_accept,
n_steps: self.last_n_steps,
max_energy_error: self.last_max_energy_error,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct StepSizeSettings {
pub target_accept: f64,
pub initial_step: f64,
pub jitter: Option<f64>,
pub adapt_options: StepSizeAdaptOptions,
}
impl Default for StepSizeSettings {
fn default() -> Self {
Self {
target_accept: 0.8,
initial_step: 0.1,
jitter: Some(0.1),
adapt_options: StepSizeAdaptOptions::default(),
}
}
}