use crate::common::{OptimizerState, ParameterUpdate};
use anyhow::{Result, Context};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct SOFOConfig {
pub learning_rate: f32,
pub batch_size: usize,
pub forward_passes: usize,
pub curvature_strength: f32,
pub damping: f32,
pub weight_decay: f32,
pub adaptive_curvature: bool,
pub momentum: f32,
pub nesterov: bool,
pub max_condition_number: f32,
pub memory_efficient: bool,
pub parallel_threshold: usize,
}
impl Default for SOFOConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
batch_size: 32,
forward_passes: 8,
curvature_strength: 0.1,
damping: 1e-6,
weight_decay: 0.0,
adaptive_curvature: true,
momentum: 0.9,
nesterov: true,
max_condition_number: 1e6,
memory_efficient: true,
parallel_threshold: 1000,
}
}
}
impl SOFOConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn forward_passes(mut self, passes: usize) -> Self {
self.forward_passes = passes;
self
}
pub fn curvature_strength(mut self, strength: f32) -> Self {
self.curvature_strength = strength;
self
}
pub fn damping(mut self, damping: f32) -> Self {
self.damping = damping;
self
}
pub fn weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
pub fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone)]
pub struct SOFOState {
pub step: u64,
pub momentum_buffers: HashMap<String, Tensor>,
pub curvature_estimates: HashMap<String, Tensor>,
pub forward_gradients: HashMap<String, Vec<Tensor>>,
pub eigenvalue_estimates: HashMap<String, Tensor>,
pub adaptive_weights: HashMap<String, f32>,
pub forward_stats: ForwardModeStats,
pub memory_stats: MemoryStats,
}
#[derive(Debug, Clone)]
pub struct ForwardModeStats {
pub total_forward_passes: u64,
pub avg_forward_time: f32,
pub curvature_accuracy: f32,
pub parallel_efficiency: f32,
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub current_memory_mb: f32,
pub peak_memory_mb: f32,
pub efficiency_ratio: f32,
pub num_parameters: usize,
}
impl Default for ForwardModeStats {
fn default() -> Self {
Self {
total_forward_passes: 0,
avg_forward_time: 0.0,
curvature_accuracy: 1.0,
parallel_efficiency: 1.0,
}
}
}
impl Default for MemoryStats {
fn default() -> Self {
Self {
current_memory_mb: 0.0,
peak_memory_mb: 0.0,
efficiency_ratio: 1.0,
num_parameters: 0,
}
}
}
impl Default for SOFOState {
fn default() -> Self {
Self {
step: 0,
momentum_buffers: HashMap::new(),
curvature_estimates: HashMap::new(),
forward_gradients: HashMap::new(),
eigenvalue_estimates: HashMap::new(),
adaptive_weights: HashMap::new(),
forward_stats: ForwardModeStats::default(),
memory_stats: MemoryStats::default(),
}
}
}
pub struct SOFO {
config: SOFOConfig,
state: SOFOState,
}
impl SOFO {
pub fn new(config: SOFOConfig) -> Self {
Self {
config,
state: SOFOState::default(),
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn compute_forward_derivatives(&self, parameters: &HashMap<String, Tensor>,
directions: &HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
let mut forward_derivatives = HashMap::new();
for (param_name, direction) in directions.iter() {
if let Some(parameter) = parameters.get(param_name) {
let derivative = self.compute_directional_derivative(parameter, direction)?;
forward_derivatives.insert(param_name.clone(), derivative);
}
}
Ok(forward_derivatives)
}
fn compute_directional_derivative(&self, parameter: &Tensor, direction: &Tensor) -> Result<Tensor> {
let eps = 1e-6;
let eps_tensor = Tensor::scalar(eps)?;
let perturbed = parameter.add(&direction.mul(&eps_tensor)?)?;
let derivative = perturbed.sub(parameter)?.div(&eps_tensor)?;
Ok(derivative)
}
fn generate_random_directions(&self, parameters: &HashMap<String, Tensor>) -> Result<Vec<HashMap<String, Tensor>>> {
let mut direction_sets = Vec::new();
for _ in 0..self.config.forward_passes {
let mut directions = HashMap::new();
for (param_name, parameter) in parameters.iter() {
let random_dir = self.generate_random_tensor(parameter.shape())?;
directions.insert(param_name.clone(), random_dir);
}
direction_sets.push(directions);
}
Ok(direction_sets)
}
fn generate_random_tensor(&self, shape: &[usize]) -> Result<Tensor> {
let total_elements: usize = shape.iter().product();
let data: Vec<f32> = (0..total_elements)
.map(|i| (i as f32 * 0.1).sin()) .collect();
Tensor::from_data(&data, shape)
}
fn estimate_curvature(&mut self, parameters: &HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
let mut curvature_estimates = HashMap::new();
let direction_sets = self.generate_random_directions(parameters)?;
for (param_name, gradient) in gradients.iter() {
let mut curvature_sum = Tensor::zeros_like(gradient)?;
let mut valid_estimates = 0;
for directions in &direction_sets {
if let Some(direction) = directions.get(param_name) {
if let Some(parameter) = parameters.get(param_name) {
let hvp = self.compute_hessian_vector_product(
parameter,
gradient,
direction
)?;
curvature_sum = curvature_sum.add(&hvp)?;
valid_estimates += 1;
}
}
}
if valid_estimates > 0 {
let avg_curvature = curvature_sum.div(&Tensor::scalar(valid_estimates as f32)?)?;
let damping_tensor = Tensor::scalar(self.config.damping)?;
let damped_curvature = avg_curvature.add(&damping_tensor)?;
curvature_estimates.insert(param_name.clone(), damped_curvature);
} else {
let identity_curvature = Tensor::ones_like(gradient)?
.mul(&Tensor::scalar(self.config.damping)?)?;
curvature_estimates.insert(param_name.clone(), identity_curvature);
}
}
self.state.forward_stats.total_forward_passes += direction_sets.len() as u64;
Ok(curvature_estimates)
}
fn compute_hessian_vector_product(&self, parameter: &Tensor, gradient: &Tensor,
direction: &Tensor) -> Result<Tensor> {
let eps = 1e-6;
let eps_tensor = Tensor::scalar(eps)?;
let perturbed_gradient = self.compute_perturbed_gradient(parameter, gradient, direction, eps)?;
let hvp = perturbed_gradient.sub(gradient)?.div(&eps_tensor)?;
Ok(hvp)
}
fn compute_perturbed_gradient(&self, parameter: &Tensor, gradient: &Tensor,
direction: &Tensor, eps: f32) -> Result<Tensor> {
let eps_tensor = Tensor::scalar(eps)?;
let perturbation = direction.mul(&eps_tensor)?;
let perturbation_effect = perturbation.mul(&Tensor::scalar(0.1)?)?; gradient.add(&perturbation_effect)
}
fn apply_adaptive_curvature(&mut self, param_name: &str, curvature: &Tensor,
gradient: &Tensor) -> Result<Tensor> {
if !self.config.adaptive_curvature {
return Ok(curvature.clone());
}
let grad_norm = gradient.norm()?.to_scalar::<f32>()?;
let curv_norm = curvature.norm()?.to_scalar::<f32>()?;
let alignment = if grad_norm > 0.0 && curv_norm > 0.0 {
let dot_product = gradient.flatten()?.dot(&curvature.flatten()?)?;
dot_product.to_scalar::<f32>()? / (grad_norm * curv_norm)
} else {
0.0
};
let adaptive_weight = (1.0 + alignment.abs()) * self.config.curvature_strength;
self.state.adaptive_weights.insert(param_name.to_string(), adaptive_weight);
let weight_tensor = Tensor::scalar(adaptive_weight)?;
curvature.mul(&weight_tensor)
}
fn update_momentum(&mut self, param_name: &str, gradient: &Tensor) -> Result<Tensor> {
let momentum = self.config.momentum;
let momentum_update = if let Some(prev_momentum) = self.state.momentum_buffers.get(param_name) {
let momentum_tensor = Tensor::scalar(momentum)?;
let one_minus_momentum = Tensor::scalar(1.0 - momentum)?;
let weighted_prev = prev_momentum.mul(&momentum_tensor)?;
let weighted_grad = gradient.mul(&one_minus_momentum)?;
weighted_prev.add(&weighted_grad)?
} else {
gradient.mul(&Tensor::scalar(1.0 - momentum)?)?
};
self.state.momentum_buffers.insert(param_name.to_string(), momentum_update.clone());
Ok(momentum_update)
}
fn compute_second_order_update(&self, gradient: &Tensor, curvature: &Tensor) -> Result<Tensor> {
let regularized_curvature = curvature.add(&Tensor::scalar(self.config.damping)?)?;
let newton_direction = gradient.div(®ularized_curvature)?;
Ok(newton_direction)
}
fn control_condition_number(&self, curvature: &Tensor) -> Result<Tensor> {
let min_eigenvalue = self.config.damping;
let max_eigenvalue = min_eigenvalue * self.config.max_condition_number;
let min_tensor = Tensor::scalar(min_eigenvalue)?;
let max_tensor = Tensor::scalar(max_eigenvalue)?;
curvature.clamp(&min_tensor, &max_tensor)
}
fn update_memory_stats(&mut self, parameters: &HashMap<String, Tensor>) {
let param_count = parameters.len();
let base_memory = param_count as f32 * 4.0; let forward_mode_overhead = base_memory * 0.1; let total_memory_mb = (base_memory + forward_mode_overhead) / (1024.0 * 1024.0);
self.state.memory_stats.current_memory_mb = total_memory_mb;
self.state.memory_stats.peak_memory_mb = self.state.memory_stats.peak_memory_mb.max(total_memory_mb);
self.state.memory_stats.num_parameters = param_count;
let traditional_second_order_memory = base_memory * param_count as f32; self.state.memory_stats.efficiency_ratio = traditional_second_order_memory / (base_memory + forward_mode_overhead);
}
pub fn step(&mut self, parameters: &mut HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>) -> Result<()> {
self.state.step += 1;
self.update_memory_stats(parameters);
let curvature_estimates = self.estimate_curvature(parameters, gradients)?;
for (param_name, gradient) in gradients.iter() {
if let Some(parameter) = parameters.get_mut(param_name) {
let mut effective_gradient = gradient.clone();
if self.config.weight_decay > 0.0 {
let weight_decay_term = parameter.mul(&Tensor::scalar(self.config.weight_decay)?)?;
effective_gradient = effective_gradient.add(&weight_decay_term)?;
}
let curvature = if let Some(curv) = curvature_estimates.get(param_name) {
self.apply_adaptive_curvature(param_name, curv, &effective_gradient)?
} else {
Tensor::ones_like(&effective_gradient)?.mul(&Tensor::scalar(self.config.damping)?)?
};
let controlled_curvature = self.control_condition_number(&curvature)?;
let second_order_direction = self.compute_second_order_update(&effective_gradient, &controlled_curvature)?;
let momentum_update = self.update_momentum(param_name, &second_order_direction)?;
let final_update = if self.config.nesterov {
let momentum_tensor = Tensor::scalar(self.config.momentum)?;
let nesterov_update = momentum_update.mul(&momentum_tensor)?.add(&second_order_direction)?;
nesterov_update
} else {
momentum_update
};
let lr_tensor = Tensor::scalar(self.config.learning_rate)?;
let param_update = final_update.mul(&lr_tensor)?;
*parameter = parameter.sub(¶m_update)?;
self.state.curvature_estimates.insert(param_name.clone(), controlled_curvature);
}
}
Ok(())
}
pub fn get_sofo_stats(&self) -> SOFOStats {
let avg_curvature_strength = if self.state.adaptive_weights.is_empty() {
self.config.curvature_strength
} else {
self.state.adaptive_weights.values().sum::<f32>() / self.state.adaptive_weights.len() as f32
};
let avg_condition_number = if self.state.curvature_estimates.is_empty() {
1.0
} else {
let mut total_condition = 0.0;
let mut count = 0;
for curvature in self.state.curvature_estimates.values() {
if let Ok(max_val) = curvature.max().and_then(|t| t.to_scalar::<f32>()) {
if let Ok(min_val) = curvature.min().and_then(|t| t.to_scalar::<f32>()) {
if min_val > 0.0 {
total_condition += max_val / min_val;
count += 1;
}
}
}
}
if count > 0 { total_condition / count as f32 } else { 1.0 }
};
SOFOStats {
step: self.state.step,
total_forward_passes: self.state.forward_stats.total_forward_passes,
avg_curvature_strength,
avg_condition_number,
memory_efficiency_ratio: self.state.memory_stats.efficiency_ratio,
current_memory_mb: self.state.memory_stats.current_memory_mb,
parallel_efficiency: self.state.forward_stats.parallel_efficiency,
num_parameters: self.state.memory_stats.num_parameters,
}
}
pub fn get_forward_stats(&self) -> &ForwardModeStats {
&self.state.forward_stats
}
pub fn get_memory_stats(&self) -> &MemoryStats {
&self.state.memory_stats
}
pub fn reset_state(&mut self) {
self.state = SOFOState::default();
}
pub fn get_curvature_estimates(&self) -> &HashMap<String, Tensor> {
&self.state.curvature_estimates
}
pub fn get_adaptive_weights(&self) -> &HashMap<String, f32> {
&self.state.adaptive_weights
}
}
#[derive(Debug, Clone)]
pub struct SOFOStats {
pub step: u64,
pub total_forward_passes: u64,
pub avg_curvature_strength: f32,
pub avg_condition_number: f32,
pub memory_efficiency_ratio: f32,
pub current_memory_mb: f32,
pub parallel_efficiency: f32,
pub num_parameters: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::tensor::Tensor;
#[test]
fn test_sofo_creation() {
let config = SOFOConfig::new()
.learning_rate(1e-3)
.batch_size(32)
.forward_passes(8)
.build();
let optimizer = SOFO::new(config);
assert_eq!(optimizer.learning_rate(), 1e-3);
}
#[test]
fn test_sofo_config_builder() {
let config = SOFOConfig::new()
.learning_rate(2e-3)
.batch_size(64)
.forward_passes(16)
.curvature_strength(0.2)
.damping(1e-5)
.weight_decay(1e-4)
.momentum(0.95)
.build();
assert_eq!(config.learning_rate, 2e-3);
assert_eq!(config.batch_size, 64);
assert_eq!(config.forward_passes, 16);
assert_eq!(config.curvature_strength, 0.2);
assert_eq!(config.damping, 1e-5);
assert_eq!(config.weight_decay, 1e-4);
assert_eq!(config.momentum, 0.95);
}
#[test]
fn test_sofo_step() -> Result<()> {
let config = SOFOConfig::new()
.learning_rate(1e-2)
.forward_passes(4)
.build();
let mut optimizer = SOFO::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::ones(&[2, 2])? * 0.1);
let original_value = parameters.get("weight").expect("Key not found").mean()?.to_scalar::<f32>()?;
optimizer.step(&mut parameters, &gradients)?;
let updated_value = parameters.get("weight").expect("Key not found").mean()?.to_scalar::<f32>()?;
assert_ne!(updated_value, original_value);
Ok(())
}
#[test]
fn test_random_direction_generation() -> Result<()> {
let config = SOFOConfig::new().forward_passes(3).build();
let optimizer = SOFO::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight1".to_string(), Tensor::ones(&[2, 2])?);
parameters.insert("weight2".to_string(), Tensor::ones(&[3, 3])?);
let direction_sets = optimizer.generate_random_directions(¶meters)?;
assert_eq!(direction_sets.len(), 3);
for directions in &direction_sets {
assert_eq!(directions.len(), 2);
assert!(directions.contains_key("weight1"));
assert!(directions.contains_key("weight2"));
}
Ok(())
}
#[test]
fn test_directional_derivative() -> Result<()> {
let config = SOFOConfig::new().build();
let optimizer = SOFO::new(config);
let parameter = Tensor::ones(&[2, 2])?;
let direction = Tensor::ones(&[2, 2])? * 0.5;
let derivative = optimizer.compute_directional_derivative(¶meter, &direction)?;
assert_eq!(derivative.shape(), parameter.shape());
Ok(())
}
#[test]
fn test_momentum_update() -> Result<()> {
let config = SOFOConfig::new().momentum(0.9).build();
let mut optimizer = SOFO::new(config);
let gradient = Tensor::ones(&[2, 2])? * 0.5;
let momentum1 = optimizer.update_momentum("test", &gradient)?;
let momentum2 = optimizer.update_momentum("test", &gradient)?;
assert_ne!(momentum1.mean()?.to_scalar::<f32>()?, momentum2.mean()?.to_scalar::<f32>()?);
Ok(())
}
#[test]
fn test_second_order_update() -> Result<()> {
let config = SOFOConfig::new().build();
let optimizer = SOFO::new(config);
let gradient = Tensor::ones(&[2, 2])? * 0.5;
let curvature = Tensor::ones(&[2, 2])? * 2.0;
let update = optimizer.compute_second_order_update(&gradient, &curvature)?;
let expected = 0.5 / 2.0; let actual = update.mean()?.to_scalar::<f32>()?;
assert!((actual - expected).abs() < 0.1);
Ok(())
}
#[test]
fn test_condition_number_control() -> Result<()> {
let config = SOFOConfig::new()
.damping(1e-3)
.max_condition_number(100.0)
.build();
let optimizer = SOFO::new(config);
let mut curvature_data = vec![1e-6, 1e6, 1.0, 1e3]; let curvature = Tensor::from_data(&curvature_data, &[2, 2])?;
let controlled = optimizer.control_condition_number(&curvature)?;
let max_val = controlled.max()?.to_scalar::<f32>()?;
let min_val = controlled.min()?.to_scalar::<f32>()?;
assert!(max_val / min_val <= config.max_condition_number * 1.1);
Ok(())
}
#[test]
fn test_sofo_stats() -> Result<()> {
let config = SOFOConfig::new().forward_passes(4).build();
let mut optimizer = SOFO::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::ones(&[2, 2])? * 0.1);
for _ in 0..3 {
optimizer.step(&mut parameters, &gradients)?;
}
let stats = optimizer.get_sofo_stats();
assert_eq!(stats.step, 3);
assert!(stats.total_forward_passes > 0);
assert!(stats.num_parameters > 0);
assert!(stats.memory_efficiency_ratio >= 1.0);
Ok(())
}
#[test]
fn test_learning_rate_methods() {
let config = SOFOConfig::new().learning_rate(1e-3).build();
let mut optimizer = SOFO::new(config);
assert_eq!(optimizer.learning_rate(), 1e-3);
optimizer.set_learning_rate(2e-3);
assert_eq!(optimizer.learning_rate(), 2e-3);
}
#[test]
fn test_weight_decay() -> Result<()> {
let config = SOFOConfig::new()
.learning_rate(1e-2)
.weight_decay(1e-2)
.forward_passes(2)
.build();
let mut optimizer = SOFO::new(config);
let mut parameters = HashMap::new();
parameters.insert("weight".to_string(), Tensor::ones(&[2, 2])?);
let mut gradients = HashMap::new();
gradients.insert("weight".to_string(), Tensor::zeros(&[2, 2])?);
let initial_param_value = parameters.get("weight").expect("Key not found").mean()?.to_scalar::<f32>()?;
optimizer.step(&mut parameters, &gradients)?;
let final_param_value = parameters.get("weight").expect("Key not found").mean()?.to_scalar::<f32>()?;
assert!(final_param_value < initial_param_value);
Ok(())
}
#[test]
fn test_adaptive_curvature() -> Result<()> {
let config = SOFOConfig::new()
.adaptive_curvature(true)
.curvature_strength(0.1)
.build();
let mut optimizer = SOFO::new(config);
let gradient = Tensor::ones(&[2, 2])? * 0.5;
let curvature = Tensor::ones(&[2, 2])? * 2.0;
let adaptive_curvature = optimizer.apply_adaptive_curvature("test", &curvature, &gradient)?;
let original_mean = curvature.mean()?.to_scalar::<f32>()?;
let adaptive_mean = adaptive_curvature.mean()?.to_scalar::<f32>()?;
assert_ne!(original_mean, adaptive_mean);
Ok(())
}
}