use super::{
api, DistributedDataParallelTrait, DistributedOps, DistributedScalar, ProcessGroup, ReduceOp,
};
use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::nn::Module;
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct DistributedDataParallel<T: DistributedScalar, M: Module<T>> {
module: Arc<Mutex<M>>,
process_group: Option<ProcessGroup>,
device_ids: Vec<usize>,
output_device: Option<usize>,
bucket_cap_mb: usize,
find_unused_parameters: bool,
gradient_as_bucket_view: bool,
static_graph: bool,
gradient_state: Arc<Mutex<GradientState<T>>>,
}
#[derive(Debug)]
struct GradientState<T: DistributedScalar> {
accumulated_grads: HashMap<String, Tensor<T>>,
ready_for_sync: bool,
buckets: Vec<GradientBucket<T>>,
}
#[derive(Debug, Clone)]
struct GradientBucket<T: DistributedScalar> {
parameters: Vec<String>,
gradient: Option<Tensor<T>>,
size_bytes: usize,
}
impl<T: DistributedScalar, M: Module<T> + Send + Sync + 'static> DistributedDataParallel<T, M> {
pub fn new(
module: M,
device_ids: Option<Vec<usize>>,
output_device: Option<usize>,
dim: Option<usize>,
broadcast_buffers: bool,
process_group: Option<ProcessGroup>,
bucket_cap_mb: Option<usize>,
find_unused_parameters: Option<bool>,
check_reduction: Option<bool>,
gradient_as_bucket_view: Option<bool>,
static_graph: Option<bool>,
) -> RusTorchResult<Self> {
if !api::is_initialized() {
return Err(RusTorchError::distributed(
"Distributed process group not initialized. Call distributed::init_process_group() first."
));
}
let device_ids = device_ids.unwrap_or_else(|| vec![0]);
let bucket_cap_mb = bucket_cap_mb.unwrap_or(25); let find_unused_parameters = find_unused_parameters.unwrap_or(false);
let gradient_as_bucket_view = gradient_as_bucket_view.unwrap_or(false);
let static_graph = static_graph.unwrap_or(false);
let gradient_state = Arc::new(Mutex::new(GradientState {
accumulated_grads: HashMap::new(),
ready_for_sync: false,
buckets: Vec::new(),
}));
let _ = (dim, broadcast_buffers, check_reduction);
Ok(Self {
module: Arc::new(Mutex::new(module)),
process_group,
device_ids,
output_device,
bucket_cap_mb,
find_unused_parameters,
gradient_as_bucket_view,
static_graph,
gradient_state,
})
}
pub fn forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
let module = self.module.lock().unwrap();
let output = module.forward(input);
self.register_grad_hooks()?;
Ok(output)
}
fn register_grad_hooks(&self) -> RusTorchResult<()> {
Ok(())
}
pub fn sync_gradients(&self) -> RusTorchResult<()> {
let module = self.module.lock().unwrap();
let parameters = module.parameters();
for param in parameters {
let grad_lock = param.grad();
let mut grad_guard = grad_lock.write().unwrap();
if let Some(ref mut grad) = *grad_guard {
api::all_reduce(grad, ReduceOp::Average, self.process_group.as_ref(), false)?;
}
}
Ok(())
}
pub fn module(&self) -> Arc<Mutex<M>> {
Arc::clone(&self.module)
}
pub fn device_ids(&self) -> &[usize] {
&self.device_ids
}
pub fn is_ddp_module() -> bool {
true
}
}
impl<T: DistributedScalar, M: Module<T> + Send + Sync + 'static> Module<T>
for DistributedDataParallel<T, M>
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input)
.unwrap_or_else(|_| Variable::new(Tensor::zeros(&[1]), false))
}
fn parameters(&self) -> Vec<Variable<T>> {
let module = self.module.lock().unwrap();
module.parameters()
}
fn eval(&mut self) {
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T: DistributedScalar, M: Module<T> + Send + Sync + 'static> DistributedDataParallelTrait<T>
for DistributedDataParallel<T, M>
{
fn device_ids(&self) -> &[usize] {
&self.device_ids
}
fn distributed_forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
self.forward(input)
}
fn sync_gradients(&self) -> RusTorchResult<()> {
self.sync_gradients()
}
}
impl<T: DistributedScalar> GradientState<T> {
fn create_buckets(&mut self, bucket_size_mb: usize) -> RusTorchResult<()> {
let bucket_size_bytes = bucket_size_mb * 1024 * 1024;
let mut current_bucket = GradientBucket {
parameters: Vec::new(),
gradient: None,
size_bytes: 0,
};
for param_name in self.accumulated_grads.keys() {
if let Some(grad) = self.accumulated_grads.get(param_name) {
let grad_size = grad.numel() * std::mem::size_of::<T>();
if current_bucket.size_bytes + grad_size > bucket_size_bytes
&& !current_bucket.parameters.is_empty()
{
self.buckets.push(current_bucket.clone());
current_bucket = GradientBucket {
parameters: Vec::new(),
gradient: None,
size_bytes: 0,
};
}
current_bucket.parameters.push(param_name.clone());
current_bucket.size_bytes += grad_size;
}
}
if !current_bucket.parameters.is_empty() {
self.buckets.push(current_bucket);
}
Ok(())
}
}
pub fn wrap_module<T: DistributedScalar, M: Module<T> + Send + Sync + 'static>(
module: M,
device_ids: Option<Vec<usize>>,
) -> RusTorchResult<DistributedDataParallel<T, M>> {
DistributedDataParallel::new(
module, device_ids, None, None, true, None, None, None, None, None, None, )
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::Linear;
#[test]
fn test_ddp_creation() {
let linear: Linear<f32> = Linear::new(10, 5);
let device_ids = vec![0];
let ddp_result = DistributedDataParallel::new(
linear,
Some(device_ids),
None,
None,
true,
None,
None,
None,
None,
None,
None,
);
assert!(ddp_result.is_err());
}
#[test]
fn test_is_ddp_module() {
assert!(DistributedDataParallel::<f32, Linear<f32>>::is_ddp_module());
}
}