use crate::common::{OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug, Clone)]
pub struct AdamConfig {
pub lr: f32,
pub betas: (f32, f32),
pub eps: f32,
pub weight_decay: f32,
}
impl Default for AdamConfig {
fn default() -> Self {
Self {
lr: 1e-3,
betas: (0.9, 0.999),
eps: 1e-8,
weight_decay: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct AdamWConfig {
pub lr: f32,
pub betas: (f32, f32),
pub eps: f32,
pub weight_decay: f32,
}
impl Default for AdamWConfig {
fn default() -> Self {
Self {
lr: 1e-4,
betas: (0.9, 0.999),
eps: 1e-8,
weight_decay: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct Adam {
config: AdamConfig,
state: OptimizerState,
exp_avg: HashMap<String, Vec<f32>>,
exp_avg_sq: HashMap<String, Vec<f32>>,
}
impl Adam {
pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
Self {
config: AdamConfig {
lr,
betas,
eps,
weight_decay,
},
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
pub fn from_config(config: AdamConfig) -> Self {
Self {
config,
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
}
impl Optimizer for Adam {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let exp_avg =
self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
let exp_avg_sq = self.exp_avg_sq.entry(param_id).or_insert_with(|| vec![0.0; size]);
if exp_avg.len() != size || exp_avg_sq.len() != size {
return Err(TrustformersError::tensor_op_error(
"Adam state buffer size mismatch",
"Adam::update",
));
}
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - self.config.betas.0.powf(step);
let bias_correction2 = 1.0 - self.config.betas.1.powf(step);
for ((p, g), (m, v)) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(exp_avg.iter_mut().zip(exp_avg_sq.iter_mut()))
{
let grad_with_wd = if self.config.weight_decay != 0.0 {
g + self.config.weight_decay * *p
} else {
*g
};
*m = self.config.betas.0 * *m + (1.0 - self.config.betas.0) * grad_with_wd;
*v = self.config.betas.1 * *v
+ (1.0 - self.config.betas.1) * grad_with_wd * grad_with_wd;
let m_hat = *m / bias_correction1;
let v_hat = *v / bias_correction2;
*p -= self.config.lr * m_hat / (v_hat.sqrt() + self.config.eps);
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for Adam",
"Adam::update",
)),
}
}
fn zero_grad(&mut self) {}
fn step(&mut self) {
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.config.lr
}
fn set_lr(&mut self, lr: f32) {
self.config.lr = lr;
}
}
impl StatefulOptimizer for Adam {
type Config = AdamConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert("lr".to_string(), Tensor::new(vec![self.config.lr])?);
state_dict.insert("beta1".to_string(), Tensor::new(vec![self.config.betas.0])?);
state_dict.insert("beta2".to_string(), Tensor::new(vec![self.config.betas.1])?);
state_dict.insert("eps".to_string(), Tensor::new(vec![self.config.eps])?);
state_dict.insert(
"weight_decay".to_string(),
Tensor::new(vec![self.config.weight_decay])?,
);
state_dict.insert(
"step".to_string(),
Tensor::new(vec![self.state.step as f32])?,
);
for (param_id, exp_avg) in &self.exp_avg {
state_dict.insert(
format!("exp_avg_{}", param_id),
Tensor::new(exp_avg.clone())?,
);
}
for (param_id, exp_avg_sq) in &self.exp_avg_sq {
state_dict.insert(
format!("exp_avg_sq_{}", param_id),
Tensor::new(exp_avg_sq.clone())?,
);
}
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr_tensor) = state.get("lr") {
if let Ok(lr_vec) = lr_tensor.data() {
if !lr_vec.is_empty() {
self.config.lr = lr_vec[0];
}
}
}
if let Some(beta1_tensor) = state.get("beta1") {
if let Ok(beta1_vec) = beta1_tensor.data() {
if !beta1_vec.is_empty() {
self.config.betas.0 = beta1_vec[0];
}
}
}
if let Some(beta2_tensor) = state.get("beta2") {
if let Ok(beta2_vec) = beta2_tensor.data() {
if !beta2_vec.is_empty() {
self.config.betas.1 = beta2_vec[0];
}
}
}
if let Some(eps_tensor) = state.get("eps") {
if let Ok(eps_vec) = eps_tensor.data() {
if !eps_vec.is_empty() {
self.config.eps = eps_vec[0];
}
}
}
if let Some(weight_decay_tensor) = state.get("weight_decay") {
if let Ok(weight_decay_vec) = weight_decay_tensor.data() {
if !weight_decay_vec.is_empty() {
self.config.weight_decay = weight_decay_vec[0];
}
}
}
if let Some(step_tensor) = state.get("step") {
if let Ok(step_vec) = step_tensor.data() {
if !step_vec.is_empty() {
self.state.step = step_vec[0] as usize;
}
}
}
for (key, tensor) in state.iter() {
if key.starts_with("exp_avg_") && !key.starts_with("exp_avg_sq_") {
let param_id = key.trim_start_matches("exp_avg_");
if let Ok(exp_avg) = tensor.data() {
self.exp_avg.insert(param_id.to_string(), exp_avg.clone());
}
} else if key.starts_with("exp_avg_sq_") {
let param_id = key.trim_start_matches("exp_avg_sq_");
if let Ok(exp_avg_sq) = tensor.data() {
self.exp_avg_sq.insert(param_id.to_string(), exp_avg_sq.clone());
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let mut momentum_elements = 0;
let mut variance_elements = 0;
for exp_avg in self.exp_avg.values() {
momentum_elements += exp_avg.len();
}
for exp_avg_sq in self.exp_avg_sq.values() {
variance_elements += exp_avg_sq.len();
}
let total_elements = momentum_elements + variance_elements;
let total_bytes = total_elements * std::mem::size_of::<f32>();
StateMemoryStats {
momentum_elements,
variance_elements,
third_moment_elements: 0,
total_bytes,
num_parameters: momentum_elements,
}
}
fn reset_state(&mut self) {
self.state.step = 0;
self.exp_avg.clear();
self.exp_avg_sq.clear();
}
fn num_parameters(&self) -> usize {
self.exp_avg.values().map(|v| v.len()).sum()
}
}
#[derive(Debug)]
pub struct AdamW {
config: AdamWConfig,
state: OptimizerState,
exp_avg: HashMap<String, Vec<f32>>,
exp_avg_sq: HashMap<String, Vec<f32>>,
}
impl AdamW {
pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
Self {
config: AdamWConfig {
lr,
betas,
eps,
weight_decay,
},
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
pub fn from_config(config: AdamWConfig) -> Self {
Self {
config,
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
}
impl Optimizer for AdamW {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let exp_avg =
self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
let exp_avg_sq = self.exp_avg_sq.entry(param_id).or_insert_with(|| vec![0.0; size]);
if exp_avg.len() != size || exp_avg_sq.len() != size {
return Err(TrustformersError::tensor_op_error(
"AdamW state buffer size mismatch",
"AdamW::update",
));
}
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - self.config.betas.0.powf(step);
let bias_correction2 = 1.0 - self.config.betas.1.powf(step);
for ((p, g), (m, v)) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(exp_avg.iter_mut().zip(exp_avg_sq.iter_mut()))
{
*m = self.config.betas.0 * *m + (1.0 - self.config.betas.0) * g;
*v = self.config.betas.1 * *v + (1.0 - self.config.betas.1) * g * g;
let m_hat = *m / bias_correction1;
let v_hat = *v / bias_correction2;
if self.config.weight_decay != 0.0 {
*p -= self.config.lr * self.config.weight_decay * *p;
}
*p -= self.config.lr * m_hat / (v_hat.sqrt() + self.config.eps);
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for AdamW",
"AdamW::update",
)),
}
}
fn zero_grad(&mut self) {}
fn step(&mut self) {
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.config.lr
}
fn set_lr(&mut self, lr: f32) {
self.config.lr = lr;
}
}
impl StatefulOptimizer for AdamW {
type Config = AdamWConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert("lr".to_string(), Tensor::new(vec![self.config.lr])?);
state_dict.insert("beta1".to_string(), Tensor::new(vec![self.config.betas.0])?);
state_dict.insert("beta2".to_string(), Tensor::new(vec![self.config.betas.1])?);
state_dict.insert("eps".to_string(), Tensor::new(vec![self.config.eps])?);
state_dict.insert(
"weight_decay".to_string(),
Tensor::new(vec![self.config.weight_decay])?,
);
state_dict.insert(
"step".to_string(),
Tensor::new(vec![self.state.step as f32])?,
);
for (param_id, exp_avg) in &self.exp_avg {
state_dict.insert(
format!("exp_avg_{}", param_id),
Tensor::new(exp_avg.clone())?,
);
}
for (param_id, exp_avg_sq) in &self.exp_avg_sq {
state_dict.insert(
format!("exp_avg_sq_{}", param_id),
Tensor::new(exp_avg_sq.clone())?,
);
}
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr_tensor) = state.get("lr") {
if let Ok(lr_vec) = lr_tensor.data() {
if !lr_vec.is_empty() {
self.config.lr = lr_vec[0];
}
}
}
if let Some(beta1_tensor) = state.get("beta1") {
if let Ok(beta1_vec) = beta1_tensor.data() {
if !beta1_vec.is_empty() {
self.config.betas.0 = beta1_vec[0];
}
}
}
if let Some(beta2_tensor) = state.get("beta2") {
if let Ok(beta2_vec) = beta2_tensor.data() {
if !beta2_vec.is_empty() {
self.config.betas.1 = beta2_vec[0];
}
}
}
if let Some(eps_tensor) = state.get("eps") {
if let Ok(eps_vec) = eps_tensor.data() {
if !eps_vec.is_empty() {
self.config.eps = eps_vec[0];
}
}
}
if let Some(weight_decay_tensor) = state.get("weight_decay") {
if let Ok(weight_decay_vec) = weight_decay_tensor.data() {
if !weight_decay_vec.is_empty() {
self.config.weight_decay = weight_decay_vec[0];
}
}
}
if let Some(step_tensor) = state.get("step") {
if let Ok(step_vec) = step_tensor.data() {
if !step_vec.is_empty() {
self.state.step = step_vec[0] as usize;
}
}
}
for (key, tensor) in state.iter() {
if key.starts_with("exp_avg_") && !key.starts_with("exp_avg_sq_") {
let param_id = key.trim_start_matches("exp_avg_");
if let Ok(exp_avg) = tensor.data() {
self.exp_avg.insert(param_id.to_string(), exp_avg.clone());
}
} else if key.starts_with("exp_avg_sq_") {
let param_id = key.trim_start_matches("exp_avg_sq_");
if let Ok(exp_avg_sq) = tensor.data() {
self.exp_avg_sq.insert(param_id.to_string(), exp_avg_sq.clone());
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let mut momentum_elements = 0;
let mut variance_elements = 0;
for exp_avg in self.exp_avg.values() {
momentum_elements += exp_avg.len();
}
for exp_avg_sq in self.exp_avg_sq.values() {
variance_elements += exp_avg_sq.len();
}
let total_elements = momentum_elements + variance_elements;
let total_bytes = total_elements * std::mem::size_of::<f32>();
StateMemoryStats {
momentum_elements,
variance_elements,
third_moment_elements: 0,
total_bytes,
num_parameters: momentum_elements,
}
}
fn reset_state(&mut self) {
self.state.step = 0;
self.exp_avg.clear();
self.exp_avg_sq.clear();
}
fn num_parameters(&self) -> usize {
self.exp_avg.values().map(|v| v.len()).sum()
}
}
#[derive(Debug)]
pub struct RAdam {
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
state: OptimizerState,
exp_avg: HashMap<String, Vec<f32>>,
exp_avg_sq: HashMap<String, Vec<f32>>,
}
impl RAdam {
pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
Self {
lr,
betas,
eps,
weight_decay,
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
fn compute_variance_rectification(&self, step: f32) -> (f32, bool) {
let beta2 = self.betas.1;
let rho_inf = 2.0 / (1.0 - beta2) - 1.0;
let rho_t = rho_inf - 2.0 * step * beta2.powf(step) / (1.0 - beta2.powf(step));
if rho_t > 4.0 {
let r_t = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf
/ ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
.sqrt();
(r_t, true)
} else {
(1.0, false)
}
}
}
impl Optimizer for RAdam {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - self.betas.0.powf(step);
let (r_t, use_adaptive) = self.compute_variance_rectification(step);
let exp_avg =
self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
let exp_avg_sq = self.exp_avg_sq.entry(param_id).or_insert_with(|| vec![0.0; size]);
if exp_avg.len() != size || exp_avg_sq.len() != size {
return Err(TrustformersError::tensor_op_error(
"RAdam state buffer size mismatch",
"RAdam::update",
));
}
for ((p, g), (m, v)) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(exp_avg.iter_mut().zip(exp_avg_sq.iter_mut()))
{
let grad_with_wd =
if self.weight_decay != 0.0 { g + self.weight_decay * *p } else { *g };
*m = self.betas.0 * *m + (1.0 - self.betas.0) * grad_with_wd;
*v = self.betas.1 * *v + (1.0 - self.betas.1) * grad_with_wd * grad_with_wd;
let m_hat = *m / bias_correction1;
if use_adaptive {
let bias_correction2 = 1.0 - self.betas.1.powf(step);
let v_hat = *v / bias_correction2;
*p -= self.lr * r_t * m_hat / (v_hat.sqrt() + self.eps);
} else {
*p -= self.lr * m_hat;
}
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for RAdam",
"RAdam::update",
)),
}
}
fn zero_grad(&mut self) {}
fn step(&mut self) {
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
#[derive(Debug)]
pub struct NAdam {
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
state: OptimizerState,
exp_avg: HashMap<String, Vec<f32>>,
exp_avg_sq: HashMap<String, Vec<f32>>,
}
impl NAdam {
pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
Self {
lr,
betas,
eps,
weight_decay,
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
}
impl Optimizer for NAdam {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let exp_avg =
self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
let exp_avg_sq = self.exp_avg_sq.entry(param_id).or_insert_with(|| vec![0.0; size]);
if exp_avg.len() != size || exp_avg_sq.len() != size {
return Err(TrustformersError::tensor_op_error(
"NAdam state buffer size mismatch",
"NAdam::update",
));
}
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - self.betas.0.powf(step);
let bias_correction2 = 1.0 - self.betas.1.powf(step);
for ((p, g), (m, v)) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(exp_avg.iter_mut().zip(exp_avg_sq.iter_mut()))
{
let grad_with_wd =
if self.weight_decay != 0.0 { g + self.weight_decay * *p } else { *g };
*m = self.betas.0 * *m + (1.0 - self.betas.0) * grad_with_wd;
*v = self.betas.1 * *v + (1.0 - self.betas.1) * grad_with_wd * grad_with_wd;
let m_hat = *m / bias_correction1;
let v_hat = *v / bias_correction2;
let nesterov_m = self.betas.0 * m_hat
+ (1.0 - self.betas.0) * grad_with_wd / bias_correction1;
*p -= self.lr * nesterov_m / (v_hat.sqrt() + self.eps);
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for NAdam",
"NAdam::update",
)),
}
}
fn zero_grad(&mut self) {}
fn step(&mut self) {
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
#[derive(Debug)]
pub struct AdaBelief {
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
state: OptimizerState,
exp_avg: HashMap<String, Vec<f32>>,
exp_avg_var: HashMap<String, Vec<f32>>,
}
impl AdaBelief {
pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
Self {
lr,
betas,
eps,
weight_decay,
state: OptimizerState::new(),
exp_avg: HashMap::new(),
exp_avg_var: HashMap::new(),
}
}
}
impl Optimizer for AdaBelief {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let exp_avg =
self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
let exp_avg_var =
self.exp_avg_var.entry(param_id).or_insert_with(|| vec![0.0; size]);
if exp_avg.len() != size || exp_avg_var.len() != size {
return Err(TrustformersError::tensor_op_error(
"AdaBelief state buffer size mismatch",
"AdaBelief::update",
));
}
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - self.betas.0.powf(step);
let bias_correction2 = 1.0 - self.betas.1.powf(step);
for ((p, g), (m, s)) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(exp_avg.iter_mut().zip(exp_avg_var.iter_mut()))
{
let grad_with_wd =
if self.weight_decay != 0.0 { g + self.weight_decay * *p } else { *g };
*m = self.betas.0 * *m + (1.0 - self.betas.0) * grad_with_wd;
let grad_residual = grad_with_wd - *m;
*s = self.betas.1 * *s + (1.0 - self.betas.1) * grad_residual * grad_residual;
let m_hat = *m / bias_correction1;
let s_hat = *s / bias_correction2;
*p -= self.lr * m_hat / (s_hat.sqrt() + self.eps);
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for AdaBelief",
"AdaBelief::update",
)),
}
}
fn zero_grad(&mut self) {}
fn step(&mut self) {
self.state.step += 1;
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
#[cfg(test)]
#[path = "adam_tests.rs"]
mod adam_tests;