use crate::OptimizerState;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use trustformers_core::errors::Result;
use trustformers_core::Tensor;
#[derive(Debug, Clone)]
pub enum FusedOperation {
FusedAdam {
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
},
FusedAdamW {
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
},
FusedSGDMomentum {
lr: f64,
momentum: f64,
dampening: f64,
weight_decay: f64,
nesterov: bool,
},
FusedGradientClipping { max_norm: f64, scale_factor: f64 },
FusedBatchNorm { eps: f64, momentum: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
pub enable_memory_coalescing: bool,
pub enable_vectorization: bool,
pub batch_size: usize,
pub enable_kernel_fusion: bool,
pub buffer_size: usize,
pub enable_async_updates: bool,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
enable_memory_coalescing: true,
enable_vectorization: true,
batch_size: 64,
enable_kernel_fusion: true,
buffer_size: 1024,
enable_async_updates: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusedOptimizerState {
pub parameter_states: HashMap<String, OptimizerState>,
pub operation_buffers: HashMap<String, Vec<f64>>,
pub fusion_stats: FusionStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionStats {
pub fused_operations: u64,
pub memory_bandwidth_saved: u64,
pub flops_saved: u64,
pub avg_batch_size: f64,
pub fusion_efficiency: f64,
}
impl Default for FusionStats {
fn default() -> Self {
Self {
fused_operations: 0,
memory_bandwidth_saved: 0,
flops_saved: 0,
avg_batch_size: 0.0,
fusion_efficiency: 0.0,
}
}
}
#[derive(Debug)]
pub struct FusedOptimizer {
config: FusionConfig,
state: Arc<Mutex<FusedOptimizerState>>,
pending_operations: Arc<Mutex<Vec<(String, FusedOperation, Tensor, Tensor)>>>,
#[allow(dead_code)]
operation_queue: Arc<Mutex<HashMap<String, Vec<FusedOperation>>>>,
}
impl FusedOptimizer {
pub fn new(config: FusionConfig) -> Result<Self> {
let state = FusedOptimizerState {
parameter_states: HashMap::new(),
operation_buffers: HashMap::new(),
fusion_stats: FusionStats::default(),
};
Ok(Self {
config,
state: Arc::new(Mutex::new(state)),
pending_operations: Arc::new(Mutex::new(Vec::new())),
operation_queue: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn queue_operation(
&mut self,
param_name: String,
operation: FusedOperation,
parameter: Tensor,
gradient: Tensor,
) -> Result<()> {
let should_execute = {
let mut pending = self.pending_operations.lock().expect("Mutex lock poisoned");
pending.push((param_name, operation, parameter, gradient));
pending.len() >= self.config.batch_size
};
if should_execute {
self.execute_fused_batch()?;
}
Ok(())
}
pub fn execute_fused_batch(&mut self) -> Result<()> {
let mut pending = self.pending_operations.lock().expect("Mutex lock poisoned");
if pending.is_empty() {
return Ok(());
}
let operations = std::mem::take(&mut *pending);
drop(pending);
let mut adam_ops = Vec::new();
let mut adamw_ops = Vec::new();
let mut sgd_ops = Vec::new();
let mut clip_ops = Vec::new();
for (param_name, op, param, grad) in operations {
match op {
FusedOperation::FusedAdam { .. } => adam_ops.push((param_name, op, param, grad)),
FusedOperation::FusedAdamW { .. } => adamw_ops.push((param_name, op, param, grad)),
FusedOperation::FusedSGDMomentum { .. } => {
sgd_ops.push((param_name, op, param, grad))
},
FusedOperation::FusedGradientClipping { .. } => {
clip_ops.push((param_name, op, param, grad))
},
_ => {
self.execute_single_operation(param_name, op, param, grad)?;
},
}
}
if !adam_ops.is_empty() {
self.execute_fused_adam_batch(adam_ops)?;
}
if !adamw_ops.is_empty() {
self.execute_fused_adamw_batch(adamw_ops)?;
}
if !sgd_ops.is_empty() {
self.execute_fused_sgd_batch(sgd_ops)?;
}
if !clip_ops.is_empty() {
self.execute_fused_clipping_batch(clip_ops)?;
}
Ok(())
}
fn execute_fused_adam_batch(
&mut self,
operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
) -> Result<()> {
let mut state = self.state.lock().expect("Mutex lock poisoned");
let batch_size = operations.len();
for (param_name, op, param, grad) in operations {
if let FusedOperation::FusedAdam {
lr,
beta1,
beta2,
eps,
weight_decay,
} = op
{
let opt_state =
state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
OptimizerState {
step: 0,
momentum: HashMap::new(),
variance: HashMap::new(),
..Default::default()
}
});
self.fused_adam_update(
¶m,
&grad,
opt_state,
lr,
beta1,
beta2,
eps,
weight_decay,
)?;
}
}
state.fusion_stats.fused_operations += 1;
state.fusion_stats.avg_batch_size = (state.fusion_stats.avg_batch_size
* (state.fusion_stats.fused_operations - 1) as f64
+ batch_size as f64)
/ state.fusion_stats.fused_operations as f64;
let bandwidth_saved = batch_size * 4 * 8; state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
Ok(())
}
fn execute_fused_adamw_batch(
&mut self,
operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
) -> Result<()> {
let mut state = self.state.lock().expect("Mutex lock poisoned");
let batch_size = operations.len();
for (param_name, op, param, grad) in operations {
if let FusedOperation::FusedAdamW {
lr,
beta1,
beta2,
eps,
weight_decay,
} = op
{
let opt_state =
state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
OptimizerState {
step: 0,
momentum: HashMap::new(),
variance: HashMap::new(),
..Default::default()
}
});
self.fused_adamw_update(
¶m,
&grad,
opt_state,
lr,
beta1,
beta2,
eps,
weight_decay,
)?;
}
}
state.fusion_stats.fused_operations += 1;
let bandwidth_saved = batch_size * 4 * 8;
state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
Ok(())
}
fn execute_fused_sgd_batch(
&mut self,
operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
) -> Result<()> {
let mut state = self.state.lock().expect("Mutex lock poisoned");
let batch_size = operations.len();
for (param_name, op, param, grad) in operations {
if let FusedOperation::FusedSGDMomentum {
lr,
momentum,
dampening,
weight_decay,
nesterov,
} = op
{
let opt_state =
state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
OptimizerState {
step: 0,
momentum: HashMap::new(),
..Default::default()
}
});
self.fused_sgd_update(
¶m,
&grad,
opt_state,
lr,
momentum,
dampening,
weight_decay,
nesterov,
)?;
}
}
state.fusion_stats.fused_operations += 1;
let bandwidth_saved = batch_size * 2 * 8; state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
Ok(())
}
fn execute_fused_clipping_batch(
&mut self,
operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
) -> Result<()> {
let mut state = self.state.lock().expect("Mutex lock poisoned");
let batch_size = operations.len();
let mut gradients = Vec::new();
for (_, _, _, grad) in &operations {
gradients.push(grad.clone());
}
let global_norm = self.compute_global_norm(&gradients)?;
for (_, op, _, grad) in operations {
if let FusedOperation::FusedGradientClipping {
max_norm,
scale_factor,
} = op
{
if global_norm > max_norm {
let clip_coef = max_norm / global_norm;
let grad_mut = grad;
grad_mut.mul_scalar((clip_coef * scale_factor) as f32)?;
} else {
let grad_mut = grad;
grad_mut.mul_scalar(scale_factor as f32)?;
}
}
}
state.fusion_stats.fused_operations += 1;
let bandwidth_saved = batch_size * 8; state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
Ok(())
}
fn execute_single_operation(
&mut self,
_param_name: String,
_operation: FusedOperation,
_parameter: Tensor,
_gradient: Tensor,
) -> Result<()> {
Ok(())
}
fn fused_adam_update(
&self,
param: &Tensor,
grad: &Tensor,
state: &mut OptimizerState,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
) -> Result<()> {
use crate::common::ParameterIds;
state.step += 1;
let param_id = ParameterIds::from_tensor(param)?;
let param_len = param.data()?.len();
let momentum =
state.momentum.entry(param_id.clone()).or_insert_with(|| vec![0.0; param_len]);
let variance = state.variance.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
let grad_data = grad.data()?;
let mut param_data = param.data()?;
let bias_correction1 = 1.0 - beta1.powi(state.step as i32);
let bias_correction2 = 1.0 - beta2.powi(state.step as i32);
for i in 0..param_data.len() {
let mut grad_val = grad_data[i];
if weight_decay > 0.0 {
grad_val += weight_decay as f32 * param_data[i];
}
momentum[i] = beta1 as f32 * momentum[i] + (1.0 - beta1 as f32) * grad_val;
variance[i] = beta2 as f32 * variance[i] + (1.0 - beta2 as f32) * grad_val * grad_val;
let m_hat = momentum[i] / bias_correction1 as f32;
let v_hat = variance[i] / bias_correction2 as f32;
param_data[i] -= lr as f32 * m_hat / (v_hat.sqrt() + eps as f32);
}
Ok(())
}
fn fused_adamw_update(
&self,
param: &Tensor,
grad: &Tensor,
state: &mut OptimizerState,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
) -> Result<()> {
use crate::common::ParameterIds;
state.step += 1;
let param_id = ParameterIds::from_tensor(param)?;
let param_len = param.data()?.len();
let momentum =
state.momentum.entry(param_id.clone()).or_insert_with(|| vec![0.0; param_len]);
let variance = state.variance.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
let grad_data = grad.data()?;
let mut param_data = param.data()?;
let bias_correction1 = 1.0 - beta1.powi(state.step as i32);
let bias_correction2 = 1.0 - beta2.powi(state.step as i32);
for i in 0..param_data.len() {
let grad_val = grad_data[i];
momentum[i] = beta1 as f32 * momentum[i] + (1.0 - beta1 as f32) * grad_val;
variance[i] = beta2 as f32 * variance[i] + (1.0 - beta2 as f32) * grad_val * grad_val;
let m_hat = momentum[i] / bias_correction1 as f32;
let v_hat = variance[i] / bias_correction2 as f32;
let adaptive_step = lr as f32 * m_hat / (v_hat.sqrt() + eps as f32);
let weight_decay_step = lr as f32 * weight_decay as f32 * param_data[i];
param_data[i] -= adaptive_step + weight_decay_step;
}
Ok(())
}
fn fused_sgd_update(
&self,
param: &Tensor,
grad: &Tensor,
state: &mut OptimizerState,
lr: f64,
momentum_coef: f64,
dampening: f64,
weight_decay: f64,
nesterov: bool,
) -> Result<()> {
use crate::common::ParameterIds;
state.step += 1;
let param_id = ParameterIds::from_tensor(param)?;
let param_len = param.data()?.len();
let momentum = state.momentum.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
let grad_data = grad.data()?;
let mut param_data = param.data()?;
for i in 0..param_data.len() {
let mut grad_val = grad_data[i];
if weight_decay > 0.0 {
grad_val += weight_decay as f32 * param_data[i];
}
if momentum_coef > 0.0 {
if state.step == 1 {
momentum[i] = grad_val;
} else {
momentum[i] =
momentum_coef as f32 * momentum[i] + (1.0 - dampening as f32) * grad_val;
}
let update_direction = if nesterov {
grad_val + momentum_coef as f32 * momentum[i]
} else {
momentum[i]
};
param_data[i] -= lr as f32 * update_direction;
} else {
param_data[i] -= lr as f32 * grad_val;
}
}
Ok(())
}
fn compute_global_norm(&self, gradients: &[Tensor]) -> Result<f64> {
let mut total_norm_sq = 0.0;
for grad in gradients {
let norm = grad.norm()?;
total_norm_sq += norm * norm;
}
Ok(total_norm_sq.sqrt() as f64)
}
pub fn flush(&mut self) -> Result<()> {
self.execute_fused_batch()
}
pub fn get_fusion_stats(&self) -> FusionStats {
let state = self.state.lock().expect("Mutex lock poisoned");
state.fusion_stats.clone()
}
pub fn reset_stats(&mut self) {
let mut state = self.state.lock().expect("Mutex lock poisoned");
state.fusion_stats = FusionStats::default();
}
pub fn update_config(&mut self, config: FusionConfig) {
self.config = config;
}
}
#[cfg(target_arch = "x86_64")]
pub mod simd {
pub fn simd_adam_update(
param: &mut [f32],
grad: &[f32],
momentum: &mut [f32],
velocity: &mut [f32],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
step: i32,
) {
use std::arch::x86_64::*;
let bias_correction1 = 1.0 - beta1.powi(step);
let bias_correction2 = 1.0 - beta2.powi(step);
let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
unsafe {
let beta1_vec = _mm256_set1_ps(beta1);
let beta2_vec = _mm256_set1_ps(beta2);
let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
let eps_vec = _mm256_set1_ps(eps);
let lr_vec = _mm256_set1_ps(corrected_lr);
let chunks = param.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let p = _mm256_loadu_ps(param.as_ptr().add(idx));
let g = _mm256_loadu_ps(grad.as_ptr().add(idx));
let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
let g_sq = _mm256_mul_ps(g, g);
let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
let v_sqrt = _mm256_sqrt_ps(v_new);
let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
let update = _mm256_div_ps(m_new, v_sqrt_eps);
let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
_mm256_storeu_ps(param.as_mut_ptr().add(idx), p_new);
_mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
_mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
}
for i in (chunks * 8)..param.len() {
let g = grad[i];
momentum[i] = beta1 * momentum[i] + (1.0 - beta1) * g;
velocity[i] = beta2 * velocity[i] + (1.0 - beta2) * g * g;
param[i] -= corrected_lr * momentum[i] / (velocity[i].sqrt() + eps);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::Tensor;
#[test]
fn test_fused_optimizer_creation() {
let config = FusionConfig::default();
let optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
let stats = optimizer.get_fusion_stats();
assert_eq!(stats.fused_operations, 0);
}
#[test]
fn test_fused_adam_operation() {
let config = FusionConfig::default();
let mut optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
let param = Tensor::ones(&[10, 10]).expect("Failed to create tensor");
let grad = Tensor::ones(&[10, 10]).expect("Failed to create tensor");
let operation = FusedOperation::FusedAdam {
lr: 0.001,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
};
optimizer
.queue_operation("param1".to_string(), operation, param, grad)
.expect("Failed to queue operation");
optimizer.flush().expect("Flush failed");
let stats = optimizer.get_fusion_stats();
assert_eq!(stats.fused_operations, 1);
}
#[test]
fn test_fused_adamw_operation() {
let config = FusionConfig::default();
let mut optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
let param = Tensor::ones(&[5, 5]).expect("Failed to create tensor");
let grad = Tensor::ones(&[5, 5]).expect("Failed to create tensor");
let operation = FusedOperation::FusedAdamW {
lr: 0.001,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
};
optimizer
.queue_operation("param2".to_string(), operation, param, grad)
.expect("Failed to queue operation");
optimizer.flush().expect("Flush failed");
let stats = optimizer.get_fusion_stats();
assert_eq!(stats.fused_operations, 1);
}
#[test]
fn test_fused_sgd_operation() {
let config = FusionConfig::default();
let mut optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
let param = Tensor::ones(&[3, 3]).expect("Failed to create tensor");
let grad = Tensor::ones(&[3, 3]).expect("Failed to create tensor");
let operation = FusedOperation::FusedSGDMomentum {
lr: 0.01,
momentum: 0.9,
dampening: 0.0,
weight_decay: 0.0,
nesterov: false,
};
optimizer
.queue_operation("param3".to_string(), operation, param, grad)
.expect("Failed to queue operation");
optimizer.flush().expect("Flush failed");
let stats = optimizer.get_fusion_stats();
assert_eq!(stats.fused_operations, 1);
}
#[test]
fn test_batch_fusion() {
let config = FusionConfig {
batch_size: 2,
..FusionConfig::default()
};
let mut optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
for i in 0..3 {
let param = Tensor::ones(&[2, 2]).expect("Failed to create tensor");
let grad = Tensor::ones(&[2, 2]).expect("Failed to create tensor");
let operation = FusedOperation::FusedAdam {
lr: 0.001,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
};
optimizer
.queue_operation(format!("param_{}", i), operation, param, grad)
.expect("Operation failed in test");
}
let stats = optimizer.get_fusion_stats();
assert!(stats.fused_operations > 0);
}
#[test]
fn test_fusion_stats() {
let config = FusionConfig::default();
let mut optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
let param = Tensor::ones(&[10, 10]).expect("Failed to create tensor");
let grad = Tensor::ones(&[10, 10]).expect("Failed to create tensor");
let operation = FusedOperation::FusedAdam {
lr: 0.001,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
};
optimizer
.queue_operation("param1".to_string(), operation, param, grad)
.expect("Failed to queue operation");
optimizer.flush().expect("Flush failed");
let stats = optimizer.get_fusion_stats();
assert_eq!(stats.fused_operations, 1);
assert!(stats.memory_bandwidth_saved > 0);
optimizer.reset_stats();
let reset_stats = optimizer.get_fusion_stats();
assert_eq!(reset_stats.fused_operations, 0);
assert_eq!(reset_stats.memory_bandwidth_saved, 0);
}
#[test]
fn test_global_norm_computation() {
let config = FusionConfig::default();
let optimizer = FusedOptimizer::new(config).expect("Failed to create fused optimizer");
let grad1 = Tensor::ones(&[3, 3]).expect("Failed to create tensor");
let grad2 = Tensor::ones(&[2, 2]).expect("Failed to create tensor");
let gradients = vec![grad1, grad2];
let global_norm = optimizer
.compute_global_norm(&gradients)
.expect("Failed to compute global norm");
assert!((global_norm - 3.606).abs() < 0.01);
}
}