use super::communication::{
AsyncCommunicator, CommunicationError, CompressionStrategy, MessagePriority, TensorMessage,
};
use super::coordinator::{CoordinatorError, ParameterServer, RingAllReduce};
use super::process::{Communicator, ProcessError};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
#[derive(Error, Debug)]
pub enum DataParallelError {
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Communication error: {0}")]
Communication(#[from] CommunicationError),
#[error("Coordinator error: {0}")]
Coordinator(#[from] CoordinatorError),
#[error("Invalid batch size: {0}")]
InvalidBatchSize(usize),
#[error("Dataset size mismatch: expected {expected}, got {actual}")]
DatasetSizeMismatch { expected: usize, actual: usize },
#[error("Gradient aggregation error: {0}")]
AggregationError(String),
#[error("Sharding error: {0}")]
ShardingError(String),
}
impl From<DataParallelError> for NumRs2Error {
fn from(err: DataParallelError) -> Self {
NumRs2Error::DistributedComputing(err.to_string())
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
)]
pub enum ShardingStrategy {
Block,
Cyclic,
Random,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GradientAggregation {
AllReduce,
RingAllReduce,
Hierarchical,
}
pub struct DistributedDataLoader<T> {
local_data: Vec<T>,
batch_size: usize,
position: usize,
global_size: usize,
rank: usize,
world_size: usize,
strategy: ShardingStrategy,
}
impl<T: Clone> DistributedDataLoader<T> {
pub fn new(
global_data: Vec<T>,
batch_size: usize,
communicator: &Communicator,
strategy: ShardingStrategy,
) -> Result<Self, DataParallelError> {
if batch_size == 0 {
return Err(DataParallelError::InvalidBatchSize(batch_size));
}
let rank = communicator.rank();
let world_size = communicator.size();
let global_size = global_data.len();
let local_data = Self::shard_data(global_data, rank, world_size, strategy)?;
Ok(Self {
local_data,
batch_size,
position: 0,
global_size,
rank,
world_size,
strategy,
})
}
fn shard_data(
mut data: Vec<T>,
rank: usize,
world_size: usize,
strategy: ShardingStrategy,
) -> Result<Vec<T>, DataParallelError> {
match strategy {
ShardingStrategy::Block => {
let chunk_size = data.len().div_ceil(world_size);
let start = rank * chunk_size;
let end = (start + chunk_size).min(data.len());
if start >= data.len() {
Ok(Vec::new())
} else {
Ok(data.drain(start..end).collect())
}
}
ShardingStrategy::Cyclic => {
let local: Vec<T> = data
.iter()
.enumerate()
.filter(|(i, _)| i % world_size == rank)
.map(|(_, item)| item.clone())
.collect();
Ok(local)
}
ShardingStrategy::Random => {
let local: Vec<T> = data
.iter()
.enumerate()
.filter(|(i, _)| {
let hash = i.wrapping_mul(2654435761) % world_size;
hash == rank
})
.map(|(_, item)| item.clone())
.collect();
Ok(local)
}
}
}
pub fn next_batch(&mut self) -> Option<Vec<T>> {
if self.position >= self.local_data.len() {
return None;
}
let end = (self.position + self.batch_size).min(self.local_data.len());
let batch = self.local_data[self.position..end].to_vec();
self.position = end;
Some(batch)
}
pub fn reset(&mut self) {
self.position = 0;
}
pub fn num_batches(&self) -> usize {
self.local_data.len().div_ceil(self.batch_size)
}
pub fn local_size(&self) -> usize {
self.local_data.len()
}
pub fn global_size(&self) -> usize {
self.global_size
}
pub fn iter(&mut self) -> DataLoaderIterator<T> {
DataLoaderIterator { loader: self }
}
}
pub struct DataLoaderIterator<'a, T> {
loader: &'a mut DistributedDataLoader<T>,
}
impl<'a, T: Clone> Iterator for DataLoaderIterator<'a, T> {
type Item = Vec<T>;
fn next(&mut self) -> Option<Self::Item> {
self.loader.next_batch()
}
}
pub struct SyncDataParallel {
communicator: Arc<Communicator>,
aggregation: GradientAggregation,
ring_reducer: Option<RingAllReduce>,
accumulation_steps: usize,
current_step: Arc<Mutex<usize>>,
gradient_buffer: Arc<Mutex<Vec<f32>>>,
}
impl SyncDataParallel {
pub fn new(
communicator: Arc<Communicator>,
aggregation: GradientAggregation,
) -> Result<Self, DataParallelError> {
let ring_reducer = if aggregation == GradientAggregation::RingAllReduce {
Some(RingAllReduce::new(communicator.clone())?)
} else {
None
};
Ok(Self {
communicator,
aggregation,
ring_reducer,
accumulation_steps: 1,
current_step: Arc::new(Mutex::new(0)),
gradient_buffer: Arc::new(Mutex::new(Vec::new())),
})
}
pub fn with_accumulation(mut self, steps: usize) -> Self {
self.accumulation_steps = steps;
self
}
pub async fn aggregate_gradients(
&self,
local_gradients: &[f32],
) -> Result<Vec<f32>, DataParallelError> {
if self.accumulation_steps > 1 {
let mut buffer = self.gradient_buffer.lock().await;
if buffer.is_empty() {
buffer.resize(local_gradients.len(), 0.0);
}
for (acc, &grad) in buffer.iter_mut().zip(local_gradients.iter()) {
*acc += grad;
}
let mut step = self.current_step.lock().await;
*step += 1;
if *step < self.accumulation_steps {
return Ok(vec![0.0; local_gradients.len()]);
}
*step = 0;
let accumulated = buffer.clone();
buffer.fill(0.0);
drop(buffer);
drop(step);
self.aggregate_impl(&accumulated).await
} else {
self.aggregate_impl(local_gradients).await
}
}
async fn aggregate_impl(&self, gradients: &[f32]) -> Result<Vec<f32>, DataParallelError> {
match self.aggregation {
GradientAggregation::RingAllReduce => {
if let Some(ref reducer) = self.ring_reducer {
Ok(reducer.allreduce(gradients).await?)
} else {
Err(DataParallelError::AggregationError(
"Ring reducer not initialized".to_string(),
))
}
}
GradientAggregation::AllReduce | GradientAggregation::Hierarchical => {
let world_size = self.communicator.size() as f32;
Ok(gradients.iter().map(|&g| g / world_size).collect())
}
}
}
pub async fn current_accumulation_step(&self) -> usize {
*self.current_step.lock().await
}
pub fn accumulation_steps(&self) -> usize {
self.accumulation_steps
}
}
pub struct AsyncDataParallel {
ps: ParameterServer,
async_comm: AsyncCommunicator,
staleness_threshold: Option<u64>,
}
impl AsyncDataParallel {
pub fn new(communicator: Arc<Communicator>, num_ps: usize) -> Result<Self, DataParallelError> {
let ps = ParameterServer::new(communicator.clone(), num_ps)?;
let async_comm = AsyncCommunicator::new(communicator)?;
Ok(Self {
ps,
async_comm,
staleness_threshold: None,
})
}
pub fn with_staleness_threshold(mut self, threshold: u64) -> Self {
self.staleness_threshold = Some(threshold);
self
}
pub async fn push_gradients(
&self,
parameter_key: &str,
gradients: &[f32],
) -> Result<(), DataParallelError> {
self.ps.push_gradients(parameter_key, gradients).await?;
Ok(())
}
pub async fn pull_parameters(
&self,
parameter_key: &str,
) -> Result<Vec<f32>, DataParallelError> {
Ok(self.ps.pull_parameters(parameter_key).await?)
}
pub async fn is_stale(
&self,
parameter_key: &str,
local_version: u64,
) -> Result<bool, DataParallelError> {
if let Some(threshold) = self.staleness_threshold {
let current_version = self.ps.get_version(parameter_key).await?;
Ok(current_version > local_version + threshold)
} else {
Ok(false)
}
}
pub async fn apply_gradients(
&self,
parameter_key: &str,
learning_rate: f32,
) -> Result<(), DataParallelError> {
self.ps
.apply_gradients(parameter_key, learning_rate)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::process::{ProcessGroup, ProcessInfo};
use std::collections::HashMap;
use std::net::SocketAddr;
fn create_mock_comm(rank: usize, size: usize) -> Result<Communicator, ProcessError> {
let addr: SocketAddr = format!("127.0.0.1:{}", 8000 + rank)
.parse()
.map_err(|e| ProcessError::ConfigError(format!("Invalid address: {}", e)))?;
let info = ProcessInfo::new(rank, size, addr, format!("localhost-{}", rank))?;
let ranks: Vec<usize> = (0..size).collect();
let group = ProcessGroup::new(ranks)?;
let mut addresses = HashMap::new();
for i in 0..size {
let peer_addr: SocketAddr = format!("127.0.0.1:{}", 8000 + i)
.parse()
.map_err(|e| ProcessError::ConfigError(format!("Invalid address: {}", e)))?;
addresses.insert(i, peer_addr);
}
Communicator::new(info, group, addresses)
}
#[test]
fn test_sharding_strategy_serialization() {
let strategies = vec![
ShardingStrategy::Block,
ShardingStrategy::Cyclic,
ShardingStrategy::Random,
];
for strategy in strategies {
let serialized = oxicode::encode_to_vec(&strategy);
assert!(serialized.is_ok());
let bytes = serialized.expect("serialization failed");
let result = oxicode::decode_from_slice::<ShardingStrategy>(&bytes);
assert!(result.is_ok());
let (deserialized, _) = result.expect("deserialization failed");
assert_eq!(
std::mem::discriminant(&strategy),
std::mem::discriminant(&deserialized)
);
}
}
#[test]
fn test_gradient_aggregation_eq() {
assert_eq!(
GradientAggregation::AllReduce,
GradientAggregation::AllReduce
);
assert_ne!(
GradientAggregation::AllReduce,
GradientAggregation::RingAllReduce
);
}
#[test]
fn test_data_loader_block_sharding() {
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
let shard =
DistributedDataLoader::<f32>::shard_data(data.clone(), 0, 4, ShardingStrategy::Block);
assert!(shard.is_ok());
let local = shard.expect("sharding failed");
assert_eq!(local.len(), 25);
assert_eq!(local[0], 0.0);
assert_eq!(local[24], 24.0);
}
#[test]
fn test_data_loader_cyclic_sharding() {
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
let shard =
DistributedDataLoader::<f32>::shard_data(data.clone(), 0, 4, ShardingStrategy::Cyclic);
assert!(shard.is_ok());
let local = shard.expect("sharding failed");
assert_eq!(local.len(), 25);
assert_eq!(local[0], 0.0);
assert_eq!(local[1], 4.0);
assert_eq!(local[2], 8.0);
}
#[test]
fn test_data_loader_random_sharding() {
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
let shard =
DistributedDataLoader::<f32>::shard_data(data.clone(), 0, 4, ShardingStrategy::Random);
assert!(shard.is_ok());
let local = shard.expect("sharding failed");
assert!(!local.is_empty());
assert!(local.len() <= 100);
}
#[test]
fn test_invalid_batch_size() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let data = vec![1.0; 100];
let result = DistributedDataLoader::new(data, 0, &comm, ShardingStrategy::Block);
assert!(result.is_err());
match result.err() {
Some(DataParallelError::InvalidBatchSize(0)) => (),
_ => panic!("Expected InvalidBatchSize error"),
}
}
#[test]
fn test_data_loader_num_batches() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let data = vec![1.0; 100];
let loader = DistributedDataLoader::new(data, 32, &comm, ShardingStrategy::Block);
assert!(loader.is_ok());
let dl = loader.expect("loader creation failed");
assert_eq!(dl.num_batches(), 4);
}
#[test]
fn test_data_loader_sizes() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let data = vec![1.0; 100];
let loader = DistributedDataLoader::new(data, 32, &comm, ShardingStrategy::Block);
assert!(loader.is_ok());
let dl = loader.expect("loader creation failed");
assert_eq!(dl.local_size(), 100);
assert_eq!(dl.global_size(), 100);
}
#[test]
fn test_data_loader_next_batch() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut loader = DistributedDataLoader::new(data, 2, &comm, ShardingStrategy::Block)
.expect("loader creation failed");
let batch1 = loader.next_batch();
assert!(batch1.is_some());
assert_eq!(batch1.expect("batch1").len(), 2);
let batch2 = loader.next_batch();
assert!(batch2.is_some());
assert_eq!(batch2.expect("batch2").len(), 2);
let batch3 = loader.next_batch();
assert!(batch3.is_some());
assert_eq!(batch3.expect("batch3").len(), 1);
let batch4 = loader.next_batch();
assert!(batch4.is_none());
}
#[test]
fn test_data_loader_reset() {
let comm = create_mock_comm(0, 1).expect("Failed to create mock communicator");
let data = vec![1.0, 2.0, 3.0];
let mut loader = DistributedDataLoader::new(data, 2, &comm, ShardingStrategy::Block)
.expect("loader creation failed");
let _ = loader.next_batch();
let _ = loader.next_batch();
assert!(loader.next_batch().is_none());
loader.reset();
let batch = loader.next_batch();
assert!(batch.is_some());
}
}