nuts_rs/stepsize/
adam.rs

1//! Adam optimizer for step size adaptation.
2//!
3//! This implements a single-parameter version of the Adam optimizer
4//! for adapting the step size in the NUTS algorithm. Unlike dual averaging,
5//! Adam maintains both first and second moment estimates of gradients,
6//! which can potentially lead to better adaptation in some scenarios.
7
8use std::f64;
9
10use serde::Serialize;
11
12/// Settings for Adam step size adaptation
13#[derive(Debug, Clone, Copy, Serialize)]
14pub struct AdamOptions {
15    /// First moment decay rate (default: 0.9)
16    pub beta1: f64,
17    /// Second moment decay rate (default: 0.999)
18    pub beta2: f64,
19    /// Small constant for numerical stability (default: 1e-8)
20    pub epsilon: f64,
21    /// Learning rate (default: 0.001)
22    pub learning_rate: f64,
23}
24
25impl Default for AdamOptions {
26    fn default() -> Self {
27        Self {
28            beta1: 0.9,
29            beta2: 0.999,
30            epsilon: 1e-8,
31            learning_rate: 0.05,
32        }
33    }
34}
35
36/// Adam optimizer for step size adaptation.
37///
38/// This implements the Adam optimizer for a single parameter (the step size).
39/// The adaptation takes the acceptance probability statistic and adjusts
40/// the step size to reach the target acceptance rate.
41#[derive(Clone)]
42pub struct Adam {
43    /// Current log step size
44    log_step: f64,
45    /// First moment estimate
46    m: f64,
47    /// Second moment estimate
48    v: f64,
49    /// Iteration counter
50    t: u64,
51    /// Adam settings
52    settings: AdamOptions,
53}
54
55impl Adam {
56    /// Create a new Adam optimizer with given settings and initial step size
57    pub fn new(settings: AdamOptions, initial_step: f64) -> Self {
58        Self {
59            log_step: initial_step.ln(),
60            m: 0.0,
61            v: 0.0,
62            t: 0,
63            settings,
64        }
65    }
66
67    /// Advance the optimizer by one step using the current acceptance statistic
68    ///
69    /// This updates the step size to move towards the target acceptance rate.
70    /// The error signal is the difference between the target and current acceptance rates.
71    pub fn advance(&mut self, accept_stat: f64, target: f64) {
72        // Compute the error/gradient - we want to minimize (target - accept_stat)²
73        // So gradient is -2 * (target - accept_stat)
74        // We simplify and just use (accept_stat - target) as our gradient
75        let gradient = accept_stat - target;
76
77        // Increment timestep
78        self.t += 1;
79
80        // Update biased first moment estimate
81        self.m = self.settings.beta1 * self.m + (1.0 - self.settings.beta1) * gradient;
82
83        // Update biased second moment estimate
84        self.v = self.settings.beta2 * self.v + (1.0 - self.settings.beta2) * gradient * gradient;
85
86        // Compute bias-corrected first moment estimate
87        let m_hat = self.m / (1.0 - self.settings.beta1.powi(self.t as i32));
88
89        // Compute bias-corrected second moment estimate
90        let v_hat = self.v / (1.0 - self.settings.beta2.powi(self.t as i32));
91
92        // Update log step size
93        // Note: if gradient is positive (accept_stat > target), we should decrease step size
94        // if gradient is negative (accept_stat < target), we should increase step size
95        self.log_step +=
96            self.settings.learning_rate * m_hat / (v_hat.sqrt() + self.settings.epsilon);
97    }
98
99    /// Get the current step size (not adapted)
100    pub fn current_step_size(&self) -> f64 {
101        self.log_step.exp()
102    }
103
104    /// Reset the optimizer with a new initial step size and bias factor
105    #[allow(dead_code)]
106    pub fn reset(&mut self, initial_step: f64, _bias_factor: f64) {
107        self.log_step = initial_step.ln();
108        self.m = 0.0;
109        self.v = 0.0;
110        self.t = 0;
111    }
112}