use crate::distributed::DistributedOps;
use crate::error::{RusTorchError, RusTorchResult};
use crate::optim::{Adam, Optimizer, SGD};
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy)]
pub enum ReduceOp {
Sum,
Average,
Max,
Min,
}
pub struct DistributedOptimizer<T: Float + Send + Sync + 'static> {
base_optimizer: Box<dyn Optimizer + Send + Sync>,
backend: Arc<dyn DistributedOps<T> + Send + Sync>,
sync_strategy: GradientSyncStrategy,
gradient_buckets: Vec<GradientBucket<T>>,
communication_freq: usize,
step_count: usize,
local_gradients: HashMap<String, Tensor<T>>,
}
#[derive(Debug, Clone, Copy)]
pub enum GradientSyncStrategy {
Synchronous,
Asynchronous,
LocalSGD {
sync_frequency: usize,
},
Compressed {
compression_ratio: f32,
},
Hierarchical,
}
pub struct GradientBucket<T: Float> {
id: usize,
tensors: Vec<Arc<Mutex<Tensor<T>>>>,
total_size: usize,
max_size: usize,
ready: bool,
}
impl<T: Float + Send + Sync + 'static> DistributedOptimizer<T> {
pub fn new(
base_optimizer: Box<dyn Optimizer + Send + Sync>,
backend: Arc<dyn DistributedOps<T> + Send + Sync>,
sync_strategy: GradientSyncStrategy,
) -> Self {
Self {
base_optimizer,
backend,
sync_strategy,
gradient_buckets: Vec::new(),
communication_freq: 1,
step_count: 0,
local_gradients: HashMap::new(),
}
}
pub fn sgd(
learning_rate: T,
momentum: T,
weight_decay: T,
backend: Arc<dyn DistributedOps<T> + Send + Sync>,
sync_strategy: GradientSyncStrategy,
) -> RusTorchResult<Self> {
let lr_f32 = learning_rate.to_f32().unwrap_or(0.001);
let momentum_f32 = momentum.to_f32().unwrap_or(0.9);
let wd_f32 = weight_decay.to_f32().unwrap_or(0.0);
let sgd = if wd_f32 > 0.0 {
SGD::with_weight_decay(lr_f32, momentum_f32, wd_f32)
} else {
SGD::with_momentum(lr_f32, momentum_f32)
};
Ok(Self::new(Box::new(sgd), backend, sync_strategy))
}
pub fn adam(
learning_rate: T,
beta1: T,
beta2: T,
epsilon: T,
weight_decay: T,
backend: Arc<dyn DistributedOps<T> + Send + Sync>,
sync_strategy: GradientSyncStrategy,
) -> RusTorchResult<Self> {
let lr_f32 = learning_rate.to_f32().unwrap_or(0.001);
let beta1_f32 = beta1.to_f32().unwrap_or(0.9);
let beta2_f32 = beta2.to_f32().unwrap_or(0.999);
let eps_f32 = epsilon.to_f32().unwrap_or(1e-8);
let wd_f32 = weight_decay.to_f32().unwrap_or(0.0);
let adam = if wd_f32 > 0.0 {
Adam::with_weight_decay(lr_f32, beta1_f32, beta2_f32, eps_f32, wd_f32)
} else {
Adam::new(lr_f32, beta1_f32, beta2_f32, eps_f32)
};
Ok(Self::new(Box::new(adam), backend, sync_strategy))
}
pub fn init_gradient_buckets(&mut self, max_bucket_size: usize) -> RusTorchResult<()> {
self.gradient_buckets.clear();
let bucket = GradientBucket {
id: 0,
tensors: Vec::new(),
total_size: 0,
max_size: max_bucket_size,
ready: false,
};
self.gradient_buckets.push(bucket);
Ok(())
}
pub fn add_to_bucket(&mut self, tensor: Arc<Mutex<Tensor<T>>>) -> RusTorchResult<()> {
let tensor_size = {
let t = tensor.lock().unwrap();
t.shape().iter().product::<usize>()
};
let bucket_idx = self.find_or_create_bucket(tensor_size)?;
let bucket = &mut self.gradient_buckets[bucket_idx];
bucket.tensors.push(tensor);
bucket.total_size += tensor_size;
if bucket.total_size >= bucket.max_size {
bucket.ready = true;
}
Ok(())
}
fn find_or_create_bucket(&mut self, tensor_size: usize) -> RusTorchResult<usize> {
for (idx, bucket) in self.gradient_buckets.iter().enumerate() {
if bucket.total_size + tensor_size <= bucket.max_size {
return Ok(idx);
}
}
let new_bucket = GradientBucket {
id: self.gradient_buckets.len(),
tensors: Vec::new(),
total_size: 0,
max_size: self.gradient_buckets[0].max_size,
ready: false,
};
self.gradient_buckets.push(new_bucket);
Ok(self.gradient_buckets.len() - 1)
}
pub fn sync_gradients(&mut self) -> RusTorchResult<()> {
match self.sync_strategy {
GradientSyncStrategy::Synchronous => self.sync_gradients_synchronous(),
GradientSyncStrategy::Asynchronous => self.sync_gradients_asynchronous(),
GradientSyncStrategy::LocalSGD { sync_frequency } => {
self.sync_gradients_local_sgd(sync_frequency)
}
GradientSyncStrategy::Compressed { compression_ratio } => {
self.sync_gradients_compressed(compression_ratio)
}
GradientSyncStrategy::Hierarchical => self.sync_gradients_hierarchical(),
}
}
fn sync_gradients_synchronous(&mut self) -> RusTorchResult<()> {
let backend = self.backend.clone();
for bucket in &mut self.gradient_buckets {
if bucket.ready {
Self::sync_bucket_with_backend(&backend, bucket)?;
bucket.ready = false;
}
}
for bucket in &mut self.gradient_buckets {
if !bucket.tensors.is_empty() {
Self::sync_bucket_with_backend(&backend, bucket)?;
}
}
Ok(())
}
fn sync_gradients_asynchronous(&mut self) -> RusTorchResult<()> {
let backend = self.backend.clone();
for bucket in &mut self.gradient_buckets {
if bucket.ready {
Self::sync_bucket_with_backend(&backend, bucket)?;
bucket.ready = false;
}
}
Ok(())
}
fn sync_gradients_local_sgd(&mut self, sync_frequency: usize) -> RusTorchResult<()> {
self.step_count += 1;
if self.step_count % sync_frequency == 0 {
for gradient in self.local_gradients.values_mut() {
self.backend
.all_reduce(gradient, crate::distributed::ReduceOp::Average)?;
}
self.local_gradients.clear();
} else {
for bucket in &self.gradient_buckets {
for tensor_ref in &bucket.tensors {
let tensor = tensor_ref.lock().unwrap();
let key = format!("tensor_{}", bucket.id);
if let Some(_accumulated) = self.local_gradients.get_mut(&key) {
} else {
self.local_gradients.insert(key, tensor.clone());
}
}
}
}
Ok(())
}
fn sync_gradients_compressed(&mut self, compression_ratio: f32) -> RusTorchResult<()> {
let backend = self.backend.clone();
for bucket in &mut self.gradient_buckets {
if bucket.ready {
Self::compress_and_sync_bucket_with_backend(&backend, bucket, compression_ratio)?;
bucket.ready = false;
}
}
Ok(())
}
fn sync_gradients_hierarchical(&mut self) -> RusTorchResult<()> {
self.sync_gradients_synchronous()
}
fn sync_bucket_with_backend(
backend: &Arc<dyn DistributedOps<T> + Send + Sync>,
bucket: &mut GradientBucket<T>,
) -> RusTorchResult<()> {
for tensor_ref in &bucket.tensors {
let mut tensor = tensor_ref.lock().unwrap();
backend.all_reduce(&mut tensor, crate::distributed::ReduceOp::Average)?;
}
Ok(())
}
fn compress_and_sync_bucket_with_backend(
backend: &Arc<dyn DistributedOps<T> + Send + Sync>,
bucket: &mut GradientBucket<T>,
_compression_ratio: f32,
) -> RusTorchResult<()> {
for tensor_ref in &bucket.tensors {
let mut tensor = tensor_ref.lock().unwrap();
backend.all_reduce(&mut tensor, crate::distributed::ReduceOp::Average)?;
}
Ok(())
}
pub fn set_communication_frequency(&mut self, freq: usize) {
self.communication_freq = freq;
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn reset_step_count(&mut self) {
self.step_count = 0;
}
}
impl<T: Float + Send + Sync + 'static> Optimizer for DistributedOptimizer<T> {
fn step(&mut self, param: &Tensor<f32>, grad: &Tensor<f32>) {
if let Err(_e) = self.sync_gradients() {
return;
}
self.base_optimizer.step(param, grad);
self.step_count += 1;
}
fn zero_grad(&mut self) {
self.base_optimizer.zero_grad();
self.local_gradients.clear();
}
fn learning_rate(&self) -> f32 {
self.base_optimizer.learning_rate()
}
fn set_learning_rate(&mut self, lr: f32) {
self.base_optimizer.set_learning_rate(lr);
}
fn state_dict(&self) -> std::collections::HashMap<String, f32> {
self.base_optimizer.state_dict()
}
fn load_state_dict(&mut self, state: std::collections::HashMap<String, f32>) {
self.base_optimizer.load_state_dict(state);
}
}
pub struct DistributedOptimizerBuilder<T: Float + Send + Sync + 'static> {
optimizer_type: OptimizerType<T>,
backend: Option<Arc<dyn DistributedOps<T> + Send + Sync>>,
sync_strategy: GradientSyncStrategy,
bucket_size: usize,
}
pub enum OptimizerType<T: Float> {
SGD {
learning_rate: T,
momentum: T,
weight_decay: T,
},
Adam {
learning_rate: T,
beta1: T,
beta2: T,
epsilon: T,
weight_decay: T,
},
}
impl<T: Float + Send + Sync + 'static> DistributedOptimizerBuilder<T> {
pub fn sgd(learning_rate: T, momentum: T, weight_decay: T) -> Self {
Self {
optimizer_type: OptimizerType::SGD {
learning_rate,
momentum,
weight_decay,
},
backend: None,
sync_strategy: GradientSyncStrategy::Synchronous,
bucket_size: 25 * 1024 * 1024, }
}
pub fn adam(learning_rate: T, beta1: T, beta2: T, epsilon: T, weight_decay: T) -> Self {
Self {
optimizer_type: OptimizerType::Adam {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
},
backend: None,
sync_strategy: GradientSyncStrategy::Synchronous,
bucket_size: 25 * 1024 * 1024, }
}
pub fn backend(mut self, backend: Arc<dyn DistributedOps<T> + Send + Sync>) -> Self {
self.backend = Some(backend);
self
}
pub fn sync_strategy(mut self, strategy: GradientSyncStrategy) -> Self {
self.sync_strategy = strategy;
self
}
pub fn bucket_size(mut self, size: usize) -> Self {
self.bucket_size = size;
self
}
pub fn build(self) -> RusTorchResult<DistributedOptimizer<T>> {
let backend = self.backend.ok_or_else(|| {
RusTorchError::ConfigurationError("Backend not specified".to_string())
})?;
let base_optimizer: Box<dyn Optimizer + Send + Sync> = match self.optimizer_type {
OptimizerType::SGD {
learning_rate,
momentum,
weight_decay,
} => {
let lr_f32 = learning_rate.to_f32().unwrap_or(0.001);
let momentum_f32 = momentum.to_f32().unwrap_or(0.9);
let wd_f32 = weight_decay.to_f32().unwrap_or(0.0);
if wd_f32 > 0.0 {
Box::new(SGD::with_weight_decay(lr_f32, momentum_f32, wd_f32))
} else {
Box::new(SGD::with_momentum(lr_f32, momentum_f32))
}
}
OptimizerType::Adam {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
} => {
let lr_f32 = learning_rate.to_f32().unwrap_or(0.001);
let beta1_f32 = beta1.to_f32().unwrap_or(0.9);
let beta2_f32 = beta2.to_f32().unwrap_or(0.999);
let eps_f32 = epsilon.to_f32().unwrap_or(1e-8);
let wd_f32 = weight_decay.to_f32().unwrap_or(0.0);
if wd_f32 > 0.0 {
Box::new(Adam::with_weight_decay(
lr_f32, beta1_f32, beta2_f32, eps_f32, wd_f32,
))
} else {
Box::new(Adam::new(lr_f32, beta1_f32, beta2_f32, eps_f32))
}
}
};
let mut optimizer = DistributedOptimizer::new(base_optimizer, backend, self.sync_strategy);
optimizer.init_gradient_buckets(self.bucket_size)?;
Ok(optimizer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gradient_sync_strategy_creation() {
let sync = GradientSyncStrategy::Synchronous;
let async_strategy = GradientSyncStrategy::Asynchronous;
assert!(matches!(sync, GradientSyncStrategy::Synchronous));
assert!(matches!(async_strategy, GradientSyncStrategy::Asynchronous));
}
#[test]
fn test_gradient_sync_strategies() {
let strategies = [
GradientSyncStrategy::Synchronous,
GradientSyncStrategy::Asynchronous,
GradientSyncStrategy::LocalSGD { sync_frequency: 10 },
GradientSyncStrategy::Compressed {
compression_ratio: 0.1,
},
GradientSyncStrategy::Hierarchical,
];
for strategy in &strategies {
assert!(matches!(
strategy,
GradientSyncStrategy::Synchronous
| GradientSyncStrategy::Asynchronous
| GradientSyncStrategy::LocalSGD { .. }
| GradientSyncStrategy::Compressed { .. }
| GradientSyncStrategy::Hierarchical
));
}
}
}