use crate::optimizer::OptimizerState;
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QHMConfig {
pub learning_rate: f32,
pub momentum: f32,
pub nu: f32,
pub weight_decay: f32,
}
impl Default for QHMConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
momentum: 0.9,
nu: 0.7,
weight_decay: 0.0,
}
}
}
#[derive(Debug)]
pub struct QHM {
config: QHMConfig,
momentum_buffers: HashMap<usize, Tensor>,
current_step: usize,
}
impl QHM {
pub fn new(config: QHMConfig) -> Self {
Self {
config,
momentum_buffers: HashMap::new(),
current_step: 0,
}
}
pub fn with_defaults(learning_rate: f32, momentum: f32, nu: f32) -> Self {
Self::new(QHMConfig {
learning_rate,
momentum,
nu,
weight_decay: 0.0,
})
}
pub fn get_config(&self) -> &QHMConfig {
&self.config
}
pub fn set_config(&mut self, config: QHMConfig) {
self.config = config;
}
}
impl OptimizerState for QHM {
fn zero_grad(&mut self) -> Result<()> {
Ok(())
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
let effective_grad = if self.config.weight_decay > 0.0 {
gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
} else {
gradient
};
let momentum_buffer = if let Some(buffer) = self.momentum_buffers.get(¶m_id) {
let updated = buffer
.mul_scalar(self.config.momentum)?
.add(&effective_grad.mul_scalar(1.0 - self.config.momentum)?)?;
self.momentum_buffers.insert(param_id, updated.clone());
updated
} else {
let initial_momentum = effective_grad.clone();
self.momentum_buffers.insert(param_id, initial_momentum.clone());
initial_momentum
};
let update_direction = effective_grad
.mul_scalar(self.config.nu)?
.add(&momentum_buffer.mul_scalar(1.0 - self.config.nu)?)?;
*parameter = parameter.sub(&update_direction.mul_scalar(self.config.learning_rate)?)?;
}
Ok(())
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"learning_rate".to_string(),
Tensor::scalar(self.config.learning_rate)?,
);
state.insert(
"momentum".to_string(),
Tensor::scalar(self.config.momentum)?,
);
state.insert("nu".to_string(), Tensor::scalar(self.config.nu)?);
state.insert(
"weight_decay".to_string(),
Tensor::scalar(self.config.weight_decay)?,
);
state.insert(
"current_step".to_string(),
Tensor::scalar(self.current_step as f32)?,
);
for (¶m_id, buffer) in &self.momentum_buffers {
state.insert(format!("momentum_buffer_{}", param_id), buffer.clone());
}
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr) = state.get("learning_rate") {
self.config.learning_rate = lr.to_scalar()?;
}
if let Some(momentum) = state.get("momentum") {
self.config.momentum = momentum.to_scalar()?;
}
if let Some(nu) = state.get("nu") {
self.config.nu = nu.to_scalar()?;
}
if let Some(wd) = state.get("weight_decay") {
self.config.weight_decay = wd.to_scalar()?;
}
if let Some(step) = state.get("current_step") {
self.current_step = step.to_scalar()? as usize;
}
self.momentum_buffers.clear();
for (key, tensor) in state {
if let Some(param_id_str) = key.strip_prefix("momentum_buffer_") {
if let Ok(param_id) = param_id_str.parse::<usize>() {
self.momentum_buffers.insert(param_id, tensor);
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggMoConfig {
pub learning_rate: f32,
pub momentum_coefficients: Vec<f32>,
pub weight_decay: f32,
}
impl Default for AggMoConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
momentum_coefficients: vec![0.0, 0.9, 0.99],
weight_decay: 0.0,
}
}
}
#[derive(Debug)]
pub struct AggMo {
config: AggMoConfig,
momentum_buffers: HashMap<usize, Vec<Tensor>>, current_step: usize,
}
impl AggMo {
pub fn new(config: AggMoConfig) -> Self {
assert!(
!config.momentum_coefficients.is_empty(),
"Must provide at least one momentum coefficient"
);
Self {
config,
momentum_buffers: HashMap::new(),
current_step: 0,
}
}
pub fn with_defaults(learning_rate: f32, momentum_coefficients: Vec<f32>) -> Self {
Self::new(AggMoConfig {
learning_rate,
momentum_coefficients,
weight_decay: 0.0,
})
}
pub fn get_config(&self) -> &AggMoConfig {
&self.config
}
pub fn num_momentum_buffers(&self) -> usize {
self.config.momentum_coefficients.len()
}
}
impl OptimizerState for AggMo {
fn zero_grad(&mut self) -> Result<()> {
Ok(())
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
let effective_grad = if self.config.weight_decay > 0.0 {
gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
} else {
gradient
};
let buffers = self.momentum_buffers.entry(param_id).or_insert_with(|| {
(0..self.config.momentum_coefficients.len())
.map(|_| {
Tensor::zeros(&effective_grad.shape())
.expect("zeros should always succeed for valid gradient shape")
})
.collect()
});
let mut aggregated_momentum = Tensor::zeros(&effective_grad.shape())?;
for (i, &beta) in self.config.momentum_coefficients.iter().enumerate() {
buffers[i] =
buffers[i].mul_scalar(beta)?.add(&effective_grad.mul_scalar(1.0 - beta)?)?;
aggregated_momentum = aggregated_momentum.add(&buffers[i])?;
}
let num_buffers = self.config.momentum_coefficients.len() as f32;
let averaged_momentum = aggregated_momentum.div_scalar(num_buffers)?;
*parameter =
parameter.sub(&averaged_momentum.mul_scalar(self.config.learning_rate)?)?;
}
Ok(())
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"learning_rate".to_string(),
Tensor::scalar(self.config.learning_rate)?,
);
state.insert(
"weight_decay".to_string(),
Tensor::scalar(self.config.weight_decay)?,
);
state.insert(
"current_step".to_string(),
Tensor::scalar(self.current_step as f32)?,
);
state.insert(
"num_momentum_coeffs".to_string(),
Tensor::scalar(self.config.momentum_coefficients.len() as f32)?,
);
for (i, &coeff) in self.config.momentum_coefficients.iter().enumerate() {
state.insert(format!("momentum_coeff_{}", i), Tensor::scalar(coeff)?);
}
for (¶m_id, buffers) in &self.momentum_buffers {
for (buffer_idx, buffer) in buffers.iter().enumerate() {
state.insert(
format!("momentum_buffer_{}_{}", param_id, buffer_idx),
buffer.clone(),
);
}
}
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr) = state.get("learning_rate") {
self.config.learning_rate = lr.to_scalar()?;
}
if let Some(wd) = state.get("weight_decay") {
self.config.weight_decay = wd.to_scalar()?;
}
if let Some(step) = state.get("current_step") {
self.current_step = step.to_scalar()? as usize;
}
if let Some(num_coeffs_tensor) = state.get("num_momentum_coeffs") {
let num_coeffs = num_coeffs_tensor.to_scalar()? as usize;
let mut coefficients = Vec::with_capacity(num_coeffs);
for i in 0..num_coeffs {
if let Some(coeff_tensor) = state.get(&format!("momentum_coeff_{}", i)) {
coefficients.push(coeff_tensor.to_scalar()?);
}
}
self.config.momentum_coefficients = coefficients;
}
self.momentum_buffers.clear();
let mut param_buffers: HashMap<usize, HashMap<usize, Tensor>> = HashMap::new();
for (key, tensor) in state {
if key.starts_with("momentum_buffer_") {
let parts: Vec<&str> = key.split('_').collect();
if parts.len() >= 4 {
if let (Ok(param_id), Ok(buffer_idx)) =
(parts[2].parse::<usize>(), parts[3].parse::<usize>())
{
param_buffers.entry(param_id).or_default().insert(buffer_idx, tensor);
}
}
}
}
for (param_id, buffer_map) in param_buffers {
let mut buffers = Vec::new();
for i in 0..self.config.momentum_coefficients.len() {
if let Some(buffer) = buffer_map.get(&i) {
buffers.push(buffer.clone());
}
}
if buffers.len() == self.config.momentum_coefficients.len() {
self.momentum_buffers.insert(param_id, buffers);
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VarianceReductionConfig {
pub learning_rate: f32,
pub method: VarianceReductionMethod,
pub history_size: usize,
pub full_grad_frequency: usize,
pub weight_decay: f32,
}
impl Default for VarianceReductionConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
method: VarianceReductionMethod::SVRG,
history_size: 100,
full_grad_frequency: 10,
weight_decay: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VarianceReductionMethod {
SVRG,
SAG,
}
#[derive(Debug)]
pub struct VarianceReduction {
config: VarianceReductionConfig,
gradient_history: HashMap<usize, Vec<Tensor>>,
average_gradients: HashMap<usize, Tensor>,
full_gradients: HashMap<usize, Tensor>,
current_step: usize,
last_full_grad_step: usize,
}
impl VarianceReduction {
pub fn new(config: VarianceReductionConfig) -> Self {
Self {
config,
gradient_history: HashMap::new(),
average_gradients: HashMap::new(),
full_gradients: HashMap::new(),
current_step: 0,
last_full_grad_step: 0,
}
}
pub fn svrg(learning_rate: f32, history_size: usize, full_grad_frequency: usize) -> Self {
Self::new(VarianceReductionConfig {
learning_rate,
method: VarianceReductionMethod::SVRG,
history_size,
full_grad_frequency,
weight_decay: 0.0,
})
}
pub fn sag(learning_rate: f32, history_size: usize) -> Self {
Self::new(VarianceReductionConfig {
learning_rate,
method: VarianceReductionMethod::SAG,
history_size,
full_grad_frequency: 1, weight_decay: 0.0,
})
}
fn update_gradient_history(&mut self, param_id: usize, gradient: &Tensor) -> Result<()> {
let history = self.gradient_history.entry(param_id).or_default();
history.push(gradient.clone());
if history.len() > self.config.history_size {
history.remove(0);
}
Ok(())
}
fn compute_average_gradient(&mut self, param_id: usize) -> Result<Tensor> {
if let Some(history) = self.gradient_history.get(¶m_id) {
if history.is_empty() {
return Err(anyhow!("No gradient history available"));
}
let mut sum = history[0].clone();
for grad in history.iter().skip(1) {
sum = sum.add(grad)?;
}
let average = sum.div_scalar(history.len() as f32)?;
self.average_gradients.insert(param_id, average.clone());
Ok(average)
} else {
Err(anyhow!("No gradient history for parameter {}", param_id))
}
}
fn should_compute_full_gradient(&self) -> bool {
self.current_step - self.last_full_grad_step >= self.config.full_grad_frequency
}
}
impl OptimizerState for VarianceReduction {
fn zero_grad(&mut self) -> Result<()> {
Ok(())
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
let compute_full_grad = match self.config.method {
VarianceReductionMethod::SVRG => self.should_compute_full_gradient(),
VarianceReductionMethod::SAG => false,
};
if compute_full_grad {
self.last_full_grad_step = self.current_step;
for (param_id, parameter) in parameters.iter().enumerate() {
let gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
self.full_gradients.insert(param_id, gradient);
}
}
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let current_gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
let effective_grad = if self.config.weight_decay > 0.0 {
current_gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
} else {
current_gradient
};
self.update_gradient_history(param_id, &effective_grad)?;
let variance_reduced_grad = match self.config.method {
VarianceReductionMethod::SVRG => {
let full_grad_opt = self.full_gradients.get(¶m_id).cloned();
if let Some(full_grad) = full_grad_opt {
let avg_grad = self.compute_average_gradient(param_id)?;
effective_grad.sub(&avg_grad)?.add(&full_grad)?
} else {
effective_grad
}
},
VarianceReductionMethod::SAG => {
self.compute_average_gradient(param_id)?
},
};
*parameter =
parameter.sub(&variance_reduced_grad.mul_scalar(self.config.learning_rate)?)?;
}
Ok(())
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"learning_rate".to_string(),
Tensor::scalar(self.config.learning_rate)?,
);
state.insert(
"current_step".to_string(),
Tensor::scalar(self.current_step as f32)?,
);
state.insert(
"last_full_grad_step".to_string(),
Tensor::scalar(self.last_full_grad_step as f32)?,
);
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr) = state.get("learning_rate") {
self.config.learning_rate = lr.to_scalar()?;
}
if let Some(step) = state.get("current_step") {
self.current_step = step.to_scalar()? as usize;
}
if let Some(last_step) = state.get("last_full_grad_step") {
self.last_full_grad_step = last_step.to_scalar()? as usize;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NesterovAcceleratedGradientConfig {
pub learning_rate: f32,
pub momentum: f32,
pub weight_decay: f32,
pub restart_on_increase: bool,
}
impl Default for NesterovAcceleratedGradientConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
momentum: 0.9,
weight_decay: 0.0,
restart_on_increase: false,
}
}
}
#[derive(Debug)]
pub struct NesterovAcceleratedGradient {
config: NesterovAcceleratedGradientConfig,
velocity_buffers: HashMap<usize, Tensor>,
current_step: usize,
previous_loss: Option<f32>,
}
impl NesterovAcceleratedGradient {
pub fn new(config: NesterovAcceleratedGradientConfig) -> Self {
Self {
config,
velocity_buffers: HashMap::new(),
current_step: 0,
previous_loss: None,
}
}
pub fn with_defaults(learning_rate: f32, momentum: f32) -> Self {
Self::new(NesterovAcceleratedGradientConfig {
learning_rate,
momentum,
weight_decay: 0.0,
restart_on_increase: false,
})
}
pub fn get_config(&self) -> &NesterovAcceleratedGradientConfig {
&self.config
}
pub fn set_current_loss(&mut self, loss: f32) {
if self.config.restart_on_increase {
if let Some(prev_loss) = self.previous_loss {
if loss > prev_loss {
self.velocity_buffers.clear();
}
}
}
self.previous_loss = Some(loss);
}
}
impl OptimizerState for NesterovAcceleratedGradient {
fn zero_grad(&mut self) -> Result<()> {
Ok(())
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
let effective_grad = if self.config.weight_decay > 0.0 {
gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
} else {
gradient
};
let velocity = if let Some(v) = self.velocity_buffers.get(¶m_id) {
v.clone()
} else {
Tensor::zeros_like(parameter)?
};
let _lookahead_position = parameter.sub(&velocity.mul_scalar(self.config.momentum)?)?;
let new_velocity = velocity
.mul_scalar(self.config.momentum)?
.add(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
self.velocity_buffers.insert(param_id, new_velocity.clone());
*parameter = parameter.sub(&new_velocity)?;
}
Ok(())
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"learning_rate".to_string(),
Tensor::scalar(self.config.learning_rate)?,
);
state.insert(
"momentum".to_string(),
Tensor::scalar(self.config.momentum)?,
);
state.insert(
"weight_decay".to_string(),
Tensor::scalar(self.config.weight_decay)?,
);
state.insert(
"current_step".to_string(),
Tensor::scalar(self.current_step as f32)?,
);
if let Some(loss) = self.previous_loss {
state.insert("previous_loss".to_string(), Tensor::scalar(loss)?);
}
for (¶m_id, velocity) in &self.velocity_buffers {
state.insert(format!("velocity_{}", param_id), velocity.clone());
}
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr) = state.get("learning_rate") {
self.config.learning_rate = lr.to_scalar()?;
}
if let Some(momentum) = state.get("momentum") {
self.config.momentum = momentum.to_scalar()?;
}
if let Some(wd) = state.get("weight_decay") {
self.config.weight_decay = wd.to_scalar()?;
}
if let Some(step) = state.get("current_step") {
self.current_step = step.to_scalar()? as usize;
}
if let Some(loss) = state.get("previous_loss") {
self.previous_loss = Some(loss.to_scalar()?);
}
self.velocity_buffers.clear();
for (key, tensor) in state {
if let Some(param_id_str) = key.strip_prefix("velocity_") {
if let Ok(param_id) = param_id_str.parse::<usize>() {
self.velocity_buffers.insert(param_id, tensor);
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeavyBallConfig {
pub learning_rate: f32,
pub beta: f32,
pub weight_decay: f32,
pub adaptive_momentum: bool,
}
impl Default for HeavyBallConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta: 0.9,
weight_decay: 0.0,
adaptive_momentum: false,
}
}
}
#[derive(Debug)]
pub struct HeavyBall {
config: HeavyBallConfig,
velocity_buffers: HashMap<usize, Tensor>,
previous_gradients: HashMap<usize, Tensor>,
current_step: usize,
}
impl HeavyBall {
pub fn new(config: HeavyBallConfig) -> Self {
Self {
config,
velocity_buffers: HashMap::new(),
previous_gradients: HashMap::new(),
current_step: 0,
}
}
pub fn with_defaults(learning_rate: f32, beta: f32) -> Self {
Self::new(HeavyBallConfig {
learning_rate,
beta,
weight_decay: 0.0,
adaptive_momentum: false,
})
}
pub fn get_config(&self) -> &HeavyBallConfig {
&self.config
}
fn compute_adaptive_momentum(&self, param_id: usize, current_grad: &Tensor) -> Result<f32> {
if let Some(prev_grad) = self.previous_gradients.get(¶m_id) {
let dot_product = current_grad.mul(prev_grad)?.sum(None, false)?;
let norm_current = current_grad.norm_squared()?.sqrt()?;
let norm_prev = prev_grad.norm_squared()?.sqrt()?;
let dot_scalar = dot_product.to_scalar()?;
let norm_current_scalar = norm_current.to_scalar()?;
let norm_prev_scalar = norm_prev.to_scalar()?;
let denominator = norm_current_scalar * norm_prev_scalar;
if denominator > 1e-8 {
let cosine_similarity = dot_scalar / denominator;
let adaptive_beta = self.config.beta * cosine_similarity.max(0.0);
Ok(adaptive_beta)
} else {
Ok(self.config.beta)
}
} else {
Ok(self.config.beta)
}
}
}
impl OptimizerState for HeavyBall {
fn zero_grad(&mut self) -> Result<()> {
Ok(())
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
let effective_grad = if self.config.weight_decay > 0.0 {
gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
} else {
gradient
};
let beta = if self.config.adaptive_momentum {
self.compute_adaptive_momentum(param_id, &effective_grad)?
} else {
self.config.beta
};
let velocity = if let Some(v) = self.velocity_buffers.get(¶m_id) {
v.clone()
} else {
Tensor::zeros_like(parameter)?
};
let new_velocity = velocity
.mul_scalar(beta)?
.sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
self.velocity_buffers.insert(param_id, new_velocity.clone());
*parameter = parameter.add(&new_velocity)?;
if self.config.adaptive_momentum {
self.previous_gradients.insert(param_id, effective_grad);
}
}
Ok(())
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"learning_rate".to_string(),
Tensor::scalar(self.config.learning_rate)?,
);
state.insert("beta".to_string(), Tensor::scalar(self.config.beta)?);
state.insert(
"weight_decay".to_string(),
Tensor::scalar(self.config.weight_decay)?,
);
state.insert(
"current_step".to_string(),
Tensor::scalar(self.current_step as f32)?,
);
for (¶m_id, velocity) in &self.velocity_buffers {
state.insert(format!("velocity_{}", param_id), velocity.clone());
}
for (¶m_id, grad) in &self.previous_gradients {
state.insert(format!("prev_grad_{}", param_id), grad.clone());
}
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr) = state.get("learning_rate") {
self.config.learning_rate = lr.to_scalar()?;
}
if let Some(beta) = state.get("beta") {
self.config.beta = beta.to_scalar()?;
}
if let Some(wd) = state.get("weight_decay") {
self.config.weight_decay = wd.to_scalar()?;
}
if let Some(step) = state.get("current_step") {
self.current_step = step.to_scalar()? as usize;
}
self.velocity_buffers.clear();
self.previous_gradients.clear();
for (key, tensor) in state {
if let Some(param_id_str) = key.strip_prefix("velocity_") {
if let Ok(param_id) = param_id_str.parse::<usize>() {
self.velocity_buffers.insert(param_id, tensor);
}
} else if let Some(param_id_str) = key.strip_prefix("prev_grad_") {
if let Ok(param_id) = param_id_str.parse::<usize>() {
self.previous_gradients.insert(param_id, tensor);
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FISTAConfig {
pub learning_rate: f32,
pub threshold: f32,
pub adaptive_restart: bool,
pub weight_decay: f32,
}
impl Default for FISTAConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
threshold: 1e-4,
adaptive_restart: true,
weight_decay: 0.0,
}
}
}
#[derive(Debug)]
pub struct FISTA {
config: FISTAConfig,
previous_params: HashMap<usize, Tensor>,
current_step: usize,
momentum_coefficient: f32,
previous_momentum: f32,
}
impl FISTA {
pub fn new(config: FISTAConfig) -> Self {
Self {
config,
previous_params: HashMap::new(),
current_step: 0,
momentum_coefficient: 1.0,
previous_momentum: 1.0,
}
}
pub fn with_defaults(learning_rate: f32, threshold: f32) -> Self {
Self::new(FISTAConfig {
learning_rate,
threshold,
adaptive_restart: true,
weight_decay: 0.0,
})
}
pub fn get_config(&self) -> &FISTAConfig {
&self.config
}
fn soft_threshold(&self, tensor: &Tensor, threshold: f32) -> Result<Tensor> {
let threshold_tensor = Tensor::scalar(threshold)?;
let zero_tensor = Tensor::zeros_like(tensor)?;
let abs_tensor = tensor.abs()?;
let thresholded = abs_tensor.sub(&threshold_tensor)?.max(&zero_tensor)?;
let sign_tensor = tensor.sign()?;
Ok(sign_tensor.mul(&thresholded)?)
}
fn update_momentum_coefficient(&mut self) {
let t = self.current_step as f32;
self.previous_momentum = self.momentum_coefficient;
self.momentum_coefficient = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
}
}
impl OptimizerState for FISTA {
fn zero_grad(&mut self) -> Result<()> {
Ok(())
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
self.update_momentum_coefficient();
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let gradient = match parameter.grad() {
Ok(grad) => grad,
Err(_) => {
continue;
},
};
let effective_grad = if self.config.weight_decay > 0.0 {
gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
} else {
gradient
};
let previous_param = if let Some(prev) = self.previous_params.get(¶m_id) {
prev.clone()
} else {
parameter.clone()
};
let beta = (self.previous_momentum - 1.0) / self.momentum_coefficient;
let extrapolated = parameter.add(&previous_param.sub(parameter)?.mul_scalar(beta)?)?;
let grad_step =
extrapolated.sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
let new_parameter = self.soft_threshold(&grad_step, self.config.threshold)?;
self.previous_params.insert(param_id, parameter.clone());
*parameter = new_parameter;
}
Ok(())
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state = HashMap::new();
state.insert(
"learning_rate".to_string(),
Tensor::scalar(self.config.learning_rate)?,
);
state.insert(
"threshold".to_string(),
Tensor::scalar(self.config.threshold)?,
);
state.insert(
"weight_decay".to_string(),
Tensor::scalar(self.config.weight_decay)?,
);
state.insert(
"current_step".to_string(),
Tensor::scalar(self.current_step as f32)?,
);
state.insert(
"momentum_coefficient".to_string(),
Tensor::scalar(self.momentum_coefficient)?,
);
state.insert(
"previous_momentum".to_string(),
Tensor::scalar(self.previous_momentum)?,
);
for (¶m_id, param) in &self.previous_params {
state.insert(format!("prev_param_{}", param_id), param.clone());
}
Ok(state)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr) = state.get("learning_rate") {
self.config.learning_rate = lr.to_scalar()?;
}
if let Some(threshold) = state.get("threshold") {
self.config.threshold = threshold.to_scalar()?;
}
if let Some(wd) = state.get("weight_decay") {
self.config.weight_decay = wd.to_scalar()?;
}
if let Some(step) = state.get("current_step") {
self.current_step = step.to_scalar()? as usize;
}
if let Some(momentum) = state.get("momentum_coefficient") {
self.momentum_coefficient = momentum.to_scalar()?;
}
if let Some(prev_momentum) = state.get("previous_momentum") {
self.previous_momentum = prev_momentum.to_scalar()?;
}
self.previous_params.clear();
for (key, tensor) in state {
if let Some(param_id_str) = key.strip_prefix("prev_param_") {
if let Ok(param_id) = param_id_str.parse::<usize>() {
self.previous_params.insert(param_id, tensor);
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveBatchSizingConfig {
pub initial_batch_size: usize,
pub min_batch_size: usize,
pub max_batch_size: usize,
pub gradient_variance_tolerance: f32,
pub lr_adaptation_factor: f32,
pub variance_window_size: usize,
pub increase_threshold: f32,
pub decrease_threshold: f32,
}
impl Default for AdaptiveBatchSizingConfig {
fn default() -> Self {
Self {
initial_batch_size: 32,
min_batch_size: 8,
max_batch_size: 512,
gradient_variance_tolerance: 0.1,
lr_adaptation_factor: 0.8,
variance_window_size: 10,
increase_threshold: 0.05,
decrease_threshold: 0.2,
}
}
}
#[derive(Debug)]
pub struct AdaptiveBatchSizing {
config: AdaptiveBatchSizingConfig,
current_batch_size: usize,
gradient_variance_history: Vec<f32>,
loss_history: Vec<f32>,
current_step: usize,
last_adjustment_step: usize,
}
impl AdaptiveBatchSizing {
pub fn new(config: AdaptiveBatchSizingConfig) -> Self {
let initial_batch_size = config.initial_batch_size;
Self {
config,
current_batch_size: initial_batch_size,
gradient_variance_history: Vec::new(),
loss_history: Vec::new(),
current_step: 0,
last_adjustment_step: 0,
}
}
pub fn with_defaults(
initial_batch_size: usize,
min_batch_size: usize,
max_batch_size: usize,
) -> Self {
Self::new(AdaptiveBatchSizingConfig {
initial_batch_size,
min_batch_size,
max_batch_size,
..Default::default()
})
}
pub fn current_batch_size(&self) -> usize {
self.current_batch_size
}
pub fn get_config(&self) -> &AdaptiveBatchSizingConfig {
&self.config
}
pub fn update(&mut self, gradient_variance: f32, current_loss: f32) -> Result<usize> {
self.current_step += 1;
self.gradient_variance_history.push(gradient_variance);
self.loss_history.push(current_loss);
if self.gradient_variance_history.len() > self.config.variance_window_size {
self.gradient_variance_history.remove(0);
}
if self.loss_history.len() > self.config.variance_window_size {
self.loss_history.remove(0);
}
if self.should_adjust_batch_size() {
self.adjust_batch_size()?;
self.last_adjustment_step = self.current_step;
}
Ok(self.current_batch_size)
}
pub fn compute_gradient_variance(&self, gradients: &[Tensor]) -> Result<f32> {
if gradients.is_empty() {
return Ok(0.0);
}
let mut mean_grad = gradients[0].clone();
for grad in gradients.iter().skip(1) {
mean_grad = mean_grad.add(grad)?;
}
mean_grad = mean_grad.div_scalar(gradients.len() as f32)?;
let mut variance_sum = 0.0;
for grad in gradients {
let diff = grad.sub(&mean_grad)?;
let squared_norm = diff.mul(&diff)?.sum(None, false)?;
variance_sum += squared_norm.to_scalar()?;
}
Ok(variance_sum / gradients.len() as f32)
}
fn should_adjust_batch_size(&self) -> bool {
if self.current_step - self.last_adjustment_step < 5 {
return false;
}
self.gradient_variance_history.len() >= 3
}
fn adjust_batch_size(&mut self) -> Result<()> {
let recent_variance = self.recent_average_variance();
let variance_trend = self.variance_trend();
let loss_trend = self.loss_trend();
if recent_variance > self.config.decrease_threshold && variance_trend > 0.0 {
self.increase_batch_size();
} else if recent_variance < self.config.increase_threshold && loss_trend < -0.01 {
self.decrease_batch_size();
}
Ok(())
}
fn recent_average_variance(&self) -> f32 {
if self.gradient_variance_history.is_empty() {
return 0.0;
}
let recent_window = std::cmp::min(5, self.gradient_variance_history.len());
let start_idx = self.gradient_variance_history.len() - recent_window;
self.gradient_variance_history[start_idx..].iter().sum::<f32>() / recent_window as f32
}
fn variance_trend(&self) -> f32 {
if self.gradient_variance_history.len() < 3 {
return 0.0;
}
let len = self.gradient_variance_history.len();
let recent = self.gradient_variance_history[len - 2..].iter().sum::<f32>() / 2.0;
let older = self.gradient_variance_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
recent - older
}
fn loss_trend(&self) -> f32 {
if self.loss_history.len() < 3 {
return 0.0;
}
let len = self.loss_history.len();
let recent = self.loss_history[len - 2..].iter().sum::<f32>() / 2.0;
let older = self.loss_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
(recent - older) / older.max(1e-8)
}
fn increase_batch_size(&mut self) {
let new_size = (self.current_batch_size as f32 * 1.5) as usize;
self.current_batch_size = new_size.min(self.config.max_batch_size);
}
fn decrease_batch_size(&mut self) {
let new_size = (self.current_batch_size as f32 * 0.8) as usize;
self.current_batch_size = new_size.max(self.config.min_batch_size);
}
pub fn get_lr_adjustment(&self, original_batch_size: usize) -> f32 {
let ratio = self.current_batch_size as f32 / original_batch_size as f32;
ratio.sqrt() * self.config.lr_adaptation_factor
}
pub fn reset(&mut self) {
self.current_batch_size = self.config.initial_batch_size;
self.gradient_variance_history.clear();
self.loss_history.clear();
self.current_step = 0;
self.last_adjustment_step = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LossSurfaceSmoothingConfig {
pub smoothing_strength: f32,
pub noise_variance: f32,
pub ema_decay: f32,
pub averaging_window: usize,
pub use_gradient_averaging: bool,
pub use_noise_injection: bool,
}
impl Default for LossSurfaceSmoothingConfig {
fn default() -> Self {
Self {
smoothing_strength: 0.1,
noise_variance: 1e-4,
ema_decay: 0.9,
averaging_window: 5,
use_gradient_averaging: true,
use_noise_injection: false,
}
}
}
#[derive(Debug)]
pub struct LossSurfaceSmoothing {
config: LossSurfaceSmoothingConfig,
gradient_history: HashMap<usize, Vec<Tensor>>,
ema_gradients: HashMap<usize, Tensor>,
smoothed_parameters: HashMap<usize, Tensor>,
current_step: usize,
}
impl LossSurfaceSmoothing {
pub fn new(config: LossSurfaceSmoothingConfig) -> Self {
Self {
config,
gradient_history: HashMap::new(),
ema_gradients: HashMap::new(),
smoothed_parameters: HashMap::new(),
current_step: 0,
}
}
pub fn with_defaults(smoothing_strength: f32, use_noise: bool) -> Self {
Self::new(LossSurfaceSmoothingConfig {
smoothing_strength,
use_noise_injection: use_noise,
..Default::default()
})
}
pub fn get_config(&self) -> &LossSurfaceSmoothingConfig {
&self.config
}
pub fn smooth_gradients(&mut self, parameters: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
for (param_id, parameter) in parameters.iter_mut().enumerate() {
let original_grad = parameter.grad()?;
let mut smoothed_grad = original_grad.clone();
if self.config.use_gradient_averaging {
smoothed_grad = self.apply_gradient_averaging(param_id, &original_grad)?;
}
smoothed_grad = self.apply_ema_smoothing(param_id, &smoothed_grad)?;
if self.config.use_noise_injection {
smoothed_grad = self.apply_noise_injection(&smoothed_grad)?;
}
parameter.set_grad(smoothed_grad)?;
}
Ok(())
}
pub fn smooth_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()> {
for (param_id, parameter) in parameters.iter_mut().enumerate() {
if let Some(smoothed_param) = self.smoothed_parameters.get(¶m_id) {
let new_smoothed = smoothed_param
.mul_scalar(self.config.ema_decay)?
.add(¶meter.mul_scalar(1.0 - self.config.ema_decay)?)?;
*parameter = parameter
.mul_scalar(1.0 - self.config.smoothing_strength)?
.add(&new_smoothed.mul_scalar(self.config.smoothing_strength)?)?;
self.smoothed_parameters.insert(param_id, new_smoothed);
} else {
self.smoothed_parameters.insert(param_id, parameter.clone());
}
}
Ok(())
}
fn apply_gradient_averaging(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
let history = self.gradient_history.entry(param_id).or_default();
history.push(gradient.clone());
if history.len() > self.config.averaging_window {
history.remove(0);
}
if history.len() == 1 {
Ok(gradient.clone())
} else {
let mut sum = history[0].clone();
for grad in history.iter().skip(1) {
sum = sum.add(grad)?;
}
Ok(sum.div_scalar(history.len() as f32)?)
}
}
fn apply_ema_smoothing(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
if let Some(ema_grad) = self.ema_gradients.get(¶m_id) {
let new_ema = ema_grad
.mul_scalar(self.config.ema_decay)?
.add(&gradient.mul_scalar(1.0 - self.config.ema_decay)?)?;
self.ema_gradients.insert(param_id, new_ema.clone());
Ok(new_ema)
} else {
self.ema_gradients.insert(param_id, gradient.clone());
Ok(gradient.clone())
}
}
fn apply_noise_injection(&self, gradient: &Tensor) -> Result<Tensor> {
let noise = Tensor::randn_like(gradient)
.map_err(|e| anyhow!("Failed to create noise tensor: {}", e))?
.mul_scalar(self.config.noise_variance.sqrt())
.map_err(|e| anyhow!("Failed to scale noise tensor: {}", e))?;
gradient
.add(&noise)
.map_err(|e| anyhow!("Failed to add noise to gradient: {}", e))
}
pub fn reset(&mut self) {
self.gradient_history.clear();
self.ema_gradients.clear();
self.smoothed_parameters.clear();
self.current_step = 0;
}
pub fn get_statistics(&self) -> HashMap<String, f32> {
let mut stats = HashMap::new();
stats.insert("current_step".to_string(), self.current_step as f32);
stats.insert(
"num_tracked_params".to_string(),
self.gradient_history.len() as f32,
);
stats.insert(
"smoothing_strength".to_string(),
self.config.smoothing_strength,
);
stats.insert("ema_decay".to_string(), self.config.ema_decay);
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qhm_config_default() {
let config = QHMConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.momentum, 0.9);
assert_eq!(config.nu, 0.7);
assert_eq!(config.weight_decay, 0.0);
}
#[test]
fn test_aggmo_config_default() {
let config = AggMoConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.momentum_coefficients, vec![0.0, 0.9, 0.99]);
assert_eq!(config.weight_decay, 0.0);
}
#[test]
fn test_qhm_creation() {
let optimizer = QHM::with_defaults(1e-3, 0.9, 0.7);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.current_step, 0);
}
#[test]
fn test_aggmo_creation() {
let optimizer = AggMo::with_defaults(1e-3, vec![0.0, 0.9, 0.99]);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.num_momentum_buffers(), 3);
}
#[test]
fn test_variance_reduction_svrg() {
let optimizer = VarianceReduction::svrg(1e-3, 50, 10);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.current_step, 0);
}
#[test]
fn test_variance_reduction_sag() {
let optimizer = VarianceReduction::sag(1e-3, 100);
assert_eq!(optimizer.get_lr(), 1e-3);
assert!(matches!(
optimizer.config.method,
VarianceReductionMethod::SAG
));
}
#[test]
fn test_nesterov_accelerated_gradient_config() {
let config = NesterovAcceleratedGradientConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.momentum, 0.9);
assert_eq!(config.weight_decay, 0.0);
assert!(!config.restart_on_increase);
}
#[test]
fn test_nesterov_accelerated_gradient_creation() {
let optimizer = NesterovAcceleratedGradient::with_defaults(1e-3, 0.9);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.current_step, 0);
assert!(optimizer.previous_loss.is_none());
}
#[test]
fn test_nesterov_restart_on_increase() {
let mut optimizer = NesterovAcceleratedGradient::new(NesterovAcceleratedGradientConfig {
learning_rate: 1e-3,
momentum: 0.9,
weight_decay: 0.0,
restart_on_increase: true,
});
optimizer.set_current_loss(1.0);
assert_eq!(optimizer.previous_loss, Some(1.0));
optimizer.set_current_loss(1.5);
assert_eq!(optimizer.previous_loss, Some(1.5));
}
#[test]
fn test_heavy_ball_config() {
let config = HeavyBallConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.beta, 0.9);
assert_eq!(config.weight_decay, 0.0);
assert!(!config.adaptive_momentum);
}
#[test]
fn test_heavy_ball_creation() {
let optimizer = HeavyBall::with_defaults(1e-3, 0.9);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.current_step, 0);
assert_eq!(optimizer.get_config().beta, 0.9);
}
#[test]
fn test_heavy_ball_adaptive_momentum() {
let optimizer = HeavyBall::new(HeavyBallConfig {
learning_rate: 1e-3,
beta: 0.9,
weight_decay: 0.0,
adaptive_momentum: true,
});
assert!(optimizer.config.adaptive_momentum);
}
#[test]
fn test_fista_config() {
let config = FISTAConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.threshold, 1e-4);
assert!(config.adaptive_restart);
assert_eq!(config.weight_decay, 0.0);
}
#[test]
fn test_fista_creation() {
let optimizer = FISTA::with_defaults(1e-3, 1e-4);
assert_eq!(optimizer.get_lr(), 1e-3);
assert_eq!(optimizer.current_step, 0);
assert_eq!(optimizer.momentum_coefficient, 1.0);
assert_eq!(optimizer.previous_momentum, 1.0);
}
#[test]
fn test_fista_momentum_update() {
let mut optimizer = FISTA::with_defaults(1e-3, 1e-4);
optimizer.current_step = 1;
optimizer.update_momentum_coefficient();
assert!(optimizer.momentum_coefficient > 1.0);
assert_eq!(optimizer.previous_momentum, 1.0);
let prev_momentum = optimizer.momentum_coefficient;
optimizer.current_step = 2;
optimizer.update_momentum_coefficient();
assert!(optimizer.momentum_coefficient > prev_momentum);
}
#[test]
fn test_adaptive_batch_sizing_config() {
let config = AdaptiveBatchSizingConfig::default();
assert_eq!(config.initial_batch_size, 32);
assert_eq!(config.min_batch_size, 8);
assert_eq!(config.max_batch_size, 512);
assert_eq!(config.gradient_variance_tolerance, 0.1);
assert_eq!(config.lr_adaptation_factor, 0.8);
assert_eq!(config.variance_window_size, 10);
assert_eq!(config.increase_threshold, 0.05);
assert_eq!(config.decrease_threshold, 0.2);
}
#[test]
fn test_adaptive_batch_sizing_creation() {
let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
assert_eq!(abs.current_batch_size(), 64);
assert_eq!(abs.get_config().min_batch_size, 16);
assert_eq!(abs.get_config().max_batch_size, 256);
}
#[test]
fn test_adaptive_batch_sizing_lr_adjustment() {
let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
let lr_adj = abs.get_lr_adjustment(32);
assert!(lr_adj > 0.0);
assert!(lr_adj < 2.0);
}
#[test]
fn test_adaptive_batch_sizing_reset() {
let mut abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
abs.current_step = 10;
abs.reset();
assert_eq!(abs.current_step, 0);
assert_eq!(abs.current_batch_size(), 64);
}
#[test]
fn test_loss_surface_smoothing_config() {
let config = LossSurfaceSmoothingConfig::default();
assert_eq!(config.smoothing_strength, 0.1);
assert_eq!(config.noise_variance, 1e-4);
assert_eq!(config.ema_decay, 0.9);
assert_eq!(config.averaging_window, 5);
assert!(config.use_gradient_averaging);
assert!(!config.use_noise_injection);
}
#[test]
fn test_loss_surface_smoothing_creation() {
let lss = LossSurfaceSmoothing::with_defaults(0.2, true);
assert_eq!(lss.get_config().smoothing_strength, 0.2);
assert!(lss.get_config().use_noise_injection);
assert_eq!(lss.current_step, 0);
}
#[test]
fn test_loss_surface_smoothing_statistics() {
let lss = LossSurfaceSmoothing::with_defaults(0.1, false);
let stats = lss.get_statistics();
assert_eq!(stats.get("current_step"), Some(&0.0));
assert_eq!(stats.get("num_tracked_params"), Some(&0.0));
assert_eq!(stats.get("smoothing_strength"), Some(&0.1));
assert_eq!(stats.get("ema_decay"), Some(&0.9));
}
#[test]
fn test_loss_surface_smoothing_reset() {
let mut lss = LossSurfaceSmoothing::with_defaults(0.1, false);
lss.current_step = 5;
lss.reset();
assert_eq!(lss.current_step, 0);
}
}