use crate::core::{error::BellandeError, tensor::Tensor};
use crate::optim::{Optimizer, OptimizerState, ParameterGroup};
use std::collections::HashMap;
pub struct Adam {
params: Vec<Tensor>,
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
m: HashMap<usize, Vec<f32>>,
v: HashMap<usize, Vec<f32>>,
param_groups: Vec<ParameterGroup>,
state: OptimizerState,
}
impl Adam {
pub fn new(
params: Vec<Tensor>,
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
) -> Self {
let mut m = HashMap::new();
let mut v = HashMap::new();
for (idx, param) in params.iter().enumerate() {
m.insert(idx, vec![0.0; param.data.len()]);
v.insert(idx, vec![0.0; param.data.len()]);
}
let default_group = ParameterGroup::new(params.clone())
.with_lr(lr)
.with_weight_decay(weight_decay)
.with_betas(betas.0, betas.1)
.with_eps(eps);
Adam {
params,
lr,
betas,
eps,
weight_decay,
m,
v,
param_groups: vec![default_group],
state: OptimizerState::new(),
}
}
}
impl Optimizer for Adam {
fn step(&mut self) -> Result<(), BellandeError> {
self.state.increment_step();
let bias_correction1 = 1.0 - self.betas.0.powi(self.state.step as i32);
let bias_correction2 = 1.0 - self.betas.1.powi(self.state.step as i32);
for (idx, param) in self.params.iter_mut().enumerate() {
if let Some(grad) = ¶m.grad {
let m = self.m.get_mut(&idx).unwrap();
let v = self.v.get_mut(&idx).unwrap();
for ((p, g), (m_i, v_i)) in param
.data
.iter_mut()
.zip(grad.iter())
.zip(m.iter_mut().zip(v.iter_mut()))
{
let mut d_p = *g;
if self.weight_decay != 0.0 {
d_p += self.weight_decay * *p;
}
*m_i = self.betas.0 * *m_i + (1.0 - self.betas.0) * d_p;
*v_i = self.betas.1 * *v_i + (1.0 - self.betas.1) * d_p * d_p;
let m_hat = *m_i / bias_correction1;
let v_hat = *v_i / bias_correction2;
*p -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
}
}
Ok(())
}
fn zero_grad(&mut self) {
for param in &mut self.params {
if let Some(grad) = &mut param.grad {
grad.iter_mut().for_each(|g| *g = 0.0);
}
}
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
for group in &mut self.param_groups {
group.lr = lr;
}
}
fn parameters(&self) -> &Vec<Tensor> {
&self.params
}
fn parameters_mut(&mut self) -> &mut Vec<Tensor> {
&mut self.params
}
fn name(&self) -> &str {
"Adam"
}
fn get_param_groups(&self) -> &[ParameterGroup] {
&self.param_groups
}
fn get_param_groups_mut(&mut self) -> &mut [ParameterGroup] {
&mut self.param_groups
}
fn add_param_group(&mut self, mut group: ParameterGroup) {
let start_idx = self.params.len();
for (i, param) in group.params.iter().enumerate() {
self.m.insert(start_idx + i, vec![0.0; param.data.len()]);
self.v.insert(start_idx + i, vec![0.0; param.data.len()]);
}
self.params.extend(group.params.clone());
self.param_groups.push(group);
}
fn state(&self) -> &OptimizerState {
&self.state
}
fn state_mut(&mut self) -> &mut OptimizerState {
&mut self.state
}
}
unsafe impl Send for Adam {}
unsafe impl Sync for Adam {}