use crate::autograd::{get_grad, Tensor, TensorId};
pub trait Optimizer {
fn step(&mut self);
fn zero_grad(&mut self);
fn lr(&self) -> f32;
fn set_lr(&mut self, lr: f32);
}
#[derive(Debug)]
pub struct SGD {
param_ids: Vec<TensorId>,
lr: f32,
momentum: f32,
weight_decay: f32,
nesterov: bool,
velocities: Vec<Vec<f32>>,
pub(crate) initialized: bool,
}
impl SGD {
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
Self {
param_ids,
lr,
momentum: 0.0,
weight_decay: 0.0,
nesterov: false,
velocities: Vec::new(),
initialized: false,
}
}
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn with_momentum(params: Vec<&mut Tensor>, lr: f32, momentum: f32) -> Self {
let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
Self {
param_ids,
lr,
momentum,
weight_decay: 0.0,
nesterov: false,
velocities: Vec::new(),
initialized: false,
}
}
#[must_use]
pub fn nesterov(mut self) -> Self {
self.nesterov = true;
self
}
#[must_use]
pub fn weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
#[allow(clippy::if_not_else)]
fn update_param(&mut self, param: &mut Tensor, idx: usize) {
let Some(grad) = get_grad(param.id()) else {
return; };
let grad_data = grad.data();
let param_data = param.data_mut();
if !self.initialized || idx >= self.velocities.len() {
if idx >= self.velocities.len() {
self.velocities.resize(idx + 1, Vec::new());
}
self.velocities[idx] = vec![0.0; param_data.len()];
}
let velocity = &mut self.velocities[idx];
for i in 0..param_data.len() {
let mut g = grad_data[i];
if self.weight_decay != 0.0 {
g += self.weight_decay * param_data[i];
}
if self.momentum != 0.0 {
velocity[i] = self.momentum * velocity[i] + g;
if self.nesterov {
param_data[i] -= self.lr * (self.momentum * velocity[i] + g);
} else {
param_data[i] -= self.lr * velocity[i];
}
} else {
param_data[i] -= self.lr * g;
}
}
}
}
impl Optimizer for SGD {
fn step(&mut self) {
self.initialized = true;
}
fn zero_grad(&mut self) {
for &id in &self.param_ids {
crate::autograd::clear_grad(id);
}
}
fn lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
impl SGD {
pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
for (idx, param) in params.iter_mut().enumerate() {
self.update_param(param, idx);
}
self.initialized = true;
}
}
#[derive(Debug)]
pub struct Adam {
param_ids: Vec<TensorId>,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
m: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
pub(crate) t: usize,
pub(crate) initialized: bool,
}
impl Adam {
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
Self {
param_ids,
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
m: Vec::new(),
v: Vec::new(),
t: 0,
initialized: false,
}
}
#[must_use]
pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self
}
#[must_use]
pub fn eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
fn update_param(&mut self, param: &mut Tensor, idx: usize) {
let Some(grad) = get_grad(param.id()) else {
return;
};
let grad_data = grad.data();
let param_data = param.data_mut();
if !self.initialized || idx >= self.m.len() {
if idx >= self.m.len() {
self.m.resize(idx + 1, Vec::new());
self.v.resize(idx + 1, Vec::new());
}
self.m[idx] = vec![0.0; param_data.len()];
self.v[idx] = vec![0.0; param_data.len()];
}
let m = &mut self.m[idx];
let v = &mut self.v[idx];
let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..param_data.len() {
let mut g = grad_data[i];
if self.weight_decay != 0.0 {
g += self.weight_decay * param_data[i];
}
m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * g;
v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * g * g;
let m_hat = m[i] / bias_correction1;
let v_hat = v[i] / bias_correction2;
param_data[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
}
pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
self.t += 1;
for (idx, param) in params.iter_mut().enumerate() {
self.update_param(param, idx);
}
self.initialized = true;
}
}
impl Optimizer for Adam {
fn step(&mut self) {
self.t += 1;
self.initialized = true;
}
fn zero_grad(&mut self) {
for &id in &self.param_ids {
crate::autograd::clear_grad(id);
}
}
fn lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
#[derive(Debug)]
pub struct AdamW {
pub(crate) param_ids: Vec<TensorId>,
pub(crate) lr: f32,
pub(crate) beta1: f32,
pub(crate) beta2: f32,
pub(crate) eps: f32,
pub(crate) weight_decay: f32,
pub(crate) m: Vec<Vec<f32>>,
pub(crate) v: Vec<Vec<f32>>,
pub(crate) t: usize,
pub(crate) initialized: bool,
}
#[path = "rm_sprop.rs"]
mod rm_sprop;
pub use rm_sprop::*;