use std::f64;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct AdamOptions {
pub beta1: f64,
pub beta2: f64,
pub epsilon: f64,
pub learning_rate: f64,
}
impl Default for AdamOptions {
fn default() -> Self {
Self {
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
learning_rate: 0.05,
}
}
}
#[derive(Clone)]
pub struct Adam {
log_step: f64,
m: f64,
v: f64,
t: u64,
settings: AdamOptions,
}
impl Adam {
pub fn new(settings: AdamOptions, initial_step: f64) -> Self {
Self {
log_step: initial_step.ln(),
m: 0.0,
v: 0.0,
t: 0,
settings,
}
}
pub fn advance(&mut self, accept_stat: f64, target: f64) {
let gradient = accept_stat - target;
self.t += 1;
self.m = self.settings.beta1 * self.m + (1.0 - self.settings.beta1) * gradient;
self.v = self.settings.beta2 * self.v + (1.0 - self.settings.beta2) * gradient * gradient;
let m_hat = self.m / (1.0 - self.settings.beta1.powi(self.t as i32));
let v_hat = self.v / (1.0 - self.settings.beta2.powi(self.t as i32));
self.log_step +=
self.settings.learning_rate * m_hat / (v_hat.sqrt() + self.settings.epsilon);
}
pub fn current_step_size(&self) -> f64 {
self.log_step.exp()
}
#[allow(dead_code)]
pub fn reset(&mut self, initial_step: f64, _bias_factor: f64) {
self.log_step = initial_step.ln();
self.m = 0.0;
self.v = 0.0;
self.t = 0;
}
}