use crate::{common::matrix::DMat, error::NetworkError, LearningRateScheduler};
use serde::{Deserialize, Serialize};
use typetag;
use super::{Optimizer, OptimizerConfig, OptimizerConfigClone};
#[derive(Serialize, Deserialize, Clone)]
struct AdamOptimizer {
config: AdamConfig,
moment1_weights: DMat,
moment2_weights: DMat,
moment1_biases: DMat,
moment2_biases: DMat,
t: usize,
m_hat_factor: f32,
v_hat_factor: f32,
}
impl AdamOptimizer {
pub(crate) fn new(config: AdamConfig) -> Self {
Self {
config,
moment1_weights: DMat::zeros(0, 0),
moment1_biases: DMat::zeros(0, 0),
moment2_weights: DMat::zeros(0, 0),
moment2_biases: DMat::zeros(0, 0),
t: 0,
m_hat_factor: 1.0,
v_hat_factor: 1.0,
}
}
fn update_moments(&mut self, d_weights: &DMat, d_biases: &DMat) {
self.moment1_weights.apply_with_indices(|i, j, v| {
*v = self.config.beta1 * *v + (1.0 - self.config.beta1) * d_weights.at(i, j);
});
self.moment2_weights.apply_with_indices(|i, j, v| {
let g = d_weights.at(i, j);
*v = self.config.beta2 * *v + (1.0 - self.config.beta2) * g * g;
});
self.moment1_biases.apply_with_indices(|i, j, v| {
*v = self.config.beta1 * *v + (1.0 - self.config.beta1) * d_biases.at(i, j);
});
self.moment2_biases.apply_with_indices(|i, j, v| {
let g = d_biases.at(i, j);
*v = self.config.beta2 * *v + (1.0 - self.config.beta2) * g * g;
});
}
fn update_parameters(&self, weights: &mut DMat, biases: &mut DMat, step_size: f32) {
weights.apply_with_indices(|i, j, v| {
let m_hat = self.moment1_weights.at(i, j) / self.m_hat_factor;
let v_hat = self.moment2_weights.at(i, j) / self.v_hat_factor;
*v -= step_size * m_hat / (v_hat.sqrt() + self.config.epsilon);
});
biases.apply_with_indices(|i, j, v| {
let m_hat = self.moment1_biases.at(i, j) / self.m_hat_factor;
let v_hat = self.moment2_biases.at(i, j) / self.v_hat_factor;
*v -= step_size * m_hat / (v_hat.sqrt() + self.config.epsilon);
});
}
}
#[typetag::serde]
impl Optimizer for AdamOptimizer {
fn initialize(&mut self, weights: &DMat, biases: &DMat) {
self.moment1_weights = DMat::zeros(weights.rows(), weights.cols());
self.moment1_biases = DMat::zeros(biases.rows(), biases.cols());
self.moment2_weights = DMat::zeros(weights.rows(), weights.cols());
self.moment2_biases = DMat::zeros(biases.rows(), biases.cols());
}
fn update(&mut self, weights: &mut DMat, biases: &mut DMat, d_weights: &DMat, d_biases: &DMat, epoch: usize) {
if self.config.scheduler.is_some() {
let scheduler = self.config.scheduler.as_ref().unwrap();
self.config.learning_rate = scheduler.schedule(epoch, self.config.learning_rate);
}
self.t += 1;
self.m_hat_factor = 1.0 - self.config.beta1.powi(self.t as i32);
self.v_hat_factor = 1.0 - self.config.beta2.powi(self.t as i32);
let step_size = self.config.learning_rate * self.m_hat_factor / self.v_hat_factor.sqrt();
self.update_moments(d_weights, d_biases);
self.update_parameters(weights, biases, step_size);
}
fn update_learning_rate(&mut self, learning_rate: f32) {
self.config.learning_rate = learning_rate;
}
}
#[derive(Serialize, Deserialize, Clone)]
struct AdamConfig {
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
scheduler: Option<Box<dyn LearningRateScheduler>>,
}
#[typetag::serde]
impl OptimizerConfig for AdamConfig {
fn update_learning_rate(&mut self, learning_rate: f32) {
self.learning_rate = learning_rate;
}
fn create_optimizer(&self) -> Box<dyn Optimizer> {
Box::new(AdamOptimizer::new(self.clone()))
}
fn learning_rate(&self) -> f32 {
self.learning_rate
}
}
pub struct Adam {
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
scheduler: Option<Result<Box<dyn LearningRateScheduler>, NetworkError>>,
}
impl Adam {
fn new() -> Adam {
Adam {
learning_rate: 0.01,
beta1: 0.9,
beta2: 0.999,
epsilon: f32::EPSILON,
scheduler: None,
}
}
}
impl Default for Adam {
fn default() -> Self {
Self::new()
}
}
impl Adam {
pub fn learning_rate(mut self, learning_rate: f32) -> Self {
self.learning_rate = learning_rate;
self
}
pub fn beta1(mut self, beta1: f32) -> Self {
self.beta1 = beta1;
self
}
pub fn beta2(mut self, beta2: f32) -> Self {
self.beta2 = beta2;
self
}
pub fn epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
pub fn scheduler(mut self, scheduler: Result<Box<dyn LearningRateScheduler>, NetworkError>) -> Self {
self.scheduler = Some(scheduler);
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.learning_rate <= 0.0 {
return Err(NetworkError::ConfigError(format!(
"Learning rate for Adam must be greater than 0.0, but was {}",
self.learning_rate
)));
}
if self.beta1 <= 0.0 || self.beta1 >= 1.0 {
return Err(NetworkError::ConfigError(format!(
"Beta1 for Adam must be in the range (0, 1), but was {}",
self.beta1
)));
}
if self.beta2 <= 0.0 || self.beta2 >= 1.0 {
return Err(NetworkError::ConfigError(format!(
"Beta2 for Adam must be in the range (0, 1), but was {}",
self.beta2
)));
}
if self.epsilon <= 0.0 {
return Err(NetworkError::ConfigError(format!(
"Epsilon for Adam must be greater than 0.0, but was {}",
self.epsilon
)));
}
if let Some(ref scheduler) = self.scheduler {
scheduler.as_ref().map_err(|e| e.clone())?;
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn OptimizerConfig>, NetworkError> {
self.validate()?;
Ok(Box::new(AdamConfig {
learning_rate: self.learning_rate,
beta1: self.beta1,
beta2: self.beta2,
epsilon: self.epsilon,
scheduler: self.scheduler.map(|s| s.unwrap()),
}))
}
}
impl OptimizerConfigClone for AdamConfig {
fn clone_box(&self) -> Box<dyn OptimizerConfig> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{common::matrix::DMat, util::equal_approx};
#[test]
fn test_initialize() {
let adam_config = AdamConfig {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
scheduler: None,
};
let mut optimizer = AdamOptimizer::new(adam_config);
let weights = DMat::new(2, 2, &[0.1, 0.2, 0.3, 0.4]);
let biases = DMat::new(2, 1, &[0.1, 0.2]);
optimizer.initialize(&weights, &biases);
assert_eq!(optimizer.moment1_weights.rows(), 2);
assert_eq!(optimizer.moment1_weights.cols(), 2);
assert_eq!(optimizer.moment1_biases.rows(), 2);
assert_eq!(optimizer.moment1_biases.cols(), 1);
}
#[test]
fn test_update() {
let config = AdamConfig {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
scheduler: None,
};
let mut optimizer = AdamOptimizer::new(config);
let mut weights = DMat::new(2, 2, &[1.0, 1.0, 1.0, 1.0]);
let mut biases = DMat::new(2, 1, &[1.0, 1.0]);
let d_weights = DMat::new(2, 2, &[0.1, 0.1, 0.1, 0.1]);
let d_biases = DMat::new(2, 1, &[0.1, 0.1]);
optimizer.initialize(&weights, &biases);
optimizer.update(&mut weights, &mut biases, &d_weights, &d_biases, 1);
assert!(weights.at(0, 0) < 1.0);
assert!(biases.at(0, 0) < 1.0);
}
#[test]
fn test_update_learning_rate() {
let adam_config = AdamConfig {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
scheduler: None,
};
let mut optimizer = AdamOptimizer::new(adam_config);
optimizer.update_learning_rate(0.01);
assert_eq!(optimizer.config.learning_rate, 0.01);
}
#[test]
fn test_adam_optimizer() {
let mut weights = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let mut biases = DMat::new(2, 1, &[1.0, 2.0]);
let d_weights = DMat::new(2, 2, &[10.0, 11.0, 12.0, 13.0]);
let d_biases = DMat::new(2, 1, &[10.0, 11.0]);
let adam_config = AdamConfig {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
scheduler: None,
};
let mut optimizer = AdamOptimizer::new(adam_config);
optimizer.initialize(&weights, &biases);
optimizer.update(&mut weights, &mut biases, &d_weights, &d_biases, 1);
let mut expected_params = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let (rows, cols) = (weights.rows(), weights.cols());
let mut m_t;
let mut v_t;
m_t = optimizer.moment1_weights.clone();
m_t.scale(optimizer.config.beta1);
let mut m_tmp;
m_tmp = d_weights.clone();
m_tmp.scale(1.0 - optimizer.config.beta1);
m_t.add(&m_tmp);
v_t = optimizer.moment2_weights.clone();
v_t.scale(optimizer.config.beta2);
let mut v_tmp = DMat::zeros(rows, cols);
v_tmp.apply_with_indices(|r, c, v| {
let g = d_weights.at(r, c);
*v = g * g;
});
v_tmp.scale(1.0 - optimizer.config.beta2);
v_t.add(&v_tmp);
let mut m_hat;
let mut v_hat;
m_hat = m_t.clone();
m_hat.scale(1.0 / (1.0 - optimizer.config.beta1.powi(optimizer.t as i32)));
v_hat = v_t.clone();
v_hat.scale(1.0 / (1.0 - optimizer.config.beta2.powi(optimizer.t as i32)));
expected_params.apply_with_indices(|r, c, v| {
let m_h = m_hat.at(r, c);
let v_h = v_hat.at(r, c);
let update = optimizer.config.learning_rate * m_h / (v_h.sqrt() + optimizer.config.epsilon);
*v = weights.at(r, c) - update;
});
assert!(equal_approx(&weights, &expected_params, 1e-2));
}
#[test]
fn test_clone_adam_optimizer() {
let adam_config = AdamConfig {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
scheduler: None,
};
let optimizer = AdamOptimizer::new(adam_config);
let cloned_optimizer = optimizer.clone();
assert_eq!(optimizer.config.learning_rate, cloned_optimizer.config.learning_rate);
assert_eq!(optimizer.config.beta1, cloned_optimizer.config.beta1);
assert_eq!(optimizer.config.beta2, cloned_optimizer.config.beta2);
}
#[test]
fn test_clone_adam_optimizer_config() {
let adam_config = Adam::default()
.learning_rate(0.001)
.beta1(0.9)
.beta2(0.999)
.epsilon(1e-8)
.build()
.unwrap();
let cloned_config = adam_config.clone();
assert_eq!(adam_config.learning_rate(), cloned_config.learning_rate());
}
#[test]
fn test_adam_invalid_learning_rate() {
let adam = Adam::default().learning_rate(0.0);
let result = adam.build();
assert!(result.is_err());
if let Err(err) = result {
assert_eq!(
err.to_string(),
"Configuration error: Learning rate for Adam must be greater than 0.0, but was 0"
);
}
}
#[test]
fn test_adam_invalid_beta1() {
let adam = Adam::default().beta1(1.5);
let result = adam.build();
assert!(result.is_err());
if let Err(err) = result {
assert_eq!(err.to_string(), "Configuration error: Beta1 for Adam must be in the range (0, 1), but was 1.5");
}
}
#[test]
fn test_adam_invalid_beta2() {
let adam = Adam::default().beta2(1.5);
let result = adam.build();
assert!(result.is_err());
if let Err(err) = result {
assert_eq!(err.to_string(), "Configuration error: Beta2 for Adam must be in the range (0, 1), but was 1.5");
}
}
}