use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, NonFiniteScalarPayload, OutOfRangePayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like},
},
ops::arithmetic,
};
fn validate_betas(
context_b1: &'static str,
context_b2: &'static str,
betas: (f32, f32),
) -> Result<()> {
let (b1, b2) = betas;
if !b1.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
context_b1, b1 as f64,
)));
}
if !b2.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
context_b2, b2 as f64,
)));
}
if !(0.0..1.0).contains(&b1) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
context_b1,
"must be in [0.0, 1.0)",
format_smolstr!("{b1}"),
)));
}
if !(0.0..1.0).contains(&b2) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
context_b2,
"must be in [0.0, 1.0)",
format_smolstr!("{b2}"),
)));
}
Ok(())
}
fn validate_eps(context: &'static str, eps: f32) -> Result<()> {
if !eps.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
context, eps as f64,
)));
}
if eps < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
context,
"must be >= 0.0",
format_smolstr!("{eps}"),
)));
}
Ok(())
}
fn validate_weight_decay(context: &'static str, weight_decay: f32) -> Result<()> {
if !weight_decay.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
context,
weight_decay as f64,
)));
}
if weight_decay < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
context,
"must be >= 0.0",
format_smolstr!("{weight_decay}"),
)));
}
Ok(())
}
type Moments = HashMap<String, (Array, Array)>;
fn fresh_moments(params: &Weights) -> Result<Moments> {
let mut out = HashMap::with_capacity(params.len());
for (key, value) in params {
out.insert(key.clone(), (zeros_like(value)?, zeros_like(value)?));
}
Ok(out)
}
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
pub struct Adam {
learning_rate: LearningRate,
betas: (f32, f32),
eps: f32,
bias_correction: bool,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
pub(crate) state: Moments,
}
impl Adam {
pub fn new(
learning_rate: impl Into<LearningRate>,
betas: (f32, f32),
eps: f32,
bias_correction: bool,
) -> Result<Self> {
validate_betas("Adam: betas.0", "Adam: betas.1", betas)?;
validate_eps("Adam: eps", eps)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
betas,
eps,
bias_correction,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, (0.9, 0.999), 1e-8, false)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn betas(&self) -> (f32, f32) {
self.betas
}
#[inline(always)]
pub fn eps(&self) -> f32 {
self.eps
}
#[inline(always)]
pub fn bias_correction(&self) -> bool {
self.bias_correction
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_betas(mut self, betas: (f32, f32)) -> Result<Self> {
validate_betas("Adam: betas.0", "Adam: betas.1", betas)?;
self.betas = betas;
Ok(self)
}
pub fn with_eps(mut self, eps: f32) -> Result<Self> {
validate_eps("Adam: eps", eps)?;
self.eps = eps;
Ok(self)
}
#[must_use]
pub fn with_bias_correction(mut self, bias_correction: bool) -> Self {
self.bias_correction = bias_correction;
self
}
fn adam_step(&mut self, key: &str, grad: &Array, param: &Array) -> Result<Array> {
let (b1, b2) = self.betas;
let lr = self.current_lr;
let (prev_m, prev_v) = match self.state.get(key) {
Some((m, v)) => (m.try_clone()?, v.try_clone()?),
None => (zeros_like(param)?, zeros_like(param)?),
};
let b1_s = scalar(b1)?;
let b2_s = scalar(b2)?;
let one_minus_b1 = scalar(1.0 - b1)?;
let one_minus_b2 = scalar(1.0 - b2)?;
let eps_s = scalar(self.eps)?;
let m_scaled = arithmetic::multiply(&b1_s, &prev_m)?;
let g_scaled = arithmetic::multiply(&one_minus_b1, grad)?;
let m_new = arithmetic::add(&m_scaled, &g_scaled)?;
let g_sq = arithmetic::square(grad)?;
let v_scaled = arithmetic::multiply(&b2_s, &prev_v)?;
let g_sq_scaled = arithmetic::multiply(&one_minus_b2, &g_sq)?;
let v_new = arithmetic::add(&v_scaled, &g_sq_scaled)?;
let step_term = if self.bias_correction {
let t = self.step_count as f32;
let c1 = lr / (1.0 - b1.powf(t));
let c2 = (1.0_f32 - b2.powf(t)).powf(-0.5);
let c1_s = scalar(c1)?;
let c2_s = scalar(c2)?;
let num = arithmetic::multiply(&c1_s, &m_new)?;
let sqrt_v = arithmetic::sqrt(&v_new)?;
let denom = arithmetic::multiply(&sqrt_v, &c2_s)?;
let denom = arithmetic::add(&denom, &eps_s)?;
arithmetic::divide(&num, &denom)?
} else {
let lr_s = scalar(lr)?;
let lr_m = arithmetic::multiply(&lr_s, &m_new)?;
let sqrt_v = arithmetic::sqrt(&v_new)?;
let denom = arithmetic::add(&sqrt_v, &eps_s)?;
arithmetic::divide(&lr_m, &denom)?
};
let new_w = arithmetic::subtract(param, &step_term)?;
self.state.insert(key.into(), (m_new, v_new));
Ok(new_w)
}
}
impl Optimizer for Adam {
fn init(&mut self, params: &Weights) -> Result<()> {
self.state = fresh_moments(params)?;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let param_clone = param.try_clone()?;
let new_w = self.adam_step(key, grad, ¶m_clone)?;
params.insert(key.clone(), new_w);
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
pub struct AdamW {
inner: Adam,
weight_decay: f32,
}
impl AdamW {
pub fn new(
learning_rate: impl Into<LearningRate>,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
bias_correction: bool,
) -> Result<Self> {
validate_weight_decay("AdamW: weight_decay", weight_decay)?;
Ok(Self {
inner: Adam::new(learning_rate, betas, eps, bias_correction)?,
weight_decay,
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, (0.9, 0.999), 1e-8, 0.01, false)
}
#[inline(always)]
pub fn weight_decay(&self) -> f32 {
self.weight_decay
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.inner.step_count)?;
self.inner.learning_rate = lr;
self.inner.current_lr = current_lr;
self.inner.lr_resolved_for_step = Some(self.inner.step_count);
Ok(self)
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self> {
validate_weight_decay("AdamW: weight_decay", weight_decay)?;
self.weight_decay = weight_decay;
Ok(self)
}
}
impl Optimizer for AdamW {
fn init(&mut self, params: &Weights) -> Result<()> {
self.inner.init(params)
}
fn preflight(&mut self) -> Result<()> {
self.inner.preflight()
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.inner.state.is_empty() {
self.inner.init(gradients)?;
}
self.inner.preflight()?;
self.inner.step_count += 1;
let lr = self.inner.current_lr;
let decay_factor = scalar(1.0 - lr * self.weight_decay)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let w_decoupled = arithmetic::multiply(param, &decay_factor)?;
let new_w = self.inner.adam_step(key, grad, &w_decoupled)?;
params.insert(key.clone(), new_w);
}
Ok(())
}
fn step(&self) -> usize {
self.inner.step_count
}
fn learning_rate(&self) -> f32 {
self.inner.current_lr
}
}
pub struct Adamax {
learning_rate: LearningRate,
betas: (f32, f32),
eps: f32,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: Moments,
}
impl Adamax {
pub fn new(learning_rate: impl Into<LearningRate>, betas: (f32, f32), eps: f32) -> Result<Self> {
validate_betas("Adamax: betas.0", "Adamax: betas.1", betas)?;
validate_eps("Adamax: eps", eps)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
betas,
eps,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, (0.9, 0.999), 1e-8)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn betas(&self) -> (f32, f32) {
self.betas
}
#[inline(always)]
pub fn eps(&self) -> f32 {
self.eps
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_betas(mut self, betas: (f32, f32)) -> Result<Self> {
validate_betas("Adamax: betas.0", "Adamax: betas.1", betas)?;
self.betas = betas;
Ok(self)
}
pub fn with_eps(mut self, eps: f32) -> Result<Self> {
validate_eps("Adamax: eps", eps)?;
self.eps = eps;
Ok(self)
}
}
impl Optimizer for Adamax {
fn init(&mut self, params: &Weights) -> Result<()> {
self.state = fresh_moments(params)?;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let (b1, b2) = self.betas;
let b1_s = scalar(b1)?;
let b2_s = scalar(b2)?;
let one_minus_b1 = scalar(1.0 - b1)?;
let eps_s = scalar(self.eps)?;
let lr_s = scalar(self.current_lr)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let (prev_m, prev_v) = match self.state.get(key) {
Some((m, v)) => (m.try_clone()?, v.try_clone()?),
None => (zeros_like(param)?, zeros_like(param)?),
};
let m_scaled = arithmetic::multiply(&b1_s, &prev_m)?;
let g_scaled = arithmetic::multiply(&one_minus_b1, grad)?;
let m_new = arithmetic::add(&m_scaled, &g_scaled)?;
let v_scaled = arithmetic::multiply(&b2_s, &prev_v)?;
let abs_g = arithmetic::abs(grad)?;
let v_new = arithmetic::maximum(&v_scaled, &abs_g)?;
let lr_m = arithmetic::multiply(&lr_s, &m_new)?;
let denom = arithmetic::add(&v_new, &eps_s)?;
let step_term = arithmetic::divide(&lr_m, &denom)?;
let new_w = arithmetic::subtract(param, &step_term)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), (m_new, v_new));
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests;