nuts_rs/stepsize/adam.rs
1//! Adam optimizer for step size adaptation.
2//!
3//! This implements a single-parameter version of the Adam optimizer
4//! for adapting the step size in the NUTS algorithm. Unlike dual averaging,
5//! Adam maintains both first and second moment estimates of gradients,
6//! which can potentially lead to better adaptation in some scenarios.
7
8use std::f64;
9
10use serde::Serialize;
11
12/// Settings for Adam step size adaptation
13#[derive(Debug, Clone, Copy, Serialize)]
14pub struct AdamOptions {
15 /// First moment decay rate (default: 0.9)
16 pub beta1: f64,
17 /// Second moment decay rate (default: 0.999)
18 pub beta2: f64,
19 /// Small constant for numerical stability (default: 1e-8)
20 pub epsilon: f64,
21 /// Learning rate (default: 0.001)
22 pub learning_rate: f64,
23}
24
25impl Default for AdamOptions {
26 fn default() -> Self {
27 Self {
28 beta1: 0.9,
29 beta2: 0.999,
30 epsilon: 1e-8,
31 learning_rate: 0.05,
32 }
33 }
34}
35
36/// Adam optimizer for step size adaptation.
37///
38/// This implements the Adam optimizer for a single parameter (the step size).
39/// The adaptation takes the acceptance probability statistic and adjusts
40/// the step size to reach the target acceptance rate.
41#[derive(Clone)]
42pub struct Adam {
43 /// Current log step size
44 log_step: f64,
45 /// First moment estimate
46 m: f64,
47 /// Second moment estimate
48 v: f64,
49 /// Iteration counter
50 t: u64,
51 /// Adam settings
52 settings: AdamOptions,
53}
54
55impl Adam {
56 /// Create a new Adam optimizer with given settings and initial step size
57 pub fn new(settings: AdamOptions, initial_step: f64) -> Self {
58 Self {
59 log_step: initial_step.ln(),
60 m: 0.0,
61 v: 0.0,
62 t: 0,
63 settings,
64 }
65 }
66
67 /// Advance the optimizer by one step using the current acceptance statistic
68 ///
69 /// This updates the step size to move towards the target acceptance rate.
70 /// The error signal is the difference between the target and current acceptance rates.
71 pub fn advance(&mut self, accept_stat: f64, target: f64) {
72 // Compute the error/gradient - we want to minimize (target - accept_stat)²
73 // So gradient is -2 * (target - accept_stat)
74 // We simplify and just use (accept_stat - target) as our gradient
75 let gradient = accept_stat - target;
76
77 // Increment timestep
78 self.t += 1;
79
80 // Update biased first moment estimate
81 self.m = self.settings.beta1 * self.m + (1.0 - self.settings.beta1) * gradient;
82
83 // Update biased second moment estimate
84 self.v = self.settings.beta2 * self.v + (1.0 - self.settings.beta2) * gradient * gradient;
85
86 // Compute bias-corrected first moment estimate
87 let m_hat = self.m / (1.0 - self.settings.beta1.powi(self.t as i32));
88
89 // Compute bias-corrected second moment estimate
90 let v_hat = self.v / (1.0 - self.settings.beta2.powi(self.t as i32));
91
92 // Update log step size
93 // Note: if gradient is positive (accept_stat > target), we should decrease step size
94 // if gradient is negative (accept_stat < target), we should increase step size
95 self.log_step +=
96 self.settings.learning_rate * m_hat / (v_hat.sqrt() + self.settings.epsilon);
97 }
98
99 /// Get the current step size (not adapted)
100 pub fn current_step_size(&self) -> f64 {
101 self.log_step.exp()
102 }
103
104 /// Reset the optimizer with a new initial step size and bias factor
105 #[allow(dead_code)]
106 pub fn reset(&mut self, initial_step: f64, _bias_factor: f64) {
107 self.log_step = initial_step.ln();
108 self.m = 0.0;
109 self.v = 0.0;
110 self.t = 0;
111 }
112}