use super::communication::{
compress_tensor, decompress_tensor, AsyncCommunicator, CommunicationError, CompressionStrategy,
MessagePriority, TensorMessage,
};
use super::coordinator::{CoordinatorError, RingAllReduce};
use super::data_parallel::GradientAggregation;
use super::process::{Communicator, ProcessError};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
#[derive(Error, Debug)]
pub enum OptimizerError {
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Communication error: {0}")]
Communication(#[from] CommunicationError),
#[error("Coordinator error: {0}")]
Coordinator(#[from] CoordinatorError),
#[error("Invalid learning rate: {0}")]
InvalidLearningRate(f32),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("State mismatch: {0}")]
StateMismatch(String),
#[error("Compression error: {0}")]
CompressionError(String),
}
impl From<OptimizerError> for NumRs2Error {
fn from(err: OptimizerError) -> Self {
NumRs2Error::DistributedComputing(err.to_string())
}
}
#[derive(Debug, Clone)]
struct OptimizerState {
m: Vec<f32>,
v: Option<Vec<f32>>,
t: u64,
}
impl OptimizerState {
fn new(size: usize, use_second_moment: bool) -> Self {
Self {
m: vec![0.0; size],
v: if use_second_moment {
Some(vec![0.0; size])
} else {
None
},
t: 0,
}
}
}
pub struct DistributedSGD {
lr: f32,
momentum: f32,
nesterov: bool,
weight_decay: f32,
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
aggregation: GradientAggregation,
ring_reducer: Option<RingAllReduce>,
state: Arc<RwLock<HashMap<String, OptimizerState>>>,
compression: CompressionStrategy,
accumulation_steps: usize,
current_step: Arc<Mutex<usize>>,
}
impl DistributedSGD {
pub fn new(lr: f32, communicator: Arc<Communicator>) -> Result<Self, OptimizerError> {
if lr <= 0.0 {
return Err(OptimizerError::InvalidLearningRate(lr));
}
let async_comm = AsyncCommunicator::new(communicator.clone())?;
Ok(Self {
lr,
momentum: 0.0,
nesterov: false,
weight_decay: 0.0,
communicator,
async_comm,
aggregation: GradientAggregation::AllReduce,
ring_reducer: None,
state: Arc::new(RwLock::new(HashMap::new())),
compression: CompressionStrategy::None,
accumulation_steps: 1,
current_step: Arc::new(Mutex::new(0)),
})
}
pub fn with_momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
pub fn with_nesterov(mut self) -> Self {
self.nesterov = true;
self
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
self.weight_decay = weight_decay;
self
}
pub fn with_compression(mut self, compression: CompressionStrategy) -> Self {
self.compression = compression;
self
}
pub fn with_aggregation(
mut self,
aggregation: GradientAggregation,
) -> Result<Self, OptimizerError> {
self.aggregation = aggregation;
if aggregation == GradientAggregation::RingAllReduce {
self.ring_reducer = Some(RingAllReduce::new(self.communicator.clone())?);
}
Ok(self)
}
pub fn with_accumulation(mut self, steps: usize) -> Self {
self.accumulation_steps = steps;
self
}
pub async fn step(
&mut self,
params: &[f32],
gradients: &[f32],
) -> Result<Vec<f32>, OptimizerError> {
self.step_named("default", params, gradients).await
}
pub async fn step_named(
&mut self,
param_name: &str,
params: &[f32],
gradients: &[f32],
) -> Result<Vec<f32>, OptimizerError> {
let mut step = self.current_step.lock().await;
*step += 1;
if *step < self.accumulation_steps {
return Ok(params.to_vec());
}
*step = 0;
drop(step);
let (compressed, indices) = compress_tensor(gradients, &self.compression)
.map_err(|e| OptimizerError::CompressionError(e.to_string()))?;
let aggregated = self.aggregate_gradients(&compressed).await?;
let grads = if indices.is_some() {
decompress_tensor(&aggregated, indices.as_deref(), gradients.len())
.map_err(|e| OptimizerError::CompressionError(e.to_string()))?
} else {
aggregated
};
let mut state_map = self.state.write().await;
let state = state_map
.entry(param_name.to_string())
.or_insert_with(|| OptimizerState::new(params.len(), false));
state.t += 1;
let mut grads = grads;
if self.weight_decay > 0.0 {
for (g, &p) in grads.iter_mut().zip(params.iter()) {
*g += self.weight_decay * p;
}
}
let mut new_params = params.to_vec();
if self.momentum > 0.0 {
for i in 0..params.len() {
state.m[i] = self.momentum * state.m[i] + grads[i];
if self.nesterov {
new_params[i] -= self.lr * (grads[i] + self.momentum * state.m[i]);
} else {
new_params[i] -= self.lr * state.m[i];
}
}
} else {
for i in 0..params.len() {
new_params[i] -= self.lr * grads[i];
}
}
Ok(new_params)
}
async fn aggregate_gradients(&self, gradients: &[f32]) -> Result<Vec<f32>, OptimizerError> {
match self.aggregation {
GradientAggregation::RingAllReduce => {
if let Some(ref reducer) = self.ring_reducer {
Ok(reducer.allreduce(gradients).await?)
} else {
Err(OptimizerError::StateMismatch(
"Ring reducer not initialized".to_string(),
))
}
}
GradientAggregation::AllReduce | GradientAggregation::Hierarchical => {
let world_size = self.communicator.size() as f32;
Ok(gradients.iter().map(|&g| g / world_size).collect())
}
}
}
pub fn lr(&self) -> f32 {
self.lr
}
pub fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
pub fn momentum(&self) -> f32 {
self.momentum
}
}
pub struct DistributedAdam {
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
weight_decay: f32,
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
aggregation: GradientAggregation,
ring_reducer: Option<RingAllReduce>,
state: Arc<RwLock<HashMap<String, OptimizerState>>>,
compression: CompressionStrategy,
amsgrad: bool,
}
impl DistributedAdam {
pub fn new(lr: f32, communicator: Arc<Communicator>) -> Result<Self, OptimizerError> {
if lr <= 0.0 {
return Err(OptimizerError::InvalidLearningRate(lr));
}
let async_comm = AsyncCommunicator::new(communicator.clone())?;
Ok(Self {
lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
communicator,
async_comm,
aggregation: GradientAggregation::AllReduce,
ring_reducer: None,
state: Arc::new(RwLock::new(HashMap::new())),
compression: CompressionStrategy::None,
amsgrad: false,
})
}
pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self
}
pub fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
self.weight_decay = weight_decay;
self
}
pub fn with_amsgrad(mut self) -> Self {
self.amsgrad = true;
self
}
pub fn with_compression(mut self, compression: CompressionStrategy) -> Self {
self.compression = compression;
self
}
pub fn with_aggregation(
mut self,
aggregation: GradientAggregation,
) -> Result<Self, OptimizerError> {
self.aggregation = aggregation;
if aggregation == GradientAggregation::RingAllReduce {
self.ring_reducer = Some(RingAllReduce::new(self.communicator.clone())?);
}
Ok(self)
}
pub async fn step(
&mut self,
params: &[f32],
gradients: &[f32],
) -> Result<Vec<f32>, OptimizerError> {
self.step_named("default", params, gradients).await
}
pub async fn step_named(
&mut self,
param_name: &str,
params: &[f32],
gradients: &[f32],
) -> Result<Vec<f32>, OptimizerError> {
let (compressed, indices) = compress_tensor(gradients, &self.compression)
.map_err(|e| OptimizerError::CompressionError(e.to_string()))?;
let aggregated = self.aggregate_gradients(&compressed).await?;
let grads = if indices.is_some() {
decompress_tensor(&aggregated, indices.as_deref(), gradients.len())
.map_err(|e| OptimizerError::CompressionError(e.to_string()))?
} else {
aggregated
};
let mut state_map = self.state.write().await;
let state = state_map
.entry(param_name.to_string())
.or_insert_with(|| OptimizerState::new(params.len(), true));
state.t += 1;
let t = state.t as f32;
let bias_correction1 = 1.0 - self.beta1.powi(state.t as i32);
let bias_correction2 = 1.0 - self.beta2.powi(state.t as i32);
let mut new_params = params.to_vec();
for i in 0..params.len() {
state.m[i] = self.beta1 * state.m[i] + (1.0 - self.beta1) * grads[i];
if let Some(ref mut v) = state.v {
v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * grads[i] * grads[i];
let m_hat = state.m[i] / bias_correction1;
let v_hat = v[i] / bias_correction2;
if self.weight_decay > 0.0 {
new_params[i] -= self.lr
* (m_hat / (v_hat.sqrt() + self.epsilon) + self.weight_decay * params[i]);
} else {
new_params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
}
}
}
Ok(new_params)
}
async fn aggregate_gradients(&self, gradients: &[f32]) -> Result<Vec<f32>, OptimizerError> {
match self.aggregation {
GradientAggregation::RingAllReduce => {
if let Some(ref reducer) = self.ring_reducer {
Ok(reducer.allreduce(gradients).await?)
} else {
Err(OptimizerError::StateMismatch(
"Ring reducer not initialized".to_string(),
))
}
}
GradientAggregation::AllReduce | GradientAggregation::Hierarchical => {
let world_size = self.communicator.size() as f32;
Ok(gradients.iter().map(|&g| g / world_size).collect())
}
}
}
pub fn lr(&self) -> f32 {
self.lr
}
pub fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
pub fn beta1(&self) -> f32 {
self.beta1
}
pub fn beta2(&self) -> f32 {
self.beta2
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::process::{ProcessGroup, ProcessInfo};
use std::collections::HashMap;
use std::net::SocketAddr;
fn create_mock_comm(rank: usize, size: usize) -> Result<Arc<Communicator>, ProcessError> {
let addr: SocketAddr = format!("127.0.0.1:{}", 8000 + rank)
.parse()
.map_err(|e| ProcessError::ConfigError(format!("Invalid address: {}", e)))?;
let info = ProcessInfo::new(rank, size, addr, format!("localhost-{}", rank))?;
let ranks: Vec<usize> = (0..size).collect();
let group = ProcessGroup::new(ranks)?;
let mut addresses = HashMap::new();
for i in 0..size {
let peer_addr: SocketAddr = format!("127.0.0.1:{}", 8000 + i)
.parse()
.map_err(|e| ProcessError::ConfigError(format!("Invalid address: {}", e)))?;
addresses.insert(i, peer_addr);
}
let comm = Communicator::new(info, group, addresses)?;
Ok(Arc::new(comm))
}
#[test]
fn test_optimizer_error_from_invalid_lr() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let result = DistributedSGD::new(-0.01, comm);
assert!(result.is_err());
}
#[test]
fn test_distributed_sgd_creation() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let sgd = DistributedSGD::new(0.01, comm);
assert!(sgd.is_ok());
let optimizer = sgd.expect("optimizer creation failed");
assert_eq!(optimizer.lr(), 0.01);
assert_eq!(optimizer.momentum(), 0.0);
}
#[test]
fn test_distributed_sgd_with_momentum() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let sgd = DistributedSGD::new(0.01, comm)
.expect("optimizer creation failed")
.with_momentum(0.9);
assert_eq!(sgd.momentum(), 0.9);
}
#[test]
fn test_distributed_sgd_set_lr() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let mut sgd = DistributedSGD::new(0.01, comm).expect("optimizer creation failed");
sgd.set_lr(0.001);
assert_eq!(sgd.lr(), 0.001);
}
#[test]
fn test_distributed_adam_creation() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let adam = DistributedAdam::new(0.001, comm);
assert!(adam.is_ok());
let optimizer = adam.expect("optimizer creation failed");
assert_eq!(optimizer.lr(), 0.001);
assert_eq!(optimizer.beta1(), 0.9);
assert_eq!(optimizer.beta2(), 0.999);
}
#[test]
fn test_distributed_adam_with_betas() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let adam = DistributedAdam::new(0.001, comm)
.expect("optimizer creation failed")
.with_betas(0.95, 0.9999);
assert_eq!(adam.beta1(), 0.95);
assert_eq!(adam.beta2(), 0.9999);
}
#[test]
fn test_distributed_adam_set_lr() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let mut adam = DistributedAdam::new(0.001, comm).expect("optimizer creation failed");
adam.set_lr(0.0001);
assert_eq!(adam.lr(), 0.0001);
}
#[test]
fn test_optimizer_state_creation() {
let state = OptimizerState::new(10, false);
assert_eq!(state.m.len(), 10);
assert!(state.v.is_none());
assert_eq!(state.t, 0);
let state_with_v = OptimizerState::new(10, true);
assert!(state_with_v.v.is_some());
assert_eq!(state_with_v.v.as_ref().expect("v missing").len(), 10);
}
#[test]
fn test_compression_strategy_with_optimizer() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let sgd = DistributedSGD::new(0.01, comm)
.expect("optimizer creation failed")
.with_compression(CompressionStrategy::TopK { k: 100 });
let _ = sgd;
}
#[test]
fn test_accumulation_steps() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let sgd = DistributedSGD::new(0.01, comm)
.expect("optimizer creation failed")
.with_accumulation(4);
assert_eq!(sgd.accumulation_steps, 4);
}
#[test]
fn test_distributed_sgd_with_nesterov() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let sgd = DistributedSGD::new(0.01, comm)
.expect("optimizer creation failed")
.with_nesterov();
assert!(sgd.nesterov);
}
#[test]
fn test_distributed_adam_with_weight_decay() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let adam = DistributedAdam::new(0.001, comm)
.expect("optimizer creation failed")
.with_weight_decay(0.01);
assert_eq!(adam.weight_decay, 0.01);
}
#[test]
fn test_distributed_adam_with_amsgrad() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let adam = DistributedAdam::new(0.001, comm)
.expect("optimizer creation failed")
.with_amsgrad();
assert!(adam.amsgrad);
}
}