use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use std::collections::HashMap;
pub struct OptimizerUtils;
impl OptimizerUtils {
pub fn clamp_for_stability(tensor: &Tensor<f32>, min_val: f32, max_val: f32) -> Tensor<f32> {
let data = tensor.as_slice().unwrap();
let clamped_data: Vec<f32> = data.iter().map(|&x| x.max(min_val).min(max_val)).collect();
Tensor::from_vec(clamped_data, tensor.shape().to_vec())
}
pub fn clip_gradient_norm(grad: &Tensor<f32>, max_norm: f32) -> Tensor<f32> {
let grad_squared = grad * grad;
let grad_norm = grad_squared.sum().sqrt();
if grad_norm > max_norm {
let scale_factor = max_norm / grad_norm;
grad * scale_factor
} else {
grad.clone()
}
}
pub fn clip_gradient_value(grad: &Tensor<f32>, min_val: f32, max_val: f32) -> Tensor<f32> {
Self::clamp_for_stability(grad, min_val, max_val)
}
pub fn stable_sqrt(tensor: &Tensor<f32>, eps: f32) -> Tensor<f32> {
let data = tensor.as_slice().unwrap();
let sqrt_data: Vec<f32> = data.iter().map(|&x| (x + eps).sqrt()).collect();
Tensor::from_vec(sqrt_data, tensor.shape().to_vec())
}
pub fn ema_update(current: &mut Tensor<f32>, new_value: &Tensor<f32>, decay: f32) {
let decay_term = &*current * decay;
let new_term = new_value * (1.0 - decay);
let updated = &decay_term + &new_term;
current.copy_from(&updated);
}
pub fn has_invalid_values(tensor: &Tensor<f32>) -> bool {
let data = tensor.as_slice().unwrap();
data.iter().any(|&x| x.is_nan() || x.is_infinite())
}
pub fn sanitize_tensor(
tensor: &Tensor<f32>,
nan_replacement: f32,
inf_replacement: f32,
) -> Tensor<f32> {
let data = tensor.as_slice().unwrap();
let sanitized_data: Vec<f32> = data
.iter()
.map(|&x| {
if x.is_nan() {
nan_replacement
} else if x.is_infinite() {
if x > 0.0 {
inf_replacement
} else {
-inf_replacement
}
} else {
x
}
})
.collect();
Tensor::from_vec(sanitized_data, tensor.shape().to_vec())
}
pub fn l2_norm(tensor: &Tensor<f32>) -> f32 {
let squared = tensor * tensor;
squared.sum().sqrt()
}
pub fn l1_norm(tensor: &Tensor<f32>) -> f32 {
let data = tensor.as_slice().unwrap();
data.iter().map(|&x| x.abs()).sum()
}
pub fn apply_weight_decay(param: &Tensor<f32>, weight_decay: f32) -> Tensor<f32> {
param * (1.0 - weight_decay)
}
pub fn apply_weight_decay_to_grad(
grad: &Tensor<f32>,
param: &Tensor<f32>,
weight_decay: f32,
) -> Tensor<f32> {
if weight_decay > 0.0 {
let weight_decay_term = param * weight_decay;
grad + &weight_decay_term
} else {
grad.clone()
}
}
pub fn update_momentum(
momentum: &mut Tensor<f32>,
grad: &Tensor<f32>,
beta1: f32,
) -> RusTorchResult<()> {
if !(0.0..1.0).contains(&beta1) {
return Err(RusTorchError::InvalidParameters {
operation: "momentum update".to_string(),
message: format!("Beta1 must be in [0, 1), got {}", beta1),
});
}
let beta1_term = &*momentum * beta1;
let grad_term = grad * (1.0 - beta1);
let updated = &beta1_term + &grad_term;
if Self::has_invalid_values(&updated) {
return Err(RusTorchError::InvalidParameters {
operation: "momentum update".to_string(),
message: "Numerical instability detected in momentum update".to_string(),
});
}
momentum.copy_from(&updated);
Ok(())
}
pub fn update_velocity(
velocity: &mut Tensor<f32>,
grad: &Tensor<f32>,
beta2: f32,
) -> RusTorchResult<()> {
if !(0.0..1.0).contains(&beta2) {
return Err(RusTorchError::InvalidParameters {
operation: "velocity update".to_string(),
message: format!("Beta2 must be in [0, 1), got {}", beta2),
});
}
let beta2_term = &*velocity * beta2;
let grad_squared = grad * grad;
let grad_term = &grad_squared * (1.0 - beta2);
let updated = &beta2_term + &grad_term;
if Self::has_invalid_values(&updated) {
return Err(RusTorchError::InvalidParameters {
operation: "velocity update".to_string(),
message: "Numerical instability detected in velocity update".to_string(),
});
}
velocity.copy_from(&updated);
Ok(())
}
pub fn bias_correction(beta: f32, step: usize) -> f32 {
if step == 0 || beta == 0.0 {
1.0
} else {
1.0 - beta.powi(step as i32)
}
}
pub fn apply_bias_correction(tensor: &Tensor<f32>, correction: f32) -> Tensor<f32> {
if correction.abs() < 1e-12 {
tensor.clone()
} else {
tensor / correction
}
}
pub fn compute_adam_update(
momentum: &Tensor<f32>,
velocity: &Tensor<f32>,
eps: f32,
) -> Tensor<f32> {
let denominator = &Self::stable_sqrt(velocity, eps);
momentum / denominator
}
pub fn tensor_max(a: &Tensor<f32>, b: &Tensor<f32>) -> RusTorchResult<Tensor<f32>> {
let a_data = a
.as_slice()
.ok_or_else(|| RusTorchError::InvalidParameters {
operation: "tensor_max".to_string(),
message: "Failed to access tensor A data".to_string(),
})?;
let b_data = b
.as_slice()
.ok_or_else(|| RusTorchError::InvalidParameters {
operation: "tensor_max".to_string(),
message: "Failed to access tensor B data".to_string(),
})?;
if a_data.len() != b_data.len() || a.shape() != b.shape() {
return Err(RusTorchError::InvalidParameters {
operation: "tensor_max".to_string(),
message: "Tensor shapes must match".to_string(),
});
}
let max_data: Vec<f32> = a_data
.iter()
.zip(b_data.iter())
.map(|(&a_val, &b_val)| a_val.max(b_val))
.collect();
Ok(Tensor::from_vec(max_data, a.shape().to_vec()))
}
pub fn tensor_abs(tensor: &Tensor<f32>) -> RusTorchResult<Tensor<f32>> {
let data = tensor
.as_slice()
.ok_or_else(|| RusTorchError::InvalidParameters {
operation: "tensor_abs".to_string(),
message: "Failed to access tensor data".to_string(),
})?;
let abs_data: Vec<f32> = data.iter().map(|&x| x.abs()).collect();
Ok(Tensor::from_vec(abs_data, tensor.shape().to_vec()))
}
pub fn advanced_ema_update(
current: &mut Tensor<f32>,
new_value: &Tensor<f32>,
base_decay: f32,
step: usize,
warmup_steps: usize,
) -> RusTorchResult<()> {
let effective_decay = if step < warmup_steps {
let warmup_factor = (step as f32) / (warmup_steps as f32);
base_decay * warmup_factor
} else {
base_decay
};
let decay_term = &*current * effective_decay;
let new_term = new_value * (1.0 - effective_decay);
let updated = &decay_term + &new_term;
if Self::has_invalid_values(&updated) {
return Err(RusTorchError::InvalidParameters {
operation: "advanced_ema_update".to_string(),
message: "Numerical instability detected".to_string(),
});
}
current.copy_from(&updated);
Ok(())
}
pub fn cosine_annealing_lr(
base_lr: f32,
current_step: usize,
total_steps: usize,
min_lr: f32,
) -> f32 {
if total_steps == 0 {
return base_lr;
}
let progress = (current_step.min(total_steps) as f32) / (total_steps as f32);
let cosine_factor = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
min_lr + (base_lr - min_lr) * cosine_factor
}
pub fn adaptive_lr_scaling(grad_norm: f32, velocity_norm: f32, trust_ratio: f32) -> f32 {
if velocity_norm < 1e-12 {
trust_ratio
} else {
trust_ratio * (grad_norm / velocity_norm).min(1.0)
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizerState {
states: HashMap<usize, ParameterState>,
global_step: usize,
memory_threshold_mb: usize,
}
#[derive(Debug, Clone)]
pub struct ParameterState {
pub momentum: Option<Tensor<f32>>,
pub velocity: Option<Tensor<f32>>,
pub extra_state: HashMap<String, Tensor<f32>>,
pub last_step: usize,
}
impl OptimizerState {
pub fn new(memory_threshold_mb: usize) -> Self {
Self {
states: HashMap::new(),
global_step: 0,
memory_threshold_mb,
}
}
pub fn get_or_create_state(
&mut self,
param_id: usize,
param_shape: &[usize],
) -> &mut ParameterState {
self.states
.entry(param_id)
.or_insert_with(|| ParameterState {
momentum: None,
velocity: None,
extra_state: HashMap::new(),
last_step: 0,
})
}
pub fn init_momentum(&mut self, param_id: usize, param_shape: &[usize]) {
if let Some(state) = self.states.get_mut(¶m_id) {
if state.momentum.is_none() {
state.momentum = Some(Tensor::zeros(param_shape));
}
}
}
pub fn init_velocity(&mut self, param_id: usize, param_shape: &[usize]) {
if let Some(state) = self.states.get_mut(¶m_id) {
if state.velocity.is_none() {
state.velocity = Some(Tensor::zeros(param_shape));
}
}
}
pub fn cleanup_stale_states(&mut self, steps_threshold: usize) {
let current_step = self.global_step;
self.states
.retain(|_, state| current_step - state.last_step < steps_threshold);
}
pub fn step(&mut self) {
self.global_step += 1;
if self.global_step % 1000 == 0 {
self.cleanup_stale_states(5000);
}
}
pub fn get_step(&self) -> usize {
self.global_step
}
pub fn estimate_memory_mb(&self) -> usize {
let mut total_elements = 0;
for state in self.states.values() {
if let Some(ref momentum) = state.momentum {
total_elements += momentum.as_slice().unwrap().len();
}
if let Some(ref velocity) = state.velocity {
total_elements += velocity.as_slice().unwrap().len();
}
for tensor in state.extra_state.values() {
total_elements += tensor.as_slice().unwrap().len();
}
}
(total_elements * 4) / (1024 * 1024)
}
}
#[derive(Debug, Clone)]
pub struct StabilityConfig {
pub min_eps: f32,
pub max_grad_norm: f32,
pub max_param_value: f32,
pub auto_nan_correction: bool,
pub gradient_clipping: bool,
}
impl Default for StabilityConfig {
fn default() -> Self {
Self {
min_eps: 1e-8,
max_grad_norm: 10.0,
max_param_value: 1e6,
auto_nan_correction: true,
gradient_clipping: true,
}
}
}
impl StabilityConfig {
pub fn stabilize_gradient(&self, grad: &Tensor<f32>) -> Tensor<f32> {
let mut stabilized_grad = grad.clone();
if self.auto_nan_correction && OptimizerUtils::has_invalid_values(&stabilized_grad) {
stabilized_grad =
OptimizerUtils::sanitize_tensor(&stabilized_grad, 0.0, self.max_grad_norm);
}
if self.gradient_clipping {
stabilized_grad =
OptimizerUtils::clip_gradient_norm(&stabilized_grad, self.max_grad_norm);
}
stabilized_grad
}
pub fn stabilize_parameter(&self, param: &Tensor<f32>) -> Tensor<f32> {
let mut stabilized_param = param.clone();
if self.auto_nan_correction && OptimizerUtils::has_invalid_values(&stabilized_param) {
stabilized_param =
OptimizerUtils::sanitize_tensor(&stabilized_param, 0.0, self.max_param_value);
}
stabilized_param = OptimizerUtils::clamp_for_stability(
&stabilized_param,
-self.max_param_value,
self.max_param_value,
);
stabilized_param
}
}
#[derive(Debug, Clone)]
pub struct OptimizerMetrics {
gradient_norms: Vec<f32>,
param_change_norms: Vec<f32>,
learning_rates: Vec<f32>,
step_times: Vec<f32>,
max_history: usize,
step_count: usize,
}
impl OptimizerMetrics {
pub fn new(max_history: usize) -> Self {
Self {
gradient_norms: Vec::with_capacity(max_history),
param_change_norms: Vec::with_capacity(max_history),
learning_rates: Vec::with_capacity(max_history),
step_times: Vec::with_capacity(max_history),
max_history,
step_count: 0,
}
}
pub fn record_step(
&mut self,
grad_norm: f32,
param_change_norm: f32,
learning_rate: f32,
step_time: f32,
) {
self.gradient_norms.push(grad_norm);
self.param_change_norms.push(param_change_norm);
self.learning_rates.push(learning_rate);
self.step_times.push(step_time);
if self.gradient_norms.len() > self.max_history {
self.gradient_norms.remove(0);
self.param_change_norms.remove(0);
self.learning_rates.remove(0);
self.step_times.remove(0);
}
self.step_count += 1;
}
pub fn gradient_stats(&self) -> (f32, f32, f32) {
if self.gradient_norms.is_empty() {
return (0.0, 0.0, 0.0);
}
let mean = self.gradient_norms.iter().sum::<f32>() / self.gradient_norms.len() as f32;
let min = *self
.gradient_norms
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap();
let max = *self
.gradient_norms
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap();
(mean, min, max)
}
pub fn check_convergence(&self, threshold: f32, window_size: usize) -> bool {
if self.gradient_norms.len() < window_size {
return false;
}
let recent_norms =
&self.gradient_norms[self.gradient_norms.len().saturating_sub(window_size)..];
let avg_norm = recent_norms.iter().sum::<f32>() / recent_norms.len() as f32;
avg_norm < threshold
}
pub fn detect_issues(&self) -> Vec<String> {
let mut issues = Vec::new();
if let Some(&latest_grad_norm) = self.gradient_norms.last() {
if latest_grad_norm > 100.0 {
issues.push("Gradient explosion detected".to_string());
}
if latest_grad_norm < 1e-8 {
issues.push("Vanishing gradients detected".to_string());
}
}
if self.gradient_norms.len() > 10 {
let recent = &self.gradient_norms[self.gradient_norms.len() - 10..];
let variance = {
let mean = recent.iter().sum::<f32>() / recent.len() as f32;
recent.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / recent.len() as f32
};
if variance < 1e-12 {
issues.push("Optimization stagnation detected".to_string());
}
}
issues
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn reset(&mut self) {
self.gradient_norms.clear();
self.param_change_norms.clear();
self.learning_rates.clear();
self.step_times.clear();
self.step_count = 0;
}
}
pub struct OptimizerFactory;
impl OptimizerFactory {
pub fn validate_config(learning_rate: f32, weight_decay: f32, eps: f32) -> RusTorchResult<()> {
if learning_rate <= 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "optimizer creation".to_string(),
message: "Learning rate must be positive".to_string(),
});
}
if weight_decay < 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "optimizer creation".to_string(),
message: "Weight decay must be non-negative".to_string(),
});
}
if eps <= 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "optimizer creation".to_string(),
message: "Epsilon must be positive".to_string(),
});
}
Ok(())
}
pub fn suggest_parameters(problem_type: &str, model_size: usize) -> (f32, f32, f32) {
match problem_type {
"vision" => {
if model_size > 50_000_000 {
(1e-4, 1e-4, 1e-8)
} else {
(1e-3, 1e-4, 1e-8)
}
}
"nlp" => {
if model_size > 100_000_000 {
(5e-5, 1e-2, 1e-6)
} else {
(2e-4, 1e-3, 1e-8)
}
}
"reinforcement_learning" => (3e-4, 0.0, 1e-5),
"fine_tuning" => (5e-5, 1e-5, 1e-8),
_ => (1e-3, 1e-4, 1e-8), }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clamp_for_stability() {
let tensor = Tensor::from_vec(vec![-10.0, 0.0, 5.0, 15.0], vec![4]);
let clamped = OptimizerUtils::clamp_for_stability(&tensor, -5.0, 10.0);
let data = clamped.as_slice().unwrap();
assert_eq!(data, &[-5.0, 0.0, 5.0, 10.0]);
}
#[test]
fn test_clip_gradient_norm() {
let grad = Tensor::from_vec(vec![3.0, 4.0], vec![2]); let clipped = OptimizerUtils::clip_gradient_norm(&grad, 2.0);
let clipped_data = clipped.as_slice().unwrap();
assert!((clipped_data[0] - 1.2).abs() < 1e-5);
assert!((clipped_data[1] - 1.6).abs() < 1e-5);
}
#[test]
fn test_stable_sqrt() {
let tensor = Tensor::from_vec(vec![0.0, 4.0, 9.0], vec![3]);
let sqrt_tensor = OptimizerUtils::stable_sqrt(&tensor, 1e-8);
let data = sqrt_tensor.as_slice().unwrap();
assert!((data[0] - (1e-8_f32).sqrt()).abs() < 1e-10);
assert!((data[1] - 2.0).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_l2_norm() {
let tensor = Tensor::from_vec(vec![3.0, 4.0], vec![2]);
let norm = OptimizerUtils::l2_norm(&tensor);
assert!((norm - 5.0).abs() < 1e-6);
}
#[test]
fn test_l1_norm() {
let tensor = Tensor::from_vec(vec![-3.0, 4.0], vec![2]);
let norm = OptimizerUtils::l1_norm(&tensor);
assert!((norm - 7.0).abs() < 1e-6);
}
#[test]
fn test_sanitize_tensor() {
let tensor = Tensor::from_vec(vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY], vec![4]);
let sanitized = OptimizerUtils::sanitize_tensor(&tensor, 0.0, 100.0);
let data = sanitized.as_slice().unwrap();
assert_eq!(data[0], 1.0);
assert_eq!(data[1], 0.0); assert_eq!(data[2], 100.0); assert_eq!(data[3], -100.0); }
#[test]
fn test_optimizer_state_management() {
let mut state = OptimizerState::new(100);
let param_id = 12345;
let param_shape = vec![2, 3];
state.get_or_create_state(param_id, ¶m_shape);
state.init_momentum(param_id, ¶m_shape);
state.init_velocity(param_id, ¶m_shape);
let param_state = state.get_or_create_state(param_id, ¶m_shape);
assert!(param_state.momentum.is_some());
assert!(param_state.velocity.is_some());
assert_eq!(state.get_step(), 0);
state.step();
assert_eq!(state.get_step(), 1);
}
#[test]
fn test_stability_config() {
let config = StabilityConfig::default();
let grad = Tensor::from_vec(vec![100.0, -100.0], vec![2]);
let stabilized = config.stabilize_gradient(&grad);
let norm = OptimizerUtils::l2_norm(&stabilized);
assert!(norm <= config.max_grad_norm + 1e-5);
}
#[test]
fn test_enhanced_momentum_update() {
let mut momentum = Tensor::zeros(&[2]);
let grad = Tensor::from_vec(vec![0.1, 0.2], vec![2]);
let result = OptimizerUtils::update_momentum(&mut momentum, &grad, 0.9);
assert!(result.is_ok());
let momentum_data = momentum.as_slice().unwrap();
assert!((momentum_data[0] - 0.01).abs() < 1e-6);
assert!((momentum_data[1] - 0.02).abs() < 1e-6);
}
#[test]
fn test_enhanced_velocity_update() {
let mut velocity = Tensor::zeros(&[2]);
let grad = Tensor::from_vec(vec![0.1, 0.2], vec![2]);
let result = OptimizerUtils::update_velocity(&mut velocity, &grad, 0.999);
assert!(result.is_ok());
let velocity_data = velocity.as_slice().unwrap();
assert!((velocity_data[0] - 0.00001).abs() < 1e-6);
assert!((velocity_data[1] - 0.00004).abs() < 1e-6);
}
#[test]
fn test_tensor_max() {
let a = Tensor::from_vec(vec![1.0, 5.0, 2.0], vec![3]);
let b = Tensor::from_vec(vec![3.0, 1.0, 4.0], vec![3]);
let max_tensor = OptimizerUtils::tensor_max(&a, &b).unwrap();
let max_data = max_tensor.as_slice().unwrap();
assert_eq!(max_data, &[3.0, 5.0, 4.0]);
}
#[test]
fn test_tensor_abs() {
let tensor = Tensor::from_vec(vec![-1.0, 2.0, -3.0, 4.0], vec![4]);
let abs_tensor = OptimizerUtils::tensor_abs(&tensor).unwrap();
let abs_data = abs_tensor.as_slice().unwrap();
assert_eq!(abs_data, &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_bias_correction() {
assert!((OptimizerUtils::bias_correction(0.9, 1) - 0.1).abs() < 1e-6);
assert_eq!(OptimizerUtils::bias_correction(0.9, 0), 1.0);
let correction = OptimizerUtils::bias_correction(0.99, 100);
assert!(correction > 0.6);
}
#[test]
fn test_cosine_annealing_lr() {
let base_lr = 1e-3;
let min_lr = 1e-5;
let lr_start = OptimizerUtils::cosine_annealing_lr(base_lr, 0, 1000, min_lr);
assert!((lr_start - base_lr).abs() < 1e-8);
let lr_middle = OptimizerUtils::cosine_annealing_lr(base_lr, 500, 1000, min_lr);
assert!(lr_middle < base_lr && lr_middle > min_lr);
let lr_end = OptimizerUtils::cosine_annealing_lr(base_lr, 1000, 1000, min_lr);
assert!((lr_end - min_lr).abs() < 1e-8);
}
#[test]
fn test_optimizer_metrics() {
let mut metrics = OptimizerMetrics::new(100);
metrics.record_step(1.0, 0.1, 1e-3, 0.01);
metrics.record_step(0.5, 0.05, 1e-3, 0.012);
metrics.record_step(0.1, 0.01, 1e-3, 0.009);
assert_eq!(metrics.step_count(), 3);
let (mean, min, max) = metrics.gradient_stats();
assert!((mean - (1.0 + 0.5 + 0.1) / 3.0).abs() < 1e-6);
assert_eq!(min, 0.1);
assert_eq!(max, 1.0);
assert!(metrics.check_convergence(2.0, 3));
assert!(!metrics.check_convergence(0.05, 3));
}
#[test]
fn test_optimizer_factory_validation() {
assert!(OptimizerFactory::validate_config(1e-3, 1e-4, 1e-8).is_ok());
assert!(OptimizerFactory::validate_config(-1e-3, 1e-4, 1e-8).is_err());
assert!(OptimizerFactory::validate_config(1e-3, -1e-4, 1e-8).is_err());
assert!(OptimizerFactory::validate_config(1e-3, 1e-4, -1e-8).is_err());
}
#[test]
fn test_parameter_suggestions() {
let (lr, wd, eps) = OptimizerFactory::suggest_parameters("vision", 10_000_000);
assert!(lr > 0.0 && wd >= 0.0 && eps > 0.0);
let (lr_large, _, _) = OptimizerFactory::suggest_parameters("vision", 100_000_000);
assert!(lr_large < lr);
let (lr_nlp, _, _) = OptimizerFactory::suggest_parameters("nlp", 500_000_000);
assert!(lr_nlp > 0.0);
}
}