use crate::{error::AutogradError, tensor::Tensor, Float, NdArray, Result};
use std::sync::{Arc, Mutex};
pub mod communication;
pub mod data_parallel;
pub mod model_parallel;
pub struct DistributedGradient<T: Float> {
local_gradients: Arc<Mutex<Vec<NdArray<T>>>>,
accumulated: Arc<Mutex<Option<Vec<NdArray<T>>>>>,
num_workers: usize,
rank: usize,
}
impl<T: Float + scirs2_core::ndarray::ScalarOperand> DistributedGradient<T> {
pub fn new(num_workers: usize, rank: usize) -> Self {
Self {
local_gradients: Arc::new(Mutex::new(Vec::new())),
accumulated: Arc::new(Mutex::new(None)),
num_workers,
rank,
}
}
pub fn add_local(&self, gradient: NdArray<T>) -> Result<()> {
let mut local = self
.local_gradients
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
local.push(gradient);
Ok(())
}
pub fn allreduce(&self) -> Result<Vec<NdArray<T>>> {
let local = self
.local_gradients
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
let num_grads = local.len();
if num_grads == 0 {
return Ok(Vec::new());
}
let mut result = Vec::with_capacity(num_grads);
for grad in local.iter() {
let averaged = grad
/ T::from(self.num_workers).ok_or_else(|| {
AutogradError::compute_error("Failed to convert num_workers".to_string())
})?;
result.push(averaged);
}
let mut accumulated = self
.accumulated
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock accumulated gradients"))?;
*accumulated = Some(result.clone());
Ok(result)
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn clear(&self) -> Result<()> {
let mut local = self
.local_gradients
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock local gradients"))?;
local.clear();
let mut accumulated = self
.accumulated
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock accumulated gradients"))?;
*accumulated = None;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParallelStrategy {
DataParallel,
ModelParallel,
PipelineParallel,
Hybrid,
}
pub struct DistributedConfig {
pub strategy: ParallelStrategy,
pub num_workers: usize,
pub rank: usize,
pub grad_accumulation_steps: usize,
pub compress_gradients: bool,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self {
strategy: ParallelStrategy::DataParallel,
num_workers: 1,
rank: 0,
grad_accumulation_steps: 1,
compress_gradients: false,
}
}
}
pub trait SyncBackend<T: Float>: Send + Sync {
fn allreduce(&self, gradients: &[NdArray<T>]) -> Result<Vec<NdArray<T>>>;
fn broadcast(&self, parameters: &[NdArray<T>]) -> Result<Vec<NdArray<T>>>;
fn gather(&self, gradient: &NdArray<T>) -> Result<Vec<NdArray<T>>>;
fn scatter(&self, data: &[NdArray<T>]) -> Result<NdArray<T>>;
}
pub struct LocalSyncBackend<T: Float> {
num_workers: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float> LocalSyncBackend<T> {
pub fn new(num_workers: usize) -> Self {
Self {
num_workers,
_phantom: std::marker::PhantomData,
}
}
}
impl<T: Float + scirs2_core::ndarray::ScalarOperand> SyncBackend<T> for LocalSyncBackend<T> {
fn allreduce(&self, gradients: &[NdArray<T>]) -> Result<Vec<NdArray<T>>> {
let divisor = T::from(self.num_workers).ok_or_else(|| {
AutogradError::compute_error("Failed to convert num_workers".to_string())
})?;
Ok(gradients.iter().map(|g| g / divisor).collect())
}
fn broadcast(&self, parameters: &[NdArray<T>]) -> Result<Vec<NdArray<T>>> {
Ok(parameters.to_vec())
}
fn gather(&self, gradient: &NdArray<T>) -> Result<Vec<NdArray<T>>> {
Ok(vec![gradient.clone(); self.num_workers])
}
fn scatter(&self, data: &[NdArray<T>]) -> Result<NdArray<T>> {
data.first()
.cloned()
.ok_or_else(|| AutogradError::invalid_argument("Empty data for scatter".to_string()))
}
}
pub struct DistributedOptimizer<T: Float> {
backend: Arc<dyn SyncBackend<T>>,
config: DistributedConfig,
grad_buffer: Arc<Mutex<Vec<Vec<NdArray<T>>>>>,
}
impl<T: Float + scirs2_core::ndarray::ScalarOperand> DistributedOptimizer<T> {
pub fn new(backend: Arc<dyn SyncBackend<T>>, config: DistributedConfig) -> Self {
Self {
backend,
config,
grad_buffer: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn accumulate_gradient(&self, gradients: Vec<NdArray<T>>) -> Result<()> {
let mut buffer = self
.grad_buffer
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
buffer.push(gradients);
Ok(())
}
pub fn should_sync(&self) -> Result<bool> {
let buffer = self
.grad_buffer
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
Ok(buffer.len() >= self.config.grad_accumulation_steps)
}
pub fn sync_gradients(&self) -> Result<Vec<NdArray<T>>> {
let mut buffer = self
.grad_buffer
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock gradient buffer"))?;
if buffer.is_empty() {
return Ok(Vec::new());
}
let num_grads = buffer[0].len();
let num_steps = buffer.len();
let mut averaged = Vec::with_capacity(num_grads);
for i in 0..num_grads {
let mut sum = buffer[0][i].clone();
for step in buffer.iter().skip(1) {
sum += &step[i];
}
let avg = sum
/ T::from(num_steps).ok_or_else(|| {
AutogradError::compute_error("Failed to convert num_steps".to_string())
})?;
averaged.push(avg);
}
let synced = self.backend.allreduce(&averaged)?;
buffer.clear();
Ok(synced)
}
pub fn config(&self) -> &DistributedConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_distributed_gradient() {
let grad_acc: DistributedGradient<f32> = DistributedGradient::new(4, 0);
let grad1: Array1<f32> = Array1::from_vec(vec![1.0, 2.0, 3.0]);
grad_acc.add_local(grad1.into_dyn()).expect("Should add");
let result = grad_acc.allreduce().expect("Should allreduce");
assert_eq!(result.len(), 1);
let result_vals = result[0].as_slice().expect("Should get slice");
assert!((result_vals[0] - 0.25).abs() < 1e-6);
}
#[test]
fn test_parallel_strategy() {
assert_eq!(
ParallelStrategy::DataParallel,
ParallelStrategy::DataParallel
);
assert_ne!(
ParallelStrategy::DataParallel,
ParallelStrategy::ModelParallel
);
}
#[test]
fn test_local_sync_backend() {
let backend: LocalSyncBackend<f64> = LocalSyncBackend::new(2);
let grad: Array1<f64> = Array1::from_vec(vec![4.0, 6.0]);
let result = backend
.allreduce(&[grad.into_dyn()])
.expect("Should allreduce");
let result_vals = result[0].as_slice().expect("Should get slice");
assert_eq!(result_vals[0], 2.0);
assert_eq!(result_vals[1], 3.0);
}
#[test]
fn test_distributed_optimizer() {
let backend = Arc::new(LocalSyncBackend::<f32>::new(1));
let config = DistributedConfig {
grad_accumulation_steps: 2,
..Default::default()
};
let optimizer = DistributedOptimizer::new(backend, config);
let grad1: Array1<f32> = Array1::from_vec(vec![1.0]);
optimizer
.accumulate_gradient(vec![grad1.into_dyn()])
.expect("Should accumulate");
assert!(!optimizer.should_sync().expect("Should check"));
let grad2: Array1<f32> = Array1::from_vec(vec![3.0]);
optimizer
.accumulate_gradient(vec![grad2.into_dyn()])
.expect("Should accumulate");
assert!(optimizer.should_sync().expect("Should check"));
let synced = optimizer.sync_gradients().expect("Should sync");
let synced_val = synced[0].as_slice().expect("Should get slice")[0];
assert_eq!(synced_val, 2.0); }
}