use std::sync::Arc;
use ferrotorch_core::{Float, FerrotorchResult};
use ferrotorch_nn::{Module, Parameter};
use crate::backend::Backend;
use crate::collective::{allreduce, ReduceOp};
pub struct DDP<M: Module<T>, T: Float> {
module: M,
backend: Arc<dyn Backend>,
_marker: std::marker::PhantomData<T>,
}
impl<M: Module<T>, T: Float> DDP<M, T> {
pub fn new(module: M, backend: Arc<dyn Backend>) -> Self {
Self {
module,
backend,
_marker: std::marker::PhantomData,
}
}
pub fn module(&self) -> &M {
&self.module
}
pub fn module_mut(&mut self) -> &mut M {
&mut self.module
}
pub fn into_inner(self) -> M {
self.module
}
pub fn backend(&self) -> &Arc<dyn Backend> {
&self.backend
}
pub fn sync_gradients(&self) -> FerrotorchResult<()> {
let params = self.module.parameters();
for param in params {
if let Some(grad) = param.tensor().grad()? {
let synced = allreduce(&grad, self.backend.as_ref(), ReduceOp::Mean)?;
param.tensor().set_grad(Some(synced))?;
}
}
Ok(())
}
pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()> {
let params_mut = self.module.parameters_mut();
for param in params_mut {
let tensor = param.tensor().clone();
let synced = crate::collective::broadcast(&tensor, self.backend.as_ref(), root)?;
*param = Parameter::new(synced);
}
Ok(())
}
}
impl<M: Module<T>, T: Float> Module<T> for DDP<M, T> {
fn forward(
&self,
input: &ferrotorch_core::Tensor<T>,
) -> FerrotorchResult<ferrotorch_core::Tensor<T>> {
self.module.forward(input)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
self.module.parameters()
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
self.module.parameters_mut()
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
self.module.named_parameters()
}
fn train(&mut self) {
self.module.train();
}
fn eval(&mut self) {
self.module.eval();
}
fn is_training(&self) -> bool {
self.module.is_training()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SimulatedBackend;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchResult, Tensor};
use ferrotorch_nn::Parameter;
use std::thread;
struct TestModule<T: Float> {
weight: Parameter<T>,
training: bool,
}
impl<T: Float> TestModule<T> {
fn new(data: &[T]) -> FerrotorchResult<Self> {
Ok(Self {
weight: Parameter::from_slice(data, &[data.len()])?,
training: true,
})
}
}
impl<T: Float> Module<T> for TestModule<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
Ok(input.clone())
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![&self.weight]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![&mut self.weight]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![("weight".into(), &self.weight)]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[test]
fn test_ddp_sync_gradients() {
let group = SimulatedBackend::create_group(4).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0]).unwrap();
let ddp = DDP::new(model, b);
let grad_val = rank as f32;
let grad = Tensor::from_storage(
TensorStorage::cpu(vec![grad_val, grad_val, grad_val]),
vec![3],
false,
)
.unwrap();
ddp.module().weight.tensor().set_grad(Some(grad)).unwrap();
ddp.sync_gradients().unwrap();
let synced_grad = ddp.module().weight.tensor().grad().unwrap().unwrap();
let data = synced_grad.data().unwrap();
for &v in data {
assert!(
(v - 1.5).abs() < 1e-5,
"rank {rank}: expected 1.5, got {v}"
);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_ddp_broadcast_parameters() {
let group = SimulatedBackend::create_group(3).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![10.0, 20.0, 30.0]
} else {
vec![0.0, 0.0, 0.0]
};
let model = TestModule::<f32>::new(&data).unwrap();
let mut ddp = DDP::new(model, b);
ddp.broadcast_parameters(0).unwrap();
let param_data = ddp.module().weight.tensor().data().unwrap();
assert!(
(param_data[0] - 10.0).abs() < 1e-5,
"rank {rank}: expected 10.0, got {}",
param_data[0]
);
assert!(
(param_data[1] - 20.0).abs() < 1e-5,
"rank {rank}: expected 20.0, got {}",
param_data[1]
);
assert!(
(param_data[2] - 30.0).abs() < 1e-5,
"rank {rank}: expected 30.0, got {}",
param_data[2]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_ddp_delegates_module_trait() {
let group = SimulatedBackend::create_group(1).unwrap();
let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
let model = TestModule::<f32>::new(&[1.0, 2.0]).unwrap();
let mut ddp = DDP::new(model, b);
assert!(ddp.is_training());
ddp.eval();
assert!(!ddp.is_training());
ddp.train();
assert!(ddp.is_training());
assert_eq!(ddp.parameters().len(), 1);
assert_eq!(ddp.named_parameters()[0].0, "weight");
}
}