use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, InvariantViolationPayload, NonFiniteScalarPayload, OutOfRangePayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like},
},
ops::{
arithmetic,
reduction::{mean, mean_axes},
shape::expand_dims_axes,
},
};
fn validate_clip_threshold(clip_threshold: f32) -> Result<()> {
if !clip_threshold.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adafactor: clip_threshold",
clip_threshold as f64,
)));
}
if clip_threshold <= 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adafactor: clip_threshold",
"must be > 0.0 (used as a divisor)",
format_smolstr!("{clip_threshold}"),
)));
}
Ok(())
}
fn validate_decay_rate(decay_rate: f32) -> Result<()> {
if !decay_rate.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adafactor: decay_rate",
decay_rate as f64,
)));
}
if decay_rate > 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adafactor: decay_rate",
"must be <= 0.0 (so that 1 - step^decay_rate stays in [0, 1))",
format_smolstr!("{decay_rate}"),
)));
}
Ok(())
}
fn validate_eps(eps: (f32, f32)) -> Result<()> {
let (e1, e2) = eps;
if !e1.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adafactor: eps.0",
e1 as f64,
)));
}
if !e2.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adafactor: eps.1",
e2 as f64,
)));
}
if e1 < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adafactor: eps.0",
"must be >= 0.0",
format_smolstr!("{e1}"),
)));
}
if e2 < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adafactor: eps.1",
"must be >= 0.0",
format_smolstr!("{e2}"),
)));
}
Ok(())
}
fn validate_beta_1(beta_1: Option<f32>) -> Result<()> {
if let Some(b) = beta_1 {
if !b.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adafactor: beta_1",
b as f64,
)));
}
if !(0.0..1.0).contains(&b) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adafactor: beta_1",
"must be None or in [0.0, 1.0)",
format_smolstr!("{b}"),
)));
}
}
Ok(())
}
fn validate_weight_decay(weight_decay: f32) -> Result<()> {
if !weight_decay.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Adafactor: weight_decay",
weight_decay as f64,
)));
}
if weight_decay < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Adafactor: weight_decay",
"must be >= 0.0",
format_smolstr!("{weight_decay}"),
)));
}
Ok(())
}
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
struct FactoredState {
row: Array,
col: Array,
exp_avg: Option<Array>,
}
impl FactoredState {
fn new(row: Array, col: Array, exp_avg: Option<Array>) -> Self {
Self { row, col, exp_avg }
}
}
struct NonFactoredState {
exp_avg_sq: Array,
exp_avg: Option<Array>,
}
impl NonFactoredState {
fn new(exp_avg_sq: Array, exp_avg: Option<Array>) -> Self {
Self {
exp_avg_sq,
exp_avg,
}
}
}
enum AdafactorState {
Factored(FactoredState),
NonFactored(NonFactoredState),
}
pub struct Adafactor {
learning_rate: Option<LearningRate>,
eps: (f32, f32),
clip_threshold: f32,
decay_rate: f32,
beta_1: Option<f32>,
weight_decay: f32,
scale_parameter: bool,
relative_step: bool,
warmup_init: bool,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: HashMap<String, AdafactorState>,
}
impl Adafactor {
#[allow(clippy::too_many_arguments)]
pub fn new(
learning_rate: Option<LearningRate>,
eps: (f32, f32),
clip_threshold: f32,
decay_rate: f32,
beta_1: Option<f32>,
weight_decay: f32,
scale_parameter: bool,
relative_step: bool,
warmup_init: bool,
) -> Result<Self> {
validate_eps(eps)?;
validate_clip_threshold(clip_threshold)?;
validate_decay_rate(decay_rate)?;
validate_beta_1(beta_1)?;
validate_weight_decay(weight_decay)?;
let current_lr = match learning_rate.as_ref() {
Some(lr) => lr.try_current(0)?,
None => 0.0,
};
Ok(Self {
learning_rate,
eps,
clip_threshold,
decay_rate,
beta_1,
weight_decay,
scale_parameter,
relative_step,
warmup_init,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_python() -> Result<Self> {
Self::new(None, (1e-30, 1e-3), 1.0, -0.8, None, 0.0, true, true, false)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> Option<&LearningRate> {
self.learning_rate.as_ref()
}
#[inline(always)]
pub fn eps(&self) -> (f32, f32) {
self.eps
}
#[inline(always)]
pub fn clip_threshold(&self) -> f32 {
self.clip_threshold
}
#[inline(always)]
pub fn decay_rate(&self) -> f32 {
self.decay_rate
}
#[inline(always)]
pub fn beta_1(&self) -> Option<f32> {
self.beta_1
}
#[inline(always)]
pub fn weight_decay(&self) -> f32 {
self.weight_decay
}
#[inline(always)]
pub fn scale_parameter(&self) -> bool {
self.scale_parameter
}
#[inline(always)]
pub fn relative_step(&self) -> bool {
self.relative_step
}
#[inline(always)]
pub fn warmup_init(&self) -> bool {
self.warmup_init
}
pub fn with_learning_rate(mut self, learning_rate: Option<LearningRate>) -> Result<Self> {
let current_lr = match learning_rate.as_ref() {
Some(lr) => lr.try_current(self.step_count)?,
None => 0.0,
};
self.learning_rate = learning_rate;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_eps(mut self, eps: (f32, f32)) -> Result<Self> {
validate_eps(eps)?;
self.eps = eps;
Ok(self)
}
pub fn with_clip_threshold(mut self, clip_threshold: f32) -> Result<Self> {
validate_clip_threshold(clip_threshold)?;
self.clip_threshold = clip_threshold;
Ok(self)
}
pub fn with_decay_rate(mut self, decay_rate: f32) -> Result<Self> {
validate_decay_rate(decay_rate)?;
self.decay_rate = decay_rate;
Ok(self)
}
pub fn with_beta_1(mut self, beta_1: Option<f32>) -> Result<Self> {
if !self.state.is_empty() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"Adafactor::with_beta_1",
"cannot toggle beta_1 after parameter state is initialized (would desynchronize existing \
vs new parameters' exp_avg shape); construct a fresh Adafactor or use try_set_beta_1 \
(which preserves state on error)",
)));
}
validate_beta_1(beta_1)?;
self.beta_1 = beta_1;
Ok(self)
}
pub fn try_set_beta_1(&mut self, beta_1: Option<f32>) -> Result<()> {
if !self.state.is_empty() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"Adafactor::try_set_beta_1",
"cannot toggle beta_1 after parameter state is initialized (would desynchronize existing \
vs new parameters' exp_avg shape); construct a fresh Adafactor instead",
)));
}
validate_beta_1(beta_1)?;
self.beta_1 = beta_1;
Ok(())
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self> {
validate_weight_decay(weight_decay)?;
self.weight_decay = weight_decay;
Ok(self)
}
#[must_use]
pub fn with_scale_parameter(mut self, scale_parameter: bool) -> Self {
self.scale_parameter = scale_parameter;
self
}
#[must_use]
pub fn with_relative_step(mut self, relative_step: bool) -> Self {
self.relative_step = relative_step;
self
}
#[must_use]
pub fn with_warmup_init(mut self, warmup_init: bool) -> Self {
self.warmup_init = warmup_init;
self
}
fn init_state_for(&self, param: &Array) -> Result<AdafactorState> {
let shape = param.shape();
let exp_avg = if self.beta_1.is_some() {
Some(zeros_like(param)?)
} else {
None
};
if param.ndim() >= 2 {
let row_shape: Vec<usize> = shape[..shape.len() - 1].to_vec();
let mut col_shape: Vec<usize> = shape[..shape.len() - 2].to_vec();
col_shape.push(shape[shape.len() - 1]);
let dtype = param.dtype()?;
let row = Array::full::<f32>(&row_shape.as_slice(), 0.0)?.astype(dtype)?;
let col = Array::full::<f32>(&col_shape.as_slice(), 0.0)?.astype(dtype)?;
Ok(AdafactorState::Factored(FactoredState::new(
row, col, exp_avg,
)))
} else {
Ok(AdafactorState::NonFactored(NonFactoredState::new(
zeros_like(param)?,
exp_avg,
)))
}
}
fn compute_rms(&self, a: &Array) -> Result<Array> {
let sq = arithmetic::square(a)?;
let m = mean(&sq, false)?;
arithmetic::sqrt(&m)
}
fn compute_learning_rate(&self, parameter_rms: &Array) -> Result<Array> {
let step = self.step_count as f32;
let relative_step = if self.relative_step {
let min_step = if self.warmup_init { 1e-6 * step } else { 1e-2 };
let rsqrt_step = step.sqrt().recip();
min_step.min(rsqrt_step)
} else {
self.current_lr
};
let rel_s = scalar(relative_step)?;
if self.scale_parameter {
let eps2_s = scalar(self.eps.1)?;
let param_scale = arithmetic::maximum(&eps2_s, parameter_rms)?;
arithmetic::multiply(¶m_scale, &rel_s)
} else {
Ok(rel_s)
}
}
}
impl Optimizer for Adafactor {
fn init(&mut self, params: &Weights) -> Result<()> {
let mut out = HashMap::with_capacity(params.len());
for (key, value) in params {
out.insert(key.clone(), self.init_state_for(value)?);
}
self.state = out;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = match self.learning_rate.as_ref() {
Some(lr) => lr.try_current(self.step_count)?,
None => 0.0,
};
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let step_f = self.step_count as f32;
let beta_2_val = 1.0 - step_f.powf(self.decay_rate);
let beta_2_s = scalar(beta_2_val)?;
let one_minus_beta_2 = scalar(1.0 - beta_2_val)?;
let eps0 = scalar(self.eps.0)?;
let one = scalar(1.0)?;
let clip = scalar(self.clip_threshold)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let parameter_rms = self.compute_rms(param)?;
let learning_rate = self.compute_learning_rate(¶meter_rms)?;
let g_sq = arithmetic::square(grad)?;
let update = arithmetic::add(&g_sq, &eps0)?;
let st = self
.state
.remove(key)
.unwrap_or(self.init_state_for(param)?);
let (new_state, mut update_arr) = match st {
AdafactorState::Factored(fs) => {
let row = fs.row;
let col = fs.col;
let exp_avg = fs.exp_avg;
let ndim = grad.ndim();
let row_axis = (ndim - 1) as i32;
let col_axis = (ndim - 2) as i32;
let upd_row_mean = mean_axes(&update, &[row_axis], false)?;
let row_scaled = arithmetic::multiply(&beta_2_s, &row)?;
let row_term = arithmetic::multiply(&one_minus_beta_2, &upd_row_mean)?;
let row_new = arithmetic::add(&row_scaled, &row_term)?;
let upd_col_mean = mean_axes(&update, &[col_axis], false)?;
let col_scaled = arithmetic::multiply(&beta_2_s, &col)?;
let col_term = arithmetic::multiply(&one_minus_beta_2, &upd_col_mean)?;
let col_new = arithmetic::add(&col_scaled, &col_term)?;
let row_inner_axis = (row_new.ndim() as i32) - 1;
let row_mean = mean_axes(&row_new, &[row_inner_axis], true)?;
let row_norm = arithmetic::divide(&row_new, &row_mean)?;
let r_factor = arithmetic::rsqrt(&row_norm)?;
let c_factor = arithmetic::rsqrt(&col_new)?;
let r_expanded = expand_dims_axes(&r_factor, &[-1])?;
let c_expand_at = (ndim as i32) - 2;
let c_expanded = expand_dims_axes(&c_factor, &[c_expand_at])?;
let approx = crate::ops::linalg_basic::matmul(&r_expanded, &c_expanded)?;
let update_calc = arithmetic::multiply(&approx, grad)?;
(
AdafactorState::Factored(FactoredState::new(row_new, col_new, exp_avg)),
update_calc,
)
}
AdafactorState::NonFactored(nfs) => {
let exp_avg_sq = nfs.exp_avg_sq;
let exp_avg = nfs.exp_avg;
let old_scaled = arithmetic::multiply(&beta_2_s, &exp_avg_sq)?;
let upd_scaled = arithmetic::multiply(&one_minus_beta_2, &update)?;
let new_eas = arithmetic::add(&old_scaled, &upd_scaled)?;
let rs = arithmetic::rsqrt(&new_eas)?;
let update_calc = arithmetic::multiply(&rs, grad)?;
(
AdafactorState::NonFactored(NonFactoredState::new(new_eas, exp_avg)),
update_calc,
)
}
};
let upd_rms = self.compute_rms(&update_arr)?;
let rms_over_clip = arithmetic::divide(&upd_rms, &clip)?;
let denom = arithmetic::maximum(&one, &rms_over_clip)?;
update_arr = arithmetic::divide(&update_arr, &denom)?;
update_arr = arithmetic::multiply(&learning_rate, &update_arr)?;
let final_state = match new_state {
AdafactorState::Factored(fs) if self.beta_1.is_some() && fs.exp_avg.is_some() => {
let row = fs.row;
let col = fs.col;
let prev_ea = fs.exp_avg.unwrap();
let b1 = self.beta_1.unwrap();
let b1_s = scalar(b1)?;
let one_minus_b1 = scalar(1.0 - b1)?;
let prev_scaled = arithmetic::multiply(&b1_s, &prev_ea)?;
let upd_scaled = arithmetic::multiply(&one_minus_b1, &update_arr)?;
let new_ea = arithmetic::add(&prev_scaled, &upd_scaled)?;
update_arr = new_ea.try_clone()?;
AdafactorState::Factored(FactoredState::new(row, col, Some(new_ea)))
}
AdafactorState::NonFactored(nfs) if self.beta_1.is_some() && nfs.exp_avg.is_some() => {
let exp_avg_sq = nfs.exp_avg_sq;
let prev_ea = nfs.exp_avg.unwrap();
let b1 = self.beta_1.unwrap();
let b1_s = scalar(b1)?;
let one_minus_b1 = scalar(1.0 - b1)?;
let prev_scaled = arithmetic::multiply(&b1_s, &prev_ea)?;
let upd_scaled = arithmetic::multiply(&one_minus_b1, &update_arr)?;
let new_ea = arithmetic::add(&prev_scaled, &upd_scaled)?;
update_arr = new_ea.try_clone()?;
AdafactorState::NonFactored(NonFactoredState::new(exp_avg_sq, Some(new_ea)))
}
other => other,
};
let param_after_decay = if self.weight_decay != 0.0 {
let neg_wd_lr_s = arithmetic::multiply(&scalar(-self.weight_decay)?, &learning_rate)?;
let extra = arithmetic::multiply(param, &neg_wd_lr_s)?;
arithmetic::add(param, &extra)?
} else {
param.try_clone()?
};
let new_w = arithmetic::subtract(¶m_after_decay, &update_arr)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), final_state);
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests;