use crate::autograd::Variable;
use crate::distributed::get_distributed_state;
use crate::error::{RusTorchError, RusTorchResult};
use crate::gpu::DeviceType;
use crate::nn::Module;
use crate::tensor::Tensor;
use num_traits::Float;
use std::sync::Arc;
#[derive(Debug)]
pub struct DataParallel<T, M>
where
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
M: Module<T> + Send + Sync,
{
module: Arc<M>,
device_ids: Vec<DeviceType>,
sync_strategy: GradientSyncStrategy,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GradientSyncStrategy {
Synchronous,
Asynchronous,
LocalSGD {
sync_frequency: usize,
},
}
impl<T, M> DataParallel<T, M>
where
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
M: Module<T> + Send + Sync + 'static,
{
pub fn new(module: M, device_ids: Vec<DeviceType>) -> Self {
Self {
module: Arc::new(module),
device_ids,
sync_strategy: GradientSyncStrategy::Synchronous,
_phantom: std::marker::PhantomData,
}
}
pub fn set_sync_strategy(&mut self, strategy: GradientSyncStrategy) {
self.sync_strategy = strategy;
}
pub fn num_devices(&self) -> usize {
self.device_ids.len()
}
pub fn replicate_input(&self, input: &Variable<T>) -> RusTorchResult<Vec<Variable<T>>> {
let batch_size = input.data().read().unwrap().shape()[0];
let chunk_size = batch_size.div_ceil(self.device_ids.len());
let mut replicated_inputs = Vec::new();
for (i, _device) in self.device_ids.iter().enumerate() {
let start_idx = i * chunk_size;
let end_idx = ((i + 1) * chunk_size).min(batch_size);
if start_idx < batch_size {
let chunk_shape = {
let input_data = input.data();
let data_guard = input_data.read().unwrap();
let mut shape = data_guard.shape().to_vec();
shape[0] = end_idx - start_idx;
shape
};
let chunk_tensor = Tensor::zeros(&chunk_shape);
let chunk_var = Variable::new(chunk_tensor, input.requires_grad());
replicated_inputs.push(chunk_var);
}
}
Ok(replicated_inputs)
}
pub fn gather_outputs(&self, outputs: Vec<Variable<T>>) -> RusTorchResult<Variable<T>> {
if outputs.is_empty() {
return Err(RusTorchError::ProcessGroupError("No outputs to gather"));
}
let total_batch_size: usize = outputs
.iter()
.map(|o| o.data().read().unwrap().shape()[0])
.sum();
let mut output_shape = outputs[0].data().read().unwrap().shape().to_vec();
output_shape[0] = total_batch_size;
let output_tensor = Tensor::zeros(&output_shape);
let output_var = Variable::new(output_tensor, outputs[0].requires_grad());
Ok(output_var)
}
pub fn sync_gradients(&self) -> RusTorchResult<()> {
match self.sync_strategy {
GradientSyncStrategy::Synchronous => self.sync_gradients_sync(),
GradientSyncStrategy::Asynchronous => self.sync_gradients_async(),
GradientSyncStrategy::LocalSGD { sync_frequency: _ } => self.sync_gradients_local_sgd(),
}
}
fn sync_gradients_sync(&self) -> RusTorchResult<()> {
let state = get_distributed_state();
let state_guard = state.lock().unwrap();
if let Some(_pg) = &state_guard.process_group {
drop(state_guard);
Ok(())
} else {
Err(RusTorchError::ProcessGroupError(
"Process group not initialized",
))
}
}
fn sync_gradients_async(&self) -> RusTorchResult<()> {
Ok(())
}
fn sync_gradients_local_sgd(&self) -> RusTorchResult<()> {
Ok(())
}
}
impl<T, M> Module<T> for DataParallel<T, M>
where
T: Float
+ Send
+ Sync
+ 'static
+ std::fmt::Debug
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
M: Module<T> + Send + Sync + 'static,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
let replicated_inputs = self
.replicate_input(input)
.unwrap_or_else(|_| vec![input.clone()]);
let outputs: Vec<Variable<T>> = replicated_inputs
.iter()
.map(|input| self.module.forward(input))
.collect();
self.gather_outputs(outputs)
.unwrap_or_else(|_| input.clone())
}
fn parameters(&self) -> Vec<Variable<T>> {
self.module.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
pub struct DistributedDataLoader<T: Float> {
data: Vec<Tensor<T>>,
labels: Vec<Tensor<T>>,
batch_size: usize,
current_epoch: usize,
shuffle: bool,
seed: Option<u64>,
}
impl<T: Float + Send + Sync + 'static> DistributedDataLoader<T> {
pub fn new(
data: Vec<Tensor<T>>,
labels: Vec<Tensor<T>>,
batch_size: usize,
shuffle: bool,
seed: Option<u64>,
) -> Self {
Self {
data,
labels,
batch_size,
current_epoch: 0,
shuffle,
seed,
}
}
pub fn set_epoch(&mut self, epoch: usize) {
self.current_epoch = epoch;
}
pub fn iter(&self) -> DistributedDataIterator<'_, T> {
DistributedDataIterator::new(
&self.data,
&self.labels,
self.batch_size,
self.current_epoch,
self.shuffle,
self.seed,
)
}
pub fn len(&self) -> usize {
self.data.len().div_ceil(self.batch_size)
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
pub struct DistributedDataIterator<'a, T: Float> {
data: &'a [Tensor<T>],
labels: &'a [Tensor<T>],
batch_size: usize,
current_index: usize,
indices: Vec<usize>,
}
impl<'a, T: Float> DistributedDataIterator<'a, T> {
fn new(
data: &'a [Tensor<T>],
labels: &'a [Tensor<T>],
batch_size: usize,
epoch: usize,
shuffle: bool,
seed: Option<u64>,
) -> Self {
let mut indices: Vec<usize> = (0..data.len()).collect();
if shuffle {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
let actual_seed = seed.unwrap_or(0) + epoch as u64;
let mut rng = StdRng::seed_from_u64(actual_seed);
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
}
Self {
data,
labels,
batch_size,
current_index: 0,
indices,
}
}
}
impl<'a, T: Float> Iterator for DistributedDataIterator<'a, T> {
type Item = (Vec<&'a Tensor<T>>, Vec<&'a Tensor<T>>);
fn next(&mut self) -> Option<Self::Item> {
if self.current_index >= self.indices.len() {
return None;
}
let end_index = (self.current_index + self.batch_size).min(self.indices.len());
let batch_indices = &self.indices[self.current_index..end_index];
let batch_data: Vec<&Tensor<T>> = batch_indices.iter().map(|&i| &self.data[i]).collect();
let batch_labels: Vec<&Tensor<T>> =
batch_indices.iter().map(|&i| &self.labels[i]).collect();
self.current_index = end_index;
Some((batch_data, batch_labels))
}
}
pub struct DistributedSampler {
num_samples: usize,
num_replicas: usize,
rank: usize,
epoch: usize,
drop_last: bool,
seed: u64,
}
impl DistributedSampler {
pub fn new(
num_samples: usize,
num_replicas: Option<usize>,
rank: Option<usize>,
drop_last: bool,
seed: u64,
) -> RusTorchResult<Self> {
let state = get_distributed_state();
let state_guard = state.lock().unwrap();
let num_replicas = num_replicas
.or_else(|| state_guard.world_size())
.unwrap_or(1);
let rank = rank.or_else(|| state_guard.rank()).unwrap_or(0);
if rank >= num_replicas {
return Err(RusTorchError::ProcessGroupError(format!(
"Rank {} is greater than or equal to num_replicas {}",
rank, num_replicas
)));
}
Ok(Self {
num_samples,
num_replicas,
rank,
epoch: 0,
drop_last,
seed,
})
}
pub fn set_epoch(&mut self, epoch: usize) {
self.epoch = epoch;
}
pub fn sample_indices(&self) -> Vec<usize> {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
let mut rng = StdRng::seed_from_u64(self.seed + self.epoch as u64);
let mut indices: Vec<usize> = (0..self.num_samples).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
let total_size = if self.drop_last {
(self.num_samples / self.num_replicas) * self.num_replicas
} else {
self.num_samples.div_ceil(self.num_replicas) * self.num_replicas
};
while indices.len() < total_size {
let remaining = total_size - indices.len();
let copy_len = std::cmp::min(indices.len(), remaining);
let to_copy = indices[..copy_len].to_vec();
indices.extend_from_slice(&to_copy);
}
indices.truncate(total_size);
let num_samples_per_replica = total_size / self.num_replicas;
let start = self.rank * num_samples_per_replica;
let end = start + num_samples_per_replica;
indices[start..end].to_vec()
}
pub fn len(&self) -> usize {
if self.drop_last {
self.num_samples / self.num_replicas
} else {
self.num_samples.div_ceil(self.num_replicas)
}
}
pub fn is_empty(&self) -> bool {
self.num_samples == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::Linear;
#[test]
fn test_gradient_sync_strategy() {
let strategies = [
GradientSyncStrategy::Synchronous,
GradientSyncStrategy::Asynchronous,
GradientSyncStrategy::LocalSGD { sync_frequency: 10 },
];
for strategy in &strategies {
match strategy {
GradientSyncStrategy::Synchronous => {} GradientSyncStrategy::Asynchronous => {} GradientSyncStrategy::LocalSGD { sync_frequency } => assert!(*sync_frequency > 0),
}
}
}
#[test]
fn test_distributed_data_loader() {
let data = vec![
Tensor::<f32>::ones(&[10, 5]),
Tensor::<f32>::ones(&[10, 5]),
Tensor::<f32>::ones(&[10, 5]),
];
let labels = vec![
Tensor::<f32>::zeros(&[10, 1]),
Tensor::<f32>::zeros(&[10, 1]),
Tensor::<f32>::zeros(&[10, 1]),
];
let loader = DistributedDataLoader::new(data, labels, 2, true, Some(42).into());
assert_eq!(loader.len(), 2);
let mut iter = loader.iter();
let batch = iter.next();
assert!(batch.is_some());
let (batch_data, batch_labels) = batch.unwrap();
assert_eq!(batch_data.len(), 2);
assert_eq!(batch_labels.len(), 2);
}
#[test]
fn test_distributed_sampler() {
let sampler = DistributedSampler::new(100, Some(4), Some(0), false, 42);
assert!(sampler.is_ok());
let sampler = sampler.unwrap();
let indices = sampler.sample_indices();
assert_eq!(indices.len(), 25);
for &idx in &indices {
assert!(idx < 100);
}
}
#[test]
fn test_data_parallel_creation() {
let linear = Linear::<f32>::new(10, 5);
let devices = vec![DeviceType::Cpu, DeviceType::Cpu];
let dp = DataParallel::new(linear, devices);
assert_eq!(dp.num_devices(), 2);
}
}