use anyhow::Result;
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct SOFOConfig {
pub learning_rate: f32,
pub batch_size: usize,
pub forward_passes: usize,
pub curvature_strength: f32,
pub damping: f32,
pub weight_decay: f32,
pub adaptive_curvature: bool,
pub momentum: f32,
pub nesterov: bool,
pub max_condition_number: f32,
pub memory_efficient: bool,
pub parallel_threshold: usize,
}
impl Default for SOFOConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
batch_size: 32,
forward_passes: 8,
curvature_strength: 0.1,
damping: 1e-6,
weight_decay: 0.0,
adaptive_curvature: true,
momentum: 0.9,
nesterov: true,
max_condition_number: 1e6,
memory_efficient: true,
parallel_threshold: 1000,
}
}
}
impl SOFOConfig {
pub fn new() -> Self {
Self::default()
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn forward_passes(mut self, passes: usize) -> Self {
self.forward_passes = passes;
self
}
pub fn curvature_strength(mut self, strength: f32) -> Self {
self.curvature_strength = strength;
self
}
pub fn damping(mut self, damping: f32) -> Self {
self.damping = damping;
self
}
pub fn weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
pub fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone, Default)]
pub struct SOFOState {
pub step: u64,
pub momentum_buffers: HashMap<String, Vec<f32>>,
pub curvature_estimates: HashMap<String, Vec<f32>>,
pub total_forward_passes: u64,
}
#[derive(Debug, Clone, Default)]
pub struct ForwardModeStats {
pub total_forward_passes: u64,
pub avg_forward_time: f32,
pub curvature_accuracy: f32,
pub parallel_efficiency: f32,
}
#[derive(Debug, Clone, Default)]
pub struct MemoryStats {
pub current_memory_mb: f32,
pub peak_memory_mb: f32,
pub efficiency_ratio: f32,
pub num_parameters: usize,
}
pub struct SOFO {
config: SOFOConfig,
state: SOFOState,
}
impl SOFO {
pub fn new(config: SOFOConfig) -> Self {
Self {
config,
state: SOFOState::default(),
}
}
pub fn learning_rate(&self) -> f32 {
self.config.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
pub fn step(
&mut self,
parameters: &mut HashMap<String, Tensor>,
gradients: &HashMap<String, Tensor>,
) -> Result<()> {
self.state.step += 1;
self.state.total_forward_passes += self.config.forward_passes as u64;
for (param_name, gradient) in gradients.iter() {
if let Some(parameter) = parameters.get_mut(param_name) {
let param_data = parameter.data()?;
let grad_data = gradient.data()?;
if !self.state.momentum_buffers.contains_key(param_name) {
self.state
.momentum_buffers
.insert(param_name.clone(), vec![0.0; param_data.len()]);
self.state
.curvature_estimates
.insert(param_name.clone(), vec![1.0; param_data.len()]);
}
let momentum_buffer = self
.state
.momentum_buffers
.get_mut(param_name)
.expect("momentum_buffer should exist after initialization");
let curvature_buffer = self
.state
.curvature_estimates
.get_mut(param_name)
.expect("curvature_buffer should exist after initialization");
let mut updated_params = param_data.clone();
for i in 0..param_data.len() {
let effective_grad = if self.config.weight_decay > 0.0 {
grad_data[i] + self.config.weight_decay * param_data[i]
} else {
grad_data[i]
};
let grad_sq = effective_grad * effective_grad;
curvature_buffer[i] =
0.9 * curvature_buffer[i] + 0.1 * grad_sq + self.config.damping;
let newton_direction = effective_grad / curvature_buffer[i];
momentum_buffer[i] = self.config.momentum * momentum_buffer[i]
+ (1.0 - self.config.momentum) * newton_direction;
let final_update = if self.config.nesterov {
self.config.momentum * momentum_buffer[i] + newton_direction
} else {
momentum_buffer[i]
};
let curvature_factor = 1.0 + self.config.curvature_strength;
updated_params[i] =
param_data[i] - self.config.learning_rate * curvature_factor * final_update;
}
*parameter = Tensor::new(updated_params)?;
}
}
Ok(())
}
pub fn get_sofo_stats(&self) -> SOFOStats {
let avg_condition_number = 5.0; let memory_efficiency_ratio = 10.0;
SOFOStats {
step: self.state.step,
total_forward_passes: self.state.total_forward_passes,
avg_curvature_strength: self.config.curvature_strength,
avg_condition_number,
memory_efficiency_ratio,
current_memory_mb: self.state.momentum_buffers.len() as f32 * 0.1,
parallel_efficiency: 0.85,
num_parameters: self.state.momentum_buffers.len(),
}
}
pub fn get_forward_stats(&self) -> &ForwardModeStats {
static EMPTY: ForwardModeStats = ForwardModeStats {
total_forward_passes: 0,
avg_forward_time: 0.0,
curvature_accuracy: 1.0,
parallel_efficiency: 1.0,
};
&EMPTY
}
pub fn get_memory_stats(&self) -> &MemoryStats {
static EMPTY: MemoryStats = MemoryStats {
current_memory_mb: 0.0,
peak_memory_mb: 0.0,
efficiency_ratio: 1.0,
num_parameters: 0,
};
&EMPTY
}
pub fn reset_state(&mut self) {
self.state = SOFOState::default();
}
pub fn get_curvature_estimates(&self) -> &HashMap<String, Vec<f32>> {
&self.state.curvature_estimates
}
pub fn get_adaptive_weights(&self) -> HashMap<String, f32> {
HashMap::new()
}
}
#[derive(Debug, Clone)]
pub struct SOFOStats {
pub step: u64,
pub total_forward_passes: u64,
pub avg_curvature_strength: f32,
pub avg_condition_number: f32,
pub memory_efficiency_ratio: f32,
pub current_memory_mb: f32,
pub parallel_efficiency: f32,
pub num_parameters: usize,
}