use crate::nn::ParameterStore;
use std::collections::HashMap;
pub trait Optimizer {
fn step(&mut self, params: &mut ParameterStore);
fn zero_grad(&mut self, params: &mut ParameterStore);
}
pub struct SGD {
learning_rate: f32,
momentum: f32,
velocity: HashMap<String, f32>,
}
impl SGD {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
momentum: 0.0,
velocity: HashMap::new(),
}
}
pub fn with_momentum(learning_rate: f32, momentum: f32) -> Self {
Self {
learning_rate,
momentum,
velocity: HashMap::new(),
}
}
}
impl Optimizer for SGD {
fn step(&mut self, params: &mut ParameterStore) {
let param_updates: Vec<(String, f32)> = params.parameters().iter()
.map(|(name, param)| (name.clone(), param.gradient))
.collect();
for (name, grad) in param_updates {
if self.momentum > 0.0 {
let velocity = self.velocity.entry(name.clone()).or_insert(0.0);
*velocity = self.momentum * (*velocity) - self.learning_rate * grad;
if let Some(param_mut) = params.get_parameter_mut(&name) {
param_mut.value += *velocity;
}
} else {
if let Some(param_mut) = params.get_parameter_mut(&name) {
param_mut.value -= self.learning_rate * grad;
}
}
}
}
fn zero_grad(&mut self, params: &mut ParameterStore) {
params.zero_grad();
}
}
pub struct Adam {
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
t: u64, m: HashMap<String, f32>, v: HashMap<String, f32>, }
impl Adam {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
t: 0,
m: HashMap::new(),
v: HashMap::new(),
}
}
pub fn with_params(
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
t: 0,
m: HashMap::new(),
v: HashMap::new(),
}
}
}
impl Optimizer for Adam {
fn step(&mut self, params: &mut ParameterStore) {
self.t += 1;
let t_f = self.t as f32;
let param_updates: Vec<(String, f32)> = params.parameters().iter()
.map(|(name, param)| (name.clone(), param.gradient))
.collect();
for (name, grad) in param_updates {
let m_t = self.m.entry(name.clone()).or_insert(0.0);
*m_t = self.beta1 * (*m_t) + (1.0 - self.beta1) * grad;
let v_t = self.v.entry(name.clone()).or_insert(0.0);
*v_t = self.beta2 * (*v_t) + (1.0 - self.beta2) * grad * grad;
let m_hat = *m_t / (1.0 - self.beta1.powf(t_f));
let v_hat = *v_t / (1.0 - self.beta2.powf(t_f));
if let Some(param_mut) = params.get_parameter_mut(&name) {
param_mut.value -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
}
}
}
fn zero_grad(&mut self, params: &mut ParameterStore) {
params.zero_grad();
}
}
pub struct RMSprop {
learning_rate: f32,
alpha: f32,
epsilon: f32,
v: HashMap<String, f32>, }
impl RMSprop {
pub fn new(learning_rate: f32) -> Self {
Self {
learning_rate,
alpha: 0.99,
epsilon: 1e-8,
v: HashMap::new(),
}
}
pub fn with_params(learning_rate: f32, alpha: f32, epsilon: f32) -> Self {
Self {
learning_rate,
alpha,
epsilon,
v: HashMap::new(),
}
}
}
impl Optimizer for RMSprop {
fn step(&mut self, params: &mut ParameterStore) {
let param_updates: Vec<(String, f32)> = params.parameters().iter()
.map(|(name, param)| (name.clone(), param.gradient))
.collect();
for (name, grad) in param_updates {
let v_t = self.v.entry(name.clone()).or_insert(0.0);
*v_t = self.alpha * (*v_t) + (1.0 - self.alpha) * grad * grad;
if let Some(param_mut) = params.get_parameter_mut(&name) {
param_mut.value -= self.learning_rate * grad / (v_t.sqrt() + self.epsilon);
}
}
}
fn zero_grad(&mut self, params: &mut ParameterStore) {
params.zero_grad();
}
}