use crate::common::{OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduleFreeSGDConfig {
pub learning_rate: f32,
pub momentum: f32,
pub weight_decay: f32,
pub warmup_steps: usize,
pub r: f32,
}
impl Default for ScheduleFreeSGDConfig {
fn default() -> Self {
Self {
learning_rate: 1.0, momentum: 0.9,
weight_decay: 0.0,
warmup_steps: 0,
r: 1.0,
}
}
}
#[derive(Debug)]
pub struct ScheduleFreeSGD {
config: ScheduleFreeSGDConfig,
state: OptimizerState,
momentum_weights: HashMap<String, Vec<f32>>,
average_weights: HashMap<String, Vec<f32>>,
}
impl ScheduleFreeSGD {
pub fn new(learning_rate: f32, momentum: f32, weight_decay: f32) -> Self {
let config = ScheduleFreeSGDConfig {
learning_rate,
momentum,
weight_decay,
warmup_steps: 0,
r: 1.0,
};
Self {
config,
state: OptimizerState::new(),
momentum_weights: HashMap::new(),
average_weights: HashMap::new(),
}
}
pub fn with_config(config: ScheduleFreeSGDConfig) -> Self {
Self {
config,
state: OptimizerState::new(),
momentum_weights: HashMap::new(),
average_weights: HashMap::new(),
}
}
pub fn for_large_models() -> Self {
let config = ScheduleFreeSGDConfig {
learning_rate: 5.0,
momentum: 0.95,
weight_decay: 0.1,
warmup_steps: 1000,
r: 1.0,
};
Self::with_config(config)
}
pub fn get_effective_lr(&self) -> f32 {
if self.config.warmup_steps == 0 || self.state.step >= self.config.warmup_steps {
self.config.learning_rate
} else {
self.config.learning_rate * (self.state.step as f32 / self.config.warmup_steps as f32)
}
}
pub fn eval_mode(&mut self, parameters: &mut [Tensor]) -> Result<()> {
for param in parameters.iter_mut() {
match param {
Tensor::F32(param_data) => {
let param_id = format!("{:p}", param_data.as_ptr());
if let Some(average_weights) = self.average_weights.get(¶m_id) {
for (p, &a) in param_data.iter_mut().zip(average_weights.iter()) {
*p = a;
}
}
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for Schedule-Free SGD eval mode",
"ScheduleFreeSGD::eval_mode",
))
},
}
}
Ok(())
}
pub fn train_mode(&mut self, parameters: &mut [Tensor]) -> Result<()> {
for param in parameters.iter_mut() {
match param {
Tensor::F32(param_data) => {
let param_id = format!("{:p}", param_data.as_ptr());
if let Some(momentum_weights) = self.momentum_weights.get(¶m_id) {
for (p, &m) in param_data.iter_mut().zip(momentum_weights.iter()) {
*p = m;
}
}
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for Schedule-Free SGD train mode",
"ScheduleFreeSGD::train_mode",
))
},
}
}
Ok(())
}
}
impl Optimizer for ScheduleFreeSGD {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let effective_lr = self.get_effective_lr();
let beta = self.config.momentum;
let weight_decay = self.config.weight_decay;
let momentum_weights = self
.momentum_weights
.entry(param_id.clone())
.or_insert_with(|| param.iter().cloned().collect());
let average_weights = self
.average_weights
.entry(param_id)
.or_insert_with(|| param.iter().cloned().collect());
if momentum_weights.len() != size || average_weights.len() != size {
return Err(TrustformersError::tensor_op_error(
"Schedule-Free SGD state buffer size mismatch",
"ScheduleFreeSGD::update",
));
}
for (((p, &g), m), a) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(momentum_weights.iter_mut())
.zip(average_weights.iter_mut())
{
let mut grad_with_wd = g;
if weight_decay > 0.0 {
grad_with_wd += weight_decay * *p;
}
*m = beta * *m + effective_lr * grad_with_wd;
*a = (1.0 - beta) * *a + beta * *m;
*p = *m;
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for Schedule-Free SGD",
"ScheduleFreeSGD::update",
)),
}
}
fn zero_grad(&mut self) {
}
fn step(&mut self) {
self.state.step();
}
fn get_lr(&self) -> f32 {
self.get_effective_lr()
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
}
impl StatefulOptimizer for ScheduleFreeSGD {
type Config = ScheduleFreeSGDConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert(
"learning_rate".to_string(),
Tensor::new(vec![self.config.learning_rate])?,
);
state_dict.insert(
"momentum".to_string(),
Tensor::new(vec![self.config.momentum])?,
);
state_dict.insert(
"weight_decay".to_string(),
Tensor::new(vec![self.config.weight_decay])?,
);
state_dict.insert(
"warmup_steps".to_string(),
Tensor::new(vec![self.config.warmup_steps as f32])?,
);
state_dict.insert("r".to_string(), Tensor::new(vec![self.config.r])?);
state_dict.insert(
"step".to_string(),
Tensor::new(vec![self.state.step as f32])?,
);
for (param_id, momentum_weights) in &self.momentum_weights {
state_dict.insert(
format!("momentum_weights_{}", param_id),
Tensor::new(momentum_weights.clone())?,
);
}
for (param_id, average_weights) in &self.average_weights {
state_dict.insert(
format!("average_weights_{}", param_id),
Tensor::new(average_weights.clone())?,
);
}
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr_tensor) = state.get("learning_rate") {
if let Ok(lr_vec) = lr_tensor.data() {
if !lr_vec.is_empty() {
self.config.learning_rate = lr_vec[0];
}
}
}
if let Some(momentum_tensor) = state.get("momentum") {
if let Ok(momentum_vec) = momentum_tensor.data() {
if !momentum_vec.is_empty() {
self.config.momentum = momentum_vec[0];
}
}
}
if let Some(weight_decay_tensor) = state.get("weight_decay") {
if let Ok(weight_decay_vec) = weight_decay_tensor.data() {
if !weight_decay_vec.is_empty() {
self.config.weight_decay = weight_decay_vec[0];
}
}
}
if let Some(warmup_steps_tensor) = state.get("warmup_steps") {
if let Ok(warmup_steps_vec) = warmup_steps_tensor.data() {
if !warmup_steps_vec.is_empty() {
self.config.warmup_steps = warmup_steps_vec[0] as usize;
}
}
}
if let Some(r_tensor) = state.get("r") {
if let Ok(r_vec) = r_tensor.data() {
if !r_vec.is_empty() {
self.config.r = r_vec[0];
}
}
}
if let Some(step_tensor) = state.get("step") {
if let Ok(step_vec) = step_tensor.data() {
if !step_vec.is_empty() {
self.state.step = step_vec[0] as usize;
}
}
}
for (key, tensor) in state {
if key.starts_with("momentum_weights_") {
let param_id = key
.strip_prefix("momentum_weights_")
.expect("key must have momentum_weights_ prefix")
.to_string();
if let Ok(weights) = tensor.data() {
self.momentum_weights.insert(param_id, weights);
}
} else if key.starts_with("average_weights_") {
let param_id = key
.strip_prefix("average_weights_")
.expect("key must have average_weights_ prefix")
.to_string();
if let Ok(weights) = tensor.data() {
self.average_weights.insert(param_id, weights);
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let momentum_size: usize = self.momentum_weights.values().map(|v| v.len()).sum();
let average_size: usize = self.average_weights.values().map(|v| v.len()).sum();
StateMemoryStats {
momentum_elements: momentum_size,
variance_elements: 0, third_moment_elements: average_size,
total_bytes: ((momentum_size + average_size) * 4),
num_parameters: self.momentum_weights.len(),
}
}
fn reset_state(&mut self) {
self.state.clear();
self.momentum_weights.clear();
self.average_weights.clear();
}
fn num_parameters(&self) -> usize {
self.momentum_weights.values().map(|v| v.len()).sum()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduleFreeAdamConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub warmup_steps: usize,
pub r: f32,
}
impl Default for ScheduleFreeAdamConfig {
fn default() -> Self {
Self {
learning_rate: 0.25, beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
warmup_steps: 0,
r: 1.0,
}
}
}
#[derive(Debug)]
pub struct ScheduleFreeAdam {
config: ScheduleFreeAdamConfig,
state: OptimizerState,
momentum_weights: HashMap<String, Vec<f32>>,
average_weights: HashMap<String, Vec<f32>>,
exp_avg: HashMap<String, Vec<f32>>,
exp_avg_sq: HashMap<String, Vec<f32>>,
}
impl ScheduleFreeAdam {
pub fn new(
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
weight_decay: f32,
) -> Self {
let config = ScheduleFreeAdamConfig {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
warmup_steps: 0,
r: 1.0,
};
Self {
config,
state: OptimizerState::new(),
momentum_weights: HashMap::new(),
average_weights: HashMap::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
pub fn with_config(config: ScheduleFreeAdamConfig) -> Self {
Self {
config,
state: OptimizerState::new(),
momentum_weights: HashMap::new(),
average_weights: HashMap::new(),
exp_avg: HashMap::new(),
exp_avg_sq: HashMap::new(),
}
}
pub fn for_language_models() -> Self {
let config = ScheduleFreeAdamConfig {
learning_rate: 0.5,
beta1: 0.9,
beta2: 0.95,
epsilon: 1e-8,
weight_decay: 0.1,
warmup_steps: 2000,
r: 1.0,
};
Self::with_config(config)
}
pub fn get_effective_lr(&self) -> f32 {
if self.config.warmup_steps == 0 || self.state.step >= self.config.warmup_steps {
self.config.learning_rate
} else {
self.config.learning_rate * (self.state.step as f32 / self.config.warmup_steps as f32)
}
}
}
impl Optimizer for ScheduleFreeAdam {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
let effective_lr = self.get_effective_lr();
let beta1 = self.config.beta1;
let beta2 = self.config.beta2;
let eps = self.config.epsilon;
let weight_decay = self.config.weight_decay;
let momentum_weights = self
.momentum_weights
.entry(param_id.clone())
.or_insert_with(|| param.iter().cloned().collect());
let average_weights = self
.average_weights
.entry(param_id.clone())
.or_insert_with(|| param.iter().cloned().collect());
let exp_avg =
self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
let exp_avg_sq = self.exp_avg_sq.entry(param_id).or_insert_with(|| vec![0.0; size]);
if momentum_weights.len() != size
|| average_weights.len() != size
|| exp_avg.len() != size
|| exp_avg_sq.len() != size
{
return Err(TrustformersError::tensor_op_error(
"Schedule-Free Adam state buffer size mismatch",
"ScheduleFreeAdam::update",
));
}
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - beta1.powf(step);
let bias_correction2 = 1.0 - beta2.powf(step);
for (((((p, &g), m), a), exp_avg_val), exp_avg_sq_val) in param
.iter_mut()
.zip(grad_arr.iter())
.zip(momentum_weights.iter_mut())
.zip(average_weights.iter_mut())
.zip(exp_avg.iter_mut())
.zip(exp_avg_sq.iter_mut())
{
let mut grad_with_wd = g;
if weight_decay > 0.0 {
grad_with_wd += weight_decay * *p;
}
*exp_avg_val = beta1 * *exp_avg_val + (1.0 - beta1) * grad_with_wd;
*exp_avg_sq_val =
beta2 * *exp_avg_sq_val + (1.0 - beta2) * grad_with_wd * grad_with_wd;
let m_hat = *exp_avg_val / bias_correction1;
let v_hat = *exp_avg_sq_val / bias_correction2;
let adam_update = effective_lr * m_hat / (v_hat.sqrt() + eps);
*m = beta1 * *m + adam_update;
*a = (1.0 - beta1) * *a + beta1 * *m;
*p = *m;
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for Schedule-Free Adam",
"ScheduleFreeAdam::update",
)),
}
}
fn zero_grad(&mut self) {
}
fn step(&mut self) {
self.state.step();
}
fn get_lr(&self) -> f32 {
self.get_effective_lr()
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
}
impl StatefulOptimizer for ScheduleFreeAdam {
type Config = ScheduleFreeAdamConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert(
"learning_rate".to_string(),
Tensor::new(vec![self.config.learning_rate])?,
);
state_dict.insert("beta1".to_string(), Tensor::new(vec![self.config.beta1])?);
state_dict.insert("beta2".to_string(), Tensor::new(vec![self.config.beta2])?);
state_dict.insert(
"epsilon".to_string(),
Tensor::new(vec![self.config.epsilon])?,
);
state_dict.insert(
"weight_decay".to_string(),
Tensor::new(vec![self.config.weight_decay])?,
);
state_dict.insert(
"warmup_steps".to_string(),
Tensor::new(vec![self.config.warmup_steps as f32])?,
);
state_dict.insert("r".to_string(), Tensor::new(vec![self.config.r])?);
state_dict.insert(
"step".to_string(),
Tensor::new(vec![self.state.step as f32])?,
);
for (param_id, momentum_weights) in &self.momentum_weights {
state_dict.insert(
format!("momentum_weights_{}", param_id),
Tensor::new(momentum_weights.clone())?,
);
}
for (param_id, average_weights) in &self.average_weights {
state_dict.insert(
format!("average_weights_{}", param_id),
Tensor::new(average_weights.clone())?,
);
}
for (param_id, exp_avg) in &self.exp_avg {
state_dict.insert(
format!("exp_avg_{}", param_id),
Tensor::new(exp_avg.clone())?,
);
}
for (param_id, exp_avg_sq) in &self.exp_avg_sq {
state_dict.insert(
format!("exp_avg_sq_{}", param_id),
Tensor::new(exp_avg_sq.clone())?,
);
}
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr_tensor) = state.get("learning_rate") {
if let Ok(lr_vec) = lr_tensor.data() {
if !lr_vec.is_empty() {
self.config.learning_rate = lr_vec[0];
}
}
}
for (key, tensor) in state {
if key.starts_with("momentum_weights_") {
let param_id = key
.strip_prefix("momentum_weights_")
.expect("key must have momentum_weights_ prefix")
.to_string();
if let Ok(weights) = tensor.data() {
self.momentum_weights.insert(param_id, weights);
}
} else if key.starts_with("average_weights_") {
let param_id = key
.strip_prefix("average_weights_")
.expect("key must have average_weights_ prefix")
.to_string();
if let Ok(weights) = tensor.data() {
self.average_weights.insert(param_id, weights);
}
} else if key.starts_with("exp_avg_") && !key.starts_with("exp_avg_sq_") {
let param_id = key
.strip_prefix("exp_avg_")
.expect("key must have exp_avg_ prefix")
.to_string();
if let Ok(weights) = tensor.data() {
self.exp_avg.insert(param_id, weights);
}
} else if key.starts_with("exp_avg_sq_") {
let param_id = key
.strip_prefix("exp_avg_sq_")
.expect("key must have exp_avg_sq_ prefix")
.to_string();
if let Ok(weights) = tensor.data() {
self.exp_avg_sq.insert(param_id, weights);
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let momentum_size: usize = self.momentum_weights.values().map(|v| v.len()).sum();
let average_size: usize = self.average_weights.values().map(|v| v.len()).sum();
let exp_avg_size: usize = self.exp_avg.values().map(|v| v.len()).sum();
let exp_avg_sq_size: usize = self.exp_avg_sq.values().map(|v| v.len()).sum();
let total_params = momentum_size + average_size + exp_avg_size + exp_avg_sq_size;
let _total_buffers = self.momentum_weights.len()
+ self.average_weights.len()
+ self.exp_avg.len()
+ self.exp_avg_sq.len();
StateMemoryStats {
momentum_elements: momentum_size + exp_avg_size,
variance_elements: average_size + exp_avg_sq_size,
third_moment_elements: 0,
total_bytes: total_params * 4,
num_parameters: self.momentum_weights.len(),
}
}
fn reset_state(&mut self) {
self.state.clear();
self.momentum_weights.clear();
self.average_weights.clear();
self.exp_avg.clear();
self.exp_avg_sq.clear();
}
fn num_parameters(&self) -> usize {
self.momentum_weights.values().map(|v| v.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schedule_free_sgd_creation() {
let optimizer = ScheduleFreeSGD::new(1.0, 0.9, 0.01);
assert_eq!(optimizer.get_lr(), 1.0);
assert_eq!(optimizer.config.momentum, 0.9);
assert_eq!(optimizer.config.weight_decay, 0.01);
}
#[test]
fn test_schedule_free_sgd_for_large_models() {
let optimizer = ScheduleFreeSGD::for_large_models();
assert_eq!(optimizer.config.learning_rate, 5.0);
assert_eq!(optimizer.config.momentum, 0.95);
assert_eq!(optimizer.config.weight_decay, 0.1);
assert_eq!(optimizer.config.warmup_steps, 1000);
}
#[test]
fn test_schedule_free_adam_creation() {
let optimizer = ScheduleFreeAdam::new(0.25, 0.9, 0.999, 1e-8, 0.01);
assert_eq!(optimizer.get_lr(), 0.25);
assert_eq!(optimizer.config.beta1, 0.9);
assert_eq!(optimizer.config.beta2, 0.999);
assert_eq!(optimizer.config.epsilon, 1e-8);
assert_eq!(optimizer.config.weight_decay, 0.01);
}
#[test]
fn test_schedule_free_adam_for_language_models() {
let optimizer = ScheduleFreeAdam::for_language_models();
assert_eq!(optimizer.config.learning_rate, 0.5);
assert_eq!(optimizer.config.beta1, 0.9);
assert_eq!(optimizer.config.beta2, 0.95);
assert_eq!(optimizer.config.weight_decay, 0.1);
assert_eq!(optimizer.config.warmup_steps, 2000);
}
#[test]
fn test_memory_usage() {
let optimizer = ScheduleFreeAdam::new(0.1, 0.9, 0.999, 1e-8, 0.0);
let memory_stats = optimizer.memory_usage();
assert_eq!(memory_stats.num_parameters, 0); assert_eq!(memory_stats.total_bytes, 0);
}
}