pub trait LrScheduler {
fn get_lr(&self, step: usize) -> f32;
}
pub struct ConstantLr {
pub lr: f32,
}
impl LrScheduler for ConstantLr {
fn get_lr(&self, _step: usize) -> f32 {
self.lr
}
}
pub struct WarmupConstantLr {
pub base_lr: f32,
pub warmup_steps: usize,
}
impl LrScheduler for WarmupConstantLr {
fn get_lr(&self, step: usize) -> f32 {
if self.warmup_steps == 0 || step >= self.warmup_steps {
self.base_lr
} else {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
}
}
}
pub struct CosineAnnealingLr {
pub base_lr: f32,
pub min_lr: f32,
pub total_steps: usize,
pub warmup_steps: usize,
}
impl LrScheduler for CosineAnnealingLr {
fn get_lr(&self, step: usize) -> f32 {
if step < self.warmup_steps {
if self.warmup_steps == 0 {
return self.base_lr;
}
return self.base_lr * (step as f32) / (self.warmup_steps as f32);
}
let cosine_steps = self.total_steps.saturating_sub(self.warmup_steps);
if cosine_steps == 0 {
return self.min_lr;
}
let elapsed = step.saturating_sub(self.warmup_steps);
let progress = (elapsed as f32 / cosine_steps as f32).min(1.0);
self.min_lr
+ 0.5 * (self.base_lr - self.min_lr) * (1.0 + (std::f32::consts::PI * progress).cos())
}
}
pub struct Sgd {
pub lr: f32,
pub momentum: f32,
pub weight_decay: f32,
pub nesterov: bool,
velocity: Vec<Vec<f32>>,
}
impl Sgd {
pub fn new(lr: f32) -> Self {
Self {
lr,
momentum: 0.9,
weight_decay: 0.0,
nesterov: false,
velocity: Vec::new(),
}
}
pub fn with_momentum(mut self, m: f32) -> Self {
self.momentum = m;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn with_nesterov(mut self) -> Self {
self.nesterov = true;
self
}
pub fn step(&mut self, params: &mut [&mut Vec<f32>], grads: &[Vec<f32>]) {
if self.velocity.len() != params.len() {
self.velocity = params.iter().map(|p| vec![0.0f32; p.len()]).collect();
}
for (i, (param, grad)) in params.iter_mut().zip(grads.iter()).enumerate() {
let v = &mut self.velocity[i];
for (j, (p, &g)) in param.iter_mut().zip(grad.iter()).enumerate() {
let g_eff = g + self.weight_decay * (*p);
if self.momentum == 0.0 {
*p -= self.lr * g_eff;
} else {
v[j] = self.momentum * v[j] + g_eff;
if self.nesterov {
*p -= self.lr * (self.momentum * v[j] + g_eff);
} else {
*p -= self.lr * v[j];
}
}
}
}
}
}
pub struct Adam {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
step_count: usize,
m: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
}
impl Adam {
pub fn new(lr: f32) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
step_count: 0,
m: Vec::new(),
v: Vec::new(),
}
}
pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
self.beta1 = b1;
self.beta2 = b2;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn with_epsilon(mut self, eps: f32) -> Self {
self.epsilon = eps;
self
}
pub fn step(&mut self, params: &mut [&mut Vec<f32>], grads: &[Vec<f32>]) {
if self.m.len() != params.len() {
self.m = params.iter().map(|p| vec![0.0f32; p.len()]).collect();
self.v = params.iter().map(|p| vec![0.0f32; p.len()]).collect();
}
self.step_count += 1;
let t = self.step_count as f32;
let bc1 = 1.0 - self.beta1.powf(t);
let bc2 = 1.0 - self.beta2.powf(t);
for (i, (param, grad)) in params.iter_mut().zip(grads.iter()).enumerate() {
let m_buf = &mut self.m[i];
let v_buf = &mut self.v[i];
for (j, (p, &g)) in param.iter_mut().zip(grad.iter()).enumerate() {
let g_eff = g + self.weight_decay * (*p);
m_buf[j] = self.beta1 * m_buf[j] + (1.0 - self.beta1) * g_eff;
v_buf[j] = self.beta2 * v_buf[j] + (1.0 - self.beta2) * g_eff * g_eff;
let m_hat = m_buf[j] / bc1;
let v_hat = v_buf[j] / bc2;
*p -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
}
}
}
pub fn reset(&mut self) {
self.step_count = 0;
for m in self.m.iter_mut() {
m.iter_mut().for_each(|x| *x = 0.0);
}
for v in self.v.iter_mut() {
v.iter_mut().for_each(|x| *x = 0.0);
}
}
}
pub struct AdamW {
inner: Adam,
}
impl AdamW {
pub fn new(lr: f32) -> Self {
Self {
inner: Adam::new(lr),
}
}
pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
self.inner = self.inner.with_betas(b1, b2);
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.inner.weight_decay = wd;
self
}
pub fn step(&mut self, params: &mut [&mut Vec<f32>], grads: &[Vec<f32>]) {
let wd = self.inner.weight_decay;
let lr = self.inner.lr;
if wd > 0.0 {
let decay_factor = 1.0 - lr * wd;
for param in params.iter_mut() {
for p in param.iter_mut() {
*p *= decay_factor;
}
}
}
let saved_wd = self.inner.weight_decay;
self.inner.weight_decay = 0.0;
self.inner.step(params, grads);
self.inner.weight_decay = saved_wd;
}
}
pub fn grad_norm(grads: &[Vec<f32>]) -> f32 {
let sq_sum: f32 = grads.iter().flat_map(|g| g.iter()).map(|&x| x * x).sum();
sq_sum.sqrt()
}
pub fn clip_grad_norm(grads: &mut [Vec<f32>], max_norm: f32) -> f32 {
let norm = grad_norm(grads);
if norm > max_norm && norm > 0.0 {
let scale = max_norm / norm;
for g in grads.iter_mut() {
for x in g.iter_mut() {
*x *= scale;
}
}
}
norm
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f32 = 1e-5;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPS
}
#[test]
fn test_constant_lr() {
let sched = ConstantLr { lr: 0.01 };
assert!(approx_eq(sched.get_lr(0), 0.01));
assert!(approx_eq(sched.get_lr(1000), 0.01));
}
#[test]
fn test_warmup_lr_before_warmup() {
let sched = WarmupConstantLr {
base_lr: 1.0,
warmup_steps: 100,
};
assert!(approx_eq(sched.get_lr(50), 0.5));
assert!(approx_eq(sched.get_lr(0), 0.0));
}
#[test]
fn test_warmup_lr_after_warmup() {
let sched = WarmupConstantLr {
base_lr: 3e-4,
warmup_steps: 100,
};
assert!(approx_eq(sched.get_lr(100), 3e-4));
assert!(approx_eq(sched.get_lr(500), 3e-4));
}
#[test]
fn test_cosine_annealing_at_zero() {
let sched = CosineAnnealingLr {
base_lr: 1.0,
min_lr: 0.0,
total_steps: 100,
warmup_steps: 0,
};
assert!(approx_eq(sched.get_lr(0), 1.0));
}
#[test]
fn test_cosine_annealing_at_total() {
let sched = CosineAnnealingLr {
base_lr: 1.0,
min_lr: 0.1,
total_steps: 100,
warmup_steps: 0,
};
let lr = sched.get_lr(100);
assert!(
(lr - 0.1).abs() < 1e-4,
"expected min_lr=0.1 at total_steps, got {lr}"
);
}
#[test]
fn test_sgd_step_basic() {
let mut p = vec![1.0f32, 2.0, 3.0];
let g = vec![0.1f32, 0.2, 0.3];
let mut sgd = Sgd::new(1.0).with_momentum(0.0);
sgd.step(&mut [&mut p], &[g]);
assert!(approx_eq(p[0], 0.9));
assert!(approx_eq(p[1], 1.8));
assert!(approx_eq(p[2], 2.7));
}
#[test]
fn test_sgd_with_momentum() {
let mut p = vec![1.0f32];
let g = vec![1.0f32];
let mut sgd = Sgd::new(0.1).with_momentum(0.9);
sgd.step(&mut [&mut p], std::slice::from_ref(&g));
assert!((p[0] - 0.9).abs() < 1e-5, "after step 1: {}", p[0]);
sgd.step(&mut [&mut p], &[g]);
assert!(
(p[0] - (0.9 - 0.1 * 1.9)).abs() < 1e-5,
"after step 2: {}",
p[0]
);
}
#[test]
fn test_adam_step_basic() {
let mut p = vec![1.0f32];
let g = vec![1.0f32];
let mut adam = Adam::new(0.01);
adam.step(&mut [&mut p], &[g]);
assert!(
p[0] < 1.0,
"Adam must decrease parameter on positive gradient"
);
}
#[test]
fn test_adam_step_reduces_loss() {
let mut p = vec![5.0f32];
let mut adam = Adam::new(0.1);
for _ in 0..200 {
let grad = vec![2.0 * p[0]];
adam.step(&mut [&mut p], &[grad]);
}
assert!(
p[0].abs() < 0.5,
"Adam should converge x^2 toward 0, got {}",
p[0]
);
}
#[test]
fn test_adamw_step_basic() {
let mut p = vec![1.0f32];
let g = vec![0.0f32]; let mut adamw = AdamW::new(0.01).with_weight_decay(0.1);
adamw.step(&mut [&mut p], &[g]);
assert!(p[0] < 1.0, "AdamW must shrink parameter via weight decay");
}
#[test]
fn test_clip_grad_norm_clips() {
let mut grads = vec![vec![3.0f32, 4.0]]; let norm_before = clip_grad_norm(&mut grads, 1.0);
assert!(approx_eq(norm_before, 5.0));
let norm_after = grad_norm(&grads);
assert!(
(norm_after - 1.0).abs() < 1e-5,
"clipped norm = {norm_after}"
);
}
#[test]
fn test_clip_grad_norm_no_clip() {
let mut grads = vec![vec![0.3f32, 0.4]]; let norm_before = clip_grad_norm(&mut grads, 1.0);
assert!((norm_before - 0.5).abs() < 1e-5);
assert!(approx_eq(grads[0][0], 0.3));
assert!(approx_eq(grads[0][1], 0.4));
}
#[test]
fn test_grad_norm_correct() {
let grads = vec![vec![3.0f32, 4.0]];
let n = grad_norm(&grads);
assert!(approx_eq(n, 5.0), "expected norm 5.0, got {n}");
}
}