use crate::common::{GradientProcessor, OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use trustformers_core::errors::Result;
use trustformers_core::{traits::Optimizer, Tensor};
#[derive(Debug, Clone)]
pub struct DiWo {
learning_rate: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
ortho_strength: f32,
adaptive_ortho: bool,
step: usize,
momentum_states: HashMap<String, Tensor>,
velocity_states: HashMap<String, Tensor>,
ortho_history: HashMap<String, OrthogonalizationState>,
#[allow(dead_code)]
gradient_processor: GradientProcessor,
}
#[derive(Debug, Clone)]
struct OrthogonalizationState {
#[allow(dead_code)]
previous_basis: Option<Tensor>,
violation_history: Vec<f32>,
adaptation_rate: f32,
}
impl Default for OrthogonalizationState {
fn default() -> Self {
Self {
previous_basis: None,
violation_history: Vec::new(),
adaptation_rate: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiWoConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
pub ortho_strength: f32,
pub adaptive_ortho: bool,
}
impl Default for DiWoConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
ortho_strength: 0.1,
adaptive_ortho: true,
}
}
}
impl DiWo {
pub fn new(
learning_rate: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
ortho_strength: f32,
) -> Self {
Self {
learning_rate,
beta1: betas.0,
beta2: betas.1,
eps,
weight_decay,
ortho_strength,
adaptive_ortho: true,
step: 0,
momentum_states: HashMap::new(),
velocity_states: HashMap::new(),
ortho_history: HashMap::new(),
gradient_processor: GradientProcessor,
}
}
pub fn for_transformer_training() -> Self {
Self::new(2e-4, (0.9, 0.98), 1e-8, 0.01, 0.05)
}
pub fn for_computer_vision() -> Self {
Self::new(1e-3, (0.9, 0.999), 1e-8, 1e-4, 0.1)
}
pub fn from_config(config: DiWoConfig) -> Self {
let mut optimizer = Self::new(
config.learning_rate,
(config.beta1, config.beta2),
config.eps,
config.weight_decay,
config.ortho_strength,
);
optimizer.adaptive_ortho = config.adaptive_ortho;
optimizer
}
fn apply_orthogonalization(
&mut self,
param_name: &str,
param: &Tensor,
update: &Tensor,
) -> Result<Tensor> {
let _ortho_state = self.ortho_history.entry(param_name.to_string()).or_default();
if param.shape().len() >= 2 && param.len() > 100 {
let param_data = param.data()?;
let update_data = update.data()?;
let orthogonalized_update =
self.gram_schmidt_orthogonalization(¶m_data, &update_data, ¶m.shape())?;
let effective_ortho_strength = if self.adaptive_ortho {
self.adapt_orthogonalization_strength(param_name, param, &orthogonalized_update)?
} else {
self.ortho_strength
};
let blended_update_data: Vec<f32> = update_data
.iter()
.zip(orthogonalized_update.iter())
.map(|(orig, ortho)| {
orig * (1.0 - effective_ortho_strength) + ortho * effective_ortho_strength
})
.collect();
Ok(Tensor::new(blended_update_data)?)
} else {
Ok(update.clone())
}
}
fn gram_schmidt_orthogonalization(
&self,
_param_data: &[f32],
update_data: &[f32],
shape: &[usize],
) -> Result<Vec<f32>> {
if shape.len() < 2 {
return Ok(update_data.to_vec());
}
let rows = shape[0];
let cols = shape[1];
let mut orthogonal_update = update_data.to_vec();
for i in 0..rows.min(cols) {
let start_idx = i * cols;
let end_idx = (i + 1) * cols;
if end_idx > orthogonal_update.len() {
break;
}
let mut norm_sq = 0.0f32;
for j in start_idx..end_idx {
norm_sq += orthogonal_update[j] * orthogonal_update[j];
}
if norm_sq > 1e-8 {
let norm = norm_sq.sqrt();
for j in start_idx..end_idx {
orthogonal_update[j] /= norm;
}
}
for k in 0..i {
let k_start = k * cols;
let k_end = (k + 1) * cols;
if k_end > orthogonal_update.len() {
break;
}
let mut dot_product = 0.0f32;
for (j1, j2) in (start_idx..end_idx).zip(k_start..k_end) {
dot_product += orthogonal_update[j1] * orthogonal_update[j2];
}
for (j1, j2) in (start_idx..end_idx).zip(k_start..k_end) {
orthogonal_update[j1] -= dot_product * orthogonal_update[j2];
}
}
}
Ok(orthogonal_update)
}
fn adapt_orthogonalization_strength(
&mut self,
param_name: &str,
param: &Tensor,
orthogonal_update: &[f32],
) -> Result<f32> {
let param_data = param.data()?;
let violation = self.calculate_orthogonality_violation(¶m_data, orthogonal_update)?;
let ortho_state = self
.ortho_history
.get_mut(param_name)
.expect("ortho_history state must exist after initialization");
ortho_state.violation_history.push(violation);
if ortho_state.violation_history.len() > 100 {
ortho_state.violation_history.remove(0);
}
if ortho_state.violation_history.len() >= 2 {
let recent_avg =
ortho_state.violation_history.iter().rev().take(10).sum::<f32>() / 10.0;
let older_avg = if ortho_state.violation_history.len() > 20 {
ortho_state.violation_history.iter().rev().skip(10).take(10).sum::<f32>() / 10.0
} else {
recent_avg
};
if recent_avg > older_avg {
ortho_state.adaptation_rate += 0.001;
} else {
ortho_state.adaptation_rate -= 0.0005;
}
ortho_state.adaptation_rate = ortho_state.adaptation_rate.clamp(0.0, 1.0);
}
Ok((self.ortho_strength * ortho_state.adaptation_rate).clamp(0.0, 0.5))
}
fn calculate_orthogonality_violation(
&self,
param_data: &[f32],
update_data: &[f32],
) -> Result<f32> {
if param_data.len() != update_data.len() || param_data.len() < 4 {
return Ok(0.0);
}
let param_mean = param_data.iter().sum::<f32>() / param_data.len() as f32;
let update_mean = update_data.iter().sum::<f32>() / update_data.len() as f32;
let mut numerator = 0.0f32;
let mut param_var = 0.0f32;
let mut update_var = 0.0f32;
for (p, u) in param_data.iter().zip(update_data.iter()) {
let p_centered = p - param_mean;
let u_centered = u - update_mean;
numerator += p_centered * u_centered;
param_var += p_centered * p_centered;
update_var += u_centered * u_centered;
}
if param_var > 1e-8 && update_var > 1e-8 {
let correlation = numerator / (param_var.sqrt() * update_var.sqrt());
Ok(correlation.abs()) } else {
Ok(0.0)
}
}
}
impl Optimizer for DiWo {
fn update(&mut self, param: &mut Tensor, gradient: &Tensor) -> Result<()> {
let param_id = format!("param_{}", self.momentum_states.len());
let param_data = param.data()?;
let grad_data = gradient.data()?;
if !self.momentum_states.contains_key(¶m_id) {
self.momentum_states.insert(param_id.clone(), Tensor::zeros_like(param)?);
self.velocity_states.insert(param_id.clone(), Tensor::zeros_like(param)?);
}
let momentum = self
.momentum_states
.get_mut(¶m_id)
.expect("momentum_states must exist after initialization");
let velocity = self
.velocity_states
.get_mut(¶m_id)
.expect("velocity_states must exist after initialization");
let momentum_data = momentum.data()?;
let velocity_data = velocity.data()?;
let new_momentum_data: Vec<f32> = momentum_data
.iter()
.zip(grad_data.iter())
.map(|(m, g)| self.beta1 * m + (1.0 - self.beta1) * g)
.collect();
let new_velocity_data: Vec<f32> = velocity_data
.iter()
.zip(grad_data.iter())
.map(|(v, g)| self.beta2 * v + (1.0 - self.beta2) * g * g)
.collect();
*momentum = Tensor::new(new_momentum_data.clone())?;
*velocity = Tensor::new(new_velocity_data.clone())?;
let step_f32 = (self.step + 1) as f32;
let momentum_corrected = 1.0 - self.beta1.powf(step_f32);
let velocity_corrected = 1.0 - self.beta2.powf(step_f32);
let base_update_data: Vec<f32> = new_momentum_data
.iter()
.zip(new_velocity_data.iter())
.zip(param_data.iter())
.map(|((m, v), p)| {
let m_hat = m / momentum_corrected;
let v_hat = v / velocity_corrected;
let base_update = -self.learning_rate * m_hat / (v_hat.sqrt() + self.eps);
base_update - self.learning_rate * self.weight_decay * p
})
.collect();
let base_update = Tensor::new(base_update_data)?;
let orthogonal_update = self.apply_orthogonalization(¶m_id, param, &base_update)?;
*param = param.add(&orthogonal_update)?;
Ok(())
}
fn step(&mut self) {
self.step += 1;
}
fn zero_grad(&mut self) {
}
fn get_lr(&self) -> f32 {
self.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.learning_rate = lr;
}
}
impl StatefulOptimizer for DiWo {
type Config = DiWoConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
thread_local! {
static TEMP_CONFIG: RefCell<Option<DiWoConfig>> = const { RefCell::new(None) };
}
TEMP_CONFIG.with(|config| {
*config.borrow_mut() = Some(DiWoConfig {
learning_rate: self.learning_rate,
beta1: self.beta1,
beta2: self.beta2,
eps: self.eps,
weight_decay: self.weight_decay,
ortho_strength: self.ortho_strength,
adaptive_ortho: self.adaptive_ortho,
});
unsafe {
std::mem::transmute(
config.borrow().as_ref().expect("config just set to Some above"),
)
}
})
}
fn state(&self) -> &Self::State {
thread_local! {
static TEMP_STATE: RefCell<Option<OptimizerState>> = const { RefCell::new(None) };
}
TEMP_STATE.with(|state| {
*state.borrow_mut() = Some(OptimizerState {
step: self.step,
momentum: HashMap::new(),
variance: HashMap::new(),
third_moment: HashMap::new(),
param_steps: HashMap::new(),
velocity: HashMap::new(),
});
unsafe {
std::mem::transmute(state.borrow().as_ref().expect("state just set to Some above"))
}
})
}
fn state_mut(&mut self) -> &mut Self::State {
thread_local! {
static TEMP_STATE: RefCell<Option<OptimizerState>> = const { RefCell::new(None) };
}
TEMP_STATE.with(|state| {
*state.borrow_mut() = Some(OptimizerState {
step: self.step,
momentum: HashMap::new(),
variance: HashMap::new(),
third_moment: HashMap::new(),
param_steps: HashMap::new(),
velocity: HashMap::new(),
});
unsafe {
std::mem::transmute(
state.borrow_mut().as_mut().expect("state just set to Some above"),
)
}
})
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert(
"learning_rate".to_string(),
Tensor::new(vec![self.learning_rate])?,
);
state_dict.insert("beta1".to_string(), Tensor::new(vec![self.beta1])?);
state_dict.insert("beta2".to_string(), Tensor::new(vec![self.beta2])?);
state_dict.insert("eps".to_string(), Tensor::new(vec![self.eps])?);
state_dict.insert(
"weight_decay".to_string(),
Tensor::new(vec![self.weight_decay])?,
);
state_dict.insert(
"ortho_strength".to_string(),
Tensor::new(vec![self.ortho_strength])?,
);
state_dict.insert("step".to_string(), Tensor::new(vec![self.step as f32])?);
for (key, tensor) in &self.momentum_states {
state_dict.insert(format!("momentum_{}", key), tensor.clone());
}
for (key, tensor) in &self.velocity_states {
state_dict.insert(format!("velocity_{}", key), tensor.clone());
}
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr_tensor) = state.get("learning_rate") {
if let Ok(lr_vec) = lr_tensor.data() {
if !lr_vec.is_empty() {
self.learning_rate = lr_vec[0];
}
}
}
if let Some(beta1_tensor) = state.get("beta1") {
if let Ok(beta1_vec) = beta1_tensor.data() {
if !beta1_vec.is_empty() {
self.beta1 = beta1_vec[0];
}
}
}
if let Some(beta2_tensor) = state.get("beta2") {
if let Ok(beta2_vec) = beta2_tensor.data() {
if !beta2_vec.is_empty() {
self.beta2 = beta2_vec[0];
}
}
}
if let Some(step_tensor) = state.get("step") {
if let Ok(step_vec) = step_tensor.data() {
if !step_vec.is_empty() {
self.step = step_vec[0] as usize;
}
}
}
for (key, tensor) in state.iter() {
if let Some(param_id) = key.strip_prefix("momentum_") {
self.momentum_states.insert(param_id.to_string(), tensor.clone());
} else if let Some(param_id) = key.strip_prefix("velocity_") {
self.velocity_states.insert(param_id.to_string(), tensor.clone());
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let momentum_elements: usize = self
.momentum_states
.values()
.map(|t| if let Ok(data) = t.data() { data.len() } else { 0 })
.sum();
let velocity_elements: usize = self
.velocity_states
.values()
.map(|t| if let Ok(data) = t.data() { data.len() } else { 0 })
.sum();
let total_elements = momentum_elements + velocity_elements;
let total_bytes = total_elements * std::mem::size_of::<f32>();
StateMemoryStats {
momentum_elements,
variance_elements: velocity_elements,
third_moment_elements: 0,
total_bytes,
num_parameters: momentum_elements,
}
}
fn reset_state(&mut self) {
self.step = 0;
self.momentum_states.clear();
self.velocity_states.clear();
self.ortho_history.clear();
}
fn num_parameters(&self) -> usize {
self.momentum_states.len()
}
}
#[derive(Debug, Clone)]
pub struct MeZOV2 {
learning_rate: f32,
perturbation_scale: f32,
#[allow(dead_code)]
num_samples: usize,
adaptive_perturbation: bool,
perturbation_levels: Vec<f32>,
step: usize,
random_seed: u64,
parameter_history: HashMap<String, Vec<Tensor>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeZOV2Config {
pub learning_rate: f32,
pub perturbation_scale: f32,
pub num_samples: usize,
pub adaptive_perturbation: bool,
pub perturbation_levels: Vec<f32>,
}
impl Default for MeZOV2Config {
fn default() -> Self {
Self {
learning_rate: 1e-3,
perturbation_scale: 1e-3,
num_samples: 2,
adaptive_perturbation: true,
perturbation_levels: vec![1.0, 0.1, 0.01],
}
}
}
impl MeZOV2 {
pub fn new(learning_rate: f32, perturbation_scale: f32, num_samples: usize) -> Self {
Self {
learning_rate,
perturbation_scale,
num_samples,
adaptive_perturbation: true,
perturbation_levels: vec![1.0, 0.1, 0.01],
step: 0,
random_seed: 42,
parameter_history: HashMap::new(),
}
}
pub fn for_large_language_models() -> Self {
Self::new(1e-4, 1e-4, 1)
}
pub fn for_billion_parameter_models() -> Self {
Self::new(5e-5, 5e-5, 1)
}
pub fn from_config(config: MeZOV2Config) -> Self {
let mut optimizer = Self::new(
config.learning_rate,
config.perturbation_scale,
config.num_samples,
);
optimizer.adaptive_perturbation = config.adaptive_perturbation;
optimizer.perturbation_levels = config.perturbation_levels;
optimizer
}
#[allow(dead_code)]
fn estimate_gradient(
&mut self,
param: &Tensor,
loss_fn: impl Fn(&Tensor) -> Result<f32>,
) -> Result<Tensor> {
let param_data = param.data()?;
let mut gradient_estimate = vec![0.0f32; param_data.len()];
let perturbation_levels = self.perturbation_levels.clone();
for &scale_factor in &perturbation_levels {
let current_scale = self.perturbation_scale * scale_factor;
for _ in 0..self.num_samples {
let perturbation = self.generate_perturbation(¶m_data, current_scale)?;
let perturbed_forward_data: Vec<f32> =
param_data.iter().zip(perturbation.iter()).map(|(p, pert)| p + pert).collect();
let param_forward = Tensor::new(perturbed_forward_data)?;
let loss_forward = loss_fn(¶m_forward)?;
let perturbed_backward_data: Vec<f32> =
param_data.iter().zip(perturbation.iter()).map(|(p, pert)| p - pert).collect();
let param_backward = Tensor::new(perturbed_backward_data)?;
let loss_backward = loss_fn(¶m_backward)?;
let loss_diff = loss_forward - loss_backward;
let gradient_scale =
loss_diff / (2.0 * current_scale) / self.perturbation_levels.len() as f32;
for (i, &pert) in perturbation.iter().enumerate() {
gradient_estimate[i] += gradient_scale * pert.signum();
}
}
}
for grad in gradient_estimate.iter_mut() {
*grad /= self.num_samples as f32;
}
Tensor::new(gradient_estimate)
}
fn generate_perturbation(&mut self, param_data: &[f32], scale: f32) -> Result<Vec<f32>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.step.hash(&mut hasher);
self.random_seed.hash(&mut hasher);
let seed = hasher.finish();
let mut rng_state = seed;
let perturbation: Vec<f32> = param_data
.iter()
.map(|_| {
rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
let random_val = (rng_state >> 16) as f32 / (1u32 << 16) as f32;
scale * (2.0 * random_val - 1.0) })
.collect();
Ok(perturbation)
}
fn adapt_perturbation_scale(&mut self, param: &Tensor, _gradient_norm: f32) {
if !self.adaptive_perturbation {
return;
}
let param_id = format!("param_{}", self.parameter_history.len());
let history = self.parameter_history.entry(param_id).or_default();
history.push(param.clone());
if history.len() > 10 {
history.remove(0);
}
if history.len() >= 2 {
let recent = &history[history.len() - 1];
let previous = &history[history.len() - 2];
if let (Ok(recent_data), Ok(previous_data)) = (recent.data(), previous.data()) {
let param_change_norm: f32 = recent_data
.iter()
.zip(previous_data.iter())
.map(|(r, p)| (r - p).abs())
.sum::<f32>()
/ recent_data.len() as f32;
if param_change_norm < 1e-8 {
self.perturbation_scale *= 1.1;
} else if param_change_norm > 1e-3 {
self.perturbation_scale *= 0.9;
}
self.perturbation_scale = self.perturbation_scale.clamp(1e-6, 1e-1);
}
}
}
}
impl Optimizer for MeZOV2 {
fn update(&mut self, param: &mut Tensor, gradient: &Tensor) -> Result<()> {
let gradient_norm = gradient.norm()?;
let param_data = param.data()?;
let grad_data = gradient.data()?;
let mut rng_state = (self.step as u64).wrapping_mul(1664525).wrapping_add(1013904223);
let noisy_grad_data: Vec<f32> = grad_data
.iter()
.map(|&g| {
rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
let noise_factor = ((rng_state >> 16) as f32 / (1u32 << 16) as f32 - 0.5) * 0.1;
g * (1.0 + noise_factor) })
.collect();
let effective_lr = self.learning_rate * (1.0 + self.perturbation_scale);
let updated_param_data: Vec<f32> = param_data
.iter()
.zip(noisy_grad_data.iter())
.map(|(p, g)| p - effective_lr * g)
.collect();
*param = Tensor::new(updated_param_data)?;
self.adapt_perturbation_scale(param, gradient_norm);
Ok(())
}
fn step(&mut self) {
self.step += 1;
}
fn zero_grad(&mut self) {
}
fn get_lr(&self) -> f32 {
self.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.learning_rate = lr;
}
}
#[derive(Debug, Clone)]
pub struct AdaWin {
learning_rate: f32,
beta: f32,
#[allow(dead_code)]
adaptation_rate: f32,
max_window_size: usize,
step: usize,
parameter_states: HashMap<String, AdaWinParameterState>,
}
#[derive(Debug, Clone)]
struct AdaWinParameterState {
gradient_history: Vec<Tensor>,
window_size: usize,
momentum_weights: Vec<f32>,
momentum: Tensor,
correlation_history: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaWinConfig {
pub learning_rate: f32,
pub beta: f32,
pub adaptation_rate: f32,
pub max_window_size: usize,
}
impl Default for AdaWinConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta: 0.9,
adaptation_rate: 0.01,
max_window_size: 20,
}
}
}
impl AdaWin {
pub fn new(
learning_rate: f32,
beta: f32,
adaptation_rate: f32,
max_window_size: usize,
) -> Self {
Self {
learning_rate,
beta,
adaptation_rate,
max_window_size,
step: 0,
parameter_states: HashMap::new(),
}
}
pub fn for_transformer_training() -> Self {
Self::new(1e-4, 0.95, 0.02, 15)
}
pub fn for_computer_vision() -> Self {
Self::new(1e-3, 0.9, 0.01, 25)
}
pub fn from_config(config: AdaWinConfig) -> Self {
Self::new(
config.learning_rate,
config.beta,
config.adaptation_rate,
config.max_window_size,
)
}
fn calculate_gradient_correlation(&self, gradients: &[Tensor]) -> Result<f32> {
if gradients.len() < 2 {
return Ok(0.0);
}
let mut total_correlation = 0.0f32;
let mut count = 0;
for i in 0..gradients.len() - 1 {
for j in i + 1..gradients.len() {
let grad1_data = gradients[i].data()?;
let grad2_data = gradients[j].data()?;
if grad1_data.len() == grad2_data.len() {
let mean1 = grad1_data.iter().sum::<f32>() / grad1_data.len() as f32;
let mean2 = grad2_data.iter().sum::<f32>() / grad2_data.len() as f32;
let mut numerator = 0.0f32;
let mut var1 = 0.0f32;
let mut var2 = 0.0f32;
for (g1, g2) in grad1_data.iter().zip(grad2_data.iter()) {
let d1 = g1 - mean1;
let d2 = g2 - mean2;
numerator += d1 * d2;
var1 += d1 * d1;
var2 += d2 * d2;
}
if var1 > 1e-8 && var2 > 1e-8 {
total_correlation += numerator / (var1.sqrt() * var2.sqrt());
count += 1;
}
}
}
}
Ok(if count > 0 { total_correlation / count as f32 } else { 0.0 })
}
fn adapt_window_size(&self, param_state: &mut AdaWinParameterState) -> Result<()> {
let correlation = self.calculate_gradient_correlation(¶m_state.gradient_history)?;
param_state.correlation_history.push(correlation);
if param_state.correlation_history.len() > 50 {
param_state.correlation_history.remove(0);
}
if param_state.correlation_history.len() >= 5 {
let recent_avg =
param_state.correlation_history.iter().rev().take(3).sum::<f32>() / 3.0;
if recent_avg > 0.5 {
let new_size = (param_state.window_size + 1).min(self.max_window_size);
param_state.window_size = new_size;
} else if recent_avg < 0.1 {
let new_size = (param_state.window_size.saturating_sub(1)).max(2);
param_state.window_size = new_size;
}
}
self.update_momentum_weights(param_state);
Ok(())
}
fn update_momentum_weights(&self, param_state: &mut AdaWinParameterState) {
param_state.momentum_weights.clear();
let window_size = param_state.window_size;
for i in 0..window_size {
let age = i as f32;
let adaptive_beta = self.beta * (1.0 - age / (window_size as f32 + 1.0));
param_state.momentum_weights.push(adaptive_beta.powf(age));
}
let sum: f32 = param_state.momentum_weights.iter().sum();
if sum > 1e-8 {
for weight in param_state.momentum_weights.iter_mut() {
*weight /= sum;
}
}
}
fn calculate_weighted_momentum(&self, param_state: &AdaWinParameterState) -> Result<Tensor> {
if param_state.gradient_history.is_empty() {
return Ok(param_state.momentum.clone());
}
let first_grad = ¶m_state.gradient_history[0];
let mut weighted_momentum = vec![0.0f32; first_grad.len()];
for (i, gradient) in param_state.gradient_history.iter().enumerate() {
if i < param_state.momentum_weights.len() {
let weight = param_state.momentum_weights[i];
let grad_data = gradient.data()?;
for (j, &g) in grad_data.iter().enumerate() {
if j < weighted_momentum.len() {
weighted_momentum[j] += weight * g;
}
}
}
}
Tensor::new(weighted_momentum)
}
}
impl Optimizer for AdaWin {
fn update(&mut self, param: &mut Tensor, gradient: &Tensor) -> Result<()> {
let param_id = format!("param_{}", self.parameter_states.len());
if !self.parameter_states.contains_key(¶m_id) {
let momentum = Tensor::zeros_like(param)?;
self.parameter_states.insert(
param_id.clone(),
AdaWinParameterState {
gradient_history: Vec::new(),
window_size: 3, momentum_weights: vec![1.0, 0.5, 0.25], momentum,
correlation_history: Vec::new(),
},
);
}
{
let param_state = self
.parameter_states
.get_mut(¶m_id)
.expect("parameter_states must exist after step");
param_state.gradient_history.push(gradient.clone());
if param_state.gradient_history.len() > self.max_window_size {
param_state.gradient_history.remove(0);
}
}
{
let mut param_state = self
.parameter_states
.remove(¶m_id)
.expect("parameter_states must exist for param_id");
self.adapt_window_size(&mut param_state)?;
self.parameter_states.insert(param_id.clone(), param_state);
}
let weighted_momentum = {
let param_state = self
.parameter_states
.get(¶m_id)
.expect("parameter_states must exist for param_id");
self.calculate_weighted_momentum(param_state)?
};
let momentum_data = {
let param_state = self
.parameter_states
.get(¶m_id)
.expect("parameter_states must exist for param_id");
param_state.momentum.data()?
};
let weighted_data = weighted_momentum.data()?;
let new_momentum_data: Vec<f32> = momentum_data
.iter()
.zip(weighted_data.iter())
.map(|(m, w)| self.beta * m + (1.0 - self.beta) * w)
.collect();
{
let param_state = self
.parameter_states
.get_mut(¶m_id)
.expect("parameter_states must exist after step");
param_state.momentum = Tensor::new(new_momentum_data.clone())?;
}
let param_data = param.data()?;
let updated_param_data: Vec<f32> = param_data
.iter()
.zip(new_momentum_data.iter())
.map(|(p, m)| p - self.learning_rate * m)
.collect();
*param = Tensor::new(updated_param_data)?;
Ok(())
}
fn step(&mut self) {
self.step += 1;
}
fn zero_grad(&mut self) {
for param_state in self.parameter_states.values_mut() {
param_state.gradient_history.clear();
}
}
fn get_lr(&self) -> f32 {
self.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.learning_rate = lr;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diwo_creation() {
let optimizer = DiWo::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1);
assert_eq!(optimizer.learning_rate, 1e-3);
assert_eq!(optimizer.beta1, 0.9);
assert_eq!(optimizer.ortho_strength, 0.1);
}
#[test]
fn test_diwo_presets() {
let transformer_opt = DiWo::for_transformer_training();
assert_eq!(transformer_opt.learning_rate, 2e-4);
assert!(transformer_opt.adaptive_ortho);
let vision_opt = DiWo::for_computer_vision();
assert_eq!(vision_opt.learning_rate, 1e-3);
assert_eq!(vision_opt.ortho_strength, 0.1);
}
#[test]
fn test_diwo_config() -> Result<()> {
let config = DiWoConfig {
learning_rate: 5e-4,
ortho_strength: 0.2,
adaptive_ortho: false,
..Default::default()
};
let optimizer = DiWo::from_config(config);
assert_eq!(optimizer.learning_rate, 5e-4);
assert_eq!(optimizer.ortho_strength, 0.2);
assert!(!optimizer.adaptive_ortho);
Ok(())
}
#[test]
fn test_mezov2_creation() {
let optimizer = MeZOV2::new(1e-3, 1e-4, 2);
assert_eq!(optimizer.learning_rate, 1e-3);
assert_eq!(optimizer.perturbation_scale, 1e-4);
assert_eq!(optimizer.num_samples, 2);
}
#[test]
fn test_mezov2_presets() {
let llm_opt = MeZOV2::for_large_language_models();
assert_eq!(llm_opt.learning_rate, 1e-4);
let billion_opt = MeZOV2::for_billion_parameter_models();
assert_eq!(billion_opt.learning_rate, 5e-5);
}
#[test]
fn test_adawin_creation() {
let optimizer = AdaWin::new(1e-3, 0.9, 0.01, 20);
assert_eq!(optimizer.learning_rate, 1e-3);
assert_eq!(optimizer.beta, 0.9);
assert_eq!(optimizer.max_window_size, 20);
}
#[test]
fn test_adawin_presets() {
let transformer_opt = AdaWin::for_transformer_training();
assert_eq!(transformer_opt.learning_rate, 1e-4);
assert_eq!(transformer_opt.max_window_size, 15);
let vision_opt = AdaWin::for_computer_vision();
assert_eq!(vision_opt.learning_rate, 1e-3);
assert_eq!(vision_opt.max_window_size, 25);
}
#[test]
fn test_diwo_update() -> Result<()> {
let mut optimizer = DiWo::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1);
let mut param = Tensor::new(vec![1.0, 2.0, 3.0])?;
let gradient = Tensor::new(vec![0.1, 0.2, 0.1])?;
let original_param = param.data()?;
optimizer.update(&mut param, &gradient)?;
optimizer.step();
let updated_param = param.data()?;
assert_ne!(original_param, updated_param);
assert_eq!(optimizer.step, 1);
Ok(())
}
#[test]
fn test_mezov2_update() -> Result<()> {
let mut optimizer = MeZOV2::new(1e-3, 1e-4, 1);
let mut param = Tensor::new(vec![1.0, 2.0, 3.0])?;
let gradient = Tensor::new(vec![0.1, 0.2, 0.1])?;
let original_param = param.data()?;
optimizer.update(&mut param, &gradient)?;
optimizer.step();
let updated_param = param.data()?;
assert_ne!(original_param, updated_param);
assert_eq!(optimizer.step, 1);
Ok(())
}
#[test]
fn test_adawin_update() -> Result<()> {
let mut optimizer = AdaWin::new(1e-3, 0.9, 0.01, 10);
let mut param = Tensor::new(vec![1.0, 2.0, 3.0])?;
let gradient = Tensor::new(vec![0.1, 0.2, 0.1])?;
let original_param = param.data()?;
optimizer.update(&mut param, &gradient)?;
optimizer.step();
let updated_param = param.data()?;
assert_ne!(original_param, updated_param);
assert_eq!(optimizer.step, 1);
Ok(())
}
}