use super::{api, DistributedDataParallelTrait, DistributedScalar, ReduceOp};
use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::nn::Module;
use crate::tensor::Tensor;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
pub struct SimpleDistributedDataParallel<T, M>
where
T: DistributedScalar,
M: Module<T> + Send + Sync + 'static,
{
module: Arc<Mutex<M>>,
device_ids: Vec<usize>,
sync_gradients: bool,
_phantom: PhantomData<T>,
}
impl<T, M> SimpleDistributedDataParallel<T, M>
where
T: DistributedScalar,
M: Module<T> + Send + Sync + 'static,
{
pub fn new(module: M, device_ids: Option<Vec<usize>>) -> RusTorchResult<Self> {
if !api::is_initialized() {
return Err(RusTorchError::distributed(
"Distributed not initialized. Call distributed::init_process_group() first.",
));
}
let device_ids = device_ids.unwrap_or_else(|| vec![0]);
Ok(Self {
module: Arc::new(Mutex::new(module)),
device_ids,
sync_gradients: true,
_phantom: PhantomData,
})
}
pub fn forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
let module = self.module.lock().unwrap();
let output = module.forward(input);
if self.sync_gradients {
}
Ok(output)
}
pub fn sync_gradients(&self) -> RusTorchResult<()> {
Ok(())
}
pub fn module(&self) -> Arc<Mutex<M>> {
Arc::clone(&self.module)
}
pub fn device_ids(&self) -> &[usize] {
&self.device_ids
}
pub fn set_gradient_sync(&mut self, enabled: bool) {
self.sync_gradients = enabled;
}
}
impl<T, M> DistributedDataParallelTrait<T> for SimpleDistributedDataParallel<T, M>
where
T: DistributedScalar,
M: Module<T> + Send + Sync + 'static,
{
fn device_ids(&self) -> &[usize] {
&self.device_ids
}
fn distributed_forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
self.forward(input)
}
fn sync_gradients(&self) -> RusTorchResult<()> {
self.sync_gradients()
}
}
pub fn wrap_simple<T, M>(
module: M,
device_ids: Option<Vec<usize>>,
) -> RusTorchResult<SimpleDistributedDataParallel<T, M>>
where
T: DistributedScalar,
M: Module<T> + Send + Sync + 'static,
{
SimpleDistributedDataParallel::new(module, device_ids)
}
#[cfg(test)]
mod tests {
use super::super::DistributedBackend;
use super::*;
use crate::nn::Linear;
#[test]
fn test_simple_ddp_creation() {
let linear: Linear<f32> = Linear::new(10, 5);
let ddp_result = SimpleDistributedDataParallel::new(linear, Some(vec![0]));
assert!(ddp_result.is_err());
}
#[test]
fn test_device_ids() {
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29510");
let _ = api::init_process_group(DistributedBackend::TCP, None, None, None, None);
let linear: Linear<f32> = Linear::new(5, 3);
if let Ok(ddp) = SimpleDistributedDataParallel::new(linear, Some(vec![0, 1])) {
assert_eq!(ddp.device_ids(), &[0, 1]);
}
let _ = api::destroy_process_group();
}
}