Skip to main content

ferrotorch_distributed/
ddp.rs

1//! Distributed Data Parallel (DDP) wrapper.
2//!
3//! [`DDP`] wraps a [`Module`] and synchronizes parameter gradients across
4//! all ranks after each backward pass by allreducing them. This is the
5//! standard data-parallel training strategy: each rank processes a
6//! different mini-batch, then gradients are averaged so that all replicas
7//! stay in sync.
8
9use std::sync::Arc;
10
11use crate::backend::Backend;
12use crate::collective::{ReduceOp, allreduce};
13use ferrotorch_core::storage::TensorStorage;
14use ferrotorch_core::{FerrotorchResult, Float, Tensor};
15use ferrotorch_nn::{Module, Parameter};
16
17/// Default bucket size for gradient bucketing (25 MB).
18const DEFAULT_BUCKET_SIZE_BYTES: usize = 25 * 1024 * 1024;
19
20/// Distributed Data Parallel module wrapper.
21///
22/// Wraps an inner [`Module`] and provides [`sync_gradients`] to allreduce
23/// parameter gradients across all ranks. Parameters are grouped into
24/// buckets (default 25 MB) for efficient communication.
25///
26/// ```ignore
27/// let ddp = DDP::new(model, backend);
28///
29/// loop {
30///     let output = ddp.module().forward(&input)?;
31///     let loss = criterion.forward(&output, &target)?;
32///     ferrotorch_core::backward(&loss)?;
33///     ddp.sync_gradients()?;
34///     optimizer.step()?;
35///     optimizer.zero_grad()?;
36/// }
37/// ```
38pub struct DDP<M: Module<T>, T: Float> {
39    module: M,
40    backend: Arc<dyn Backend>,
41    /// Bucket assignments: `buckets[i]` is a list of parameter indices in bucket i.
42    buckets: Vec<Vec<usize>>,
43    _marker: std::marker::PhantomData<T>,
44}
45
46impl<M: Module<T>, T: Float> DDP<M, T> {
47    /// Wrap a module for distributed data-parallel training.
48    ///
49    /// Parameters are assigned to ~25 MB gradient buckets in reverse order
50    /// (matching PyTorch's convention — backward computes gradients in
51    /// reverse parameter order, so the first bucket fills first).
52    pub fn new(module: M, backend: Arc<dyn Backend>) -> Self {
53        Self::with_bucket_size(module, backend, DEFAULT_BUCKET_SIZE_BYTES)
54    }
55
56    /// Wrap a module with a custom bucket size (in bytes).
57    pub fn with_bucket_size(
58        module: M,
59        backend: Arc<dyn Backend>,
60        bucket_size_bytes: usize,
61    ) -> Self {
62        let params = module.parameters();
63        let buckets = compute_buckets::<T>(&params, bucket_size_bytes);
64        Self {
65            module,
66            backend,
67            buckets,
68            _marker: std::marker::PhantomData,
69        }
70    }
71
72    /// Immutable access to the inner module (for forward pass, etc.).
73    pub fn module(&self) -> &M {
74        &self.module
75    }
76
77    /// Mutable access to the inner module (for train/eval mode, etc.).
78    pub fn module_mut(&mut self) -> &mut M {
79        &mut self.module
80    }
81
82    /// Consume the DDP wrapper and return the inner module.
83    pub fn into_inner(self) -> M {
84        self.module
85    }
86
87    /// The backend used for communication.
88    pub fn backend(&self) -> &Arc<dyn Backend> {
89        &self.backend
90    }
91
92    /// Allreduce parameter gradients across ranks using gradient bucketing.
93    ///
94    /// Parameters are grouped into ~25 MB buckets. Each bucket is
95    /// allreduced independently as a single flat buffer. This enables
96    /// future overlapped communication where the first bucket can start
97    /// transferring while backward is still computing later gradients.
98    ///
99    /// Call this after `backward()` and before `optimizer.step()`.
100    pub fn sync_gradients(&self) -> FerrotorchResult<()> {
101        let params = self.module.parameters();
102        for bucket in &self.buckets {
103            sync_one_bucket::<T>(bucket, &params, self.backend.as_ref())?;
104        }
105        Ok(())
106    }
107
108    /// Allreduce parameter gradients with bucket-level parallelism.
109    ///
110    /// Like [`sync_gradients`], but processes all buckets concurrently using
111    /// `std::thread::scope`. Each bucket's allreduce runs in its own thread,
112    /// overlapping communication across buckets. All threads complete before
113    /// this method returns.
114    ///
115    /// This provides communication/computation overlap when backward and
116    /// sync run on different threads, and communication overlap across
117    /// buckets even in the synchronous case.
118    pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()> {
119        let params = self.module.parameters();
120
121        // Collect errors from threads.
122        let errors: std::sync::Mutex<Vec<ferrotorch_core::error::FerrotorchError>> =
123            std::sync::Mutex::new(Vec::new());
124
125        std::thread::scope(|s| {
126            for bucket in &self.buckets {
127                let params_ref = &params;
128                let backend_ref = self.backend.as_ref();
129                let errors_ref = &errors;
130
131                s.spawn(move || {
132                    let result = sync_one_bucket::<T>(bucket, params_ref, backend_ref);
133                    if let Err(e) = result {
134                        errors_ref.lock().unwrap().push(e);
135                    }
136                });
137            }
138        });
139
140        let errs = errors.into_inner().unwrap();
141        if let Some(e) = errs.into_iter().next() {
142            return Err(e);
143        }
144
145        Ok(())
146    }
147
148    /// Broadcast model parameters from `root` rank to all other ranks.
149    ///
150    /// Ensures all ranks start with identical weights. Call once before
151    /// the training loop begins.
152    ///
153    /// # Warning
154    ///
155    /// This replaces the `Parameter` objects in the module. Any optimizer
156    /// that holds references to the old parameters must be re-initialized
157    /// after calling this method, otherwise optimizer state (momentum,
158    /// adaptive learning rates, etc.) will refer to stale parameters.
159    pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()> {
160        let params_mut = self.module.parameters_mut();
161
162        for param in params_mut {
163            let tensor = param.tensor().clone();
164            let synced = crate::collective::broadcast(&tensor, self.backend.as_ref(), root)?;
165            *param = Parameter::new(synced);
166        }
167
168        Ok(())
169    }
170}
171
172// Forward the Module trait through to the inner module so DDP can be
173// used as a drop-in replacement.
174impl<M: Module<T>, T: Float> Module<T> for DDP<M, T> {
175    fn forward(
176        &self,
177        input: &ferrotorch_core::Tensor<T>,
178    ) -> FerrotorchResult<ferrotorch_core::Tensor<T>> {
179        self.module.forward(input)
180    }
181
182    fn parameters(&self) -> Vec<&Parameter<T>> {
183        self.module.parameters()
184    }
185
186    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
187        self.module.parameters_mut()
188    }
189
190    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
191        self.module.named_parameters()
192    }
193
194    fn train(&mut self) {
195        self.module.train();
196    }
197
198    fn eval(&mut self) {
199        self.module.eval();
200    }
201
202    fn is_training(&self) -> bool {
203        self.module.is_training()
204    }
205}
206
207/// Assign parameters to gradient buckets.
208///
209/// Parameters are added in reverse order (matching PyTorch — backward
210/// computes gradients in reverse parameter order, so the last parameters
211/// fill the first bucket). Each bucket holds at most `bucket_size_bytes`
212/// of gradient data.
213fn compute_buckets<T: Float>(
214    params: &[&Parameter<T>],
215    bucket_size_bytes: usize,
216) -> Vec<Vec<usize>> {
217    let elem_size = std::mem::size_of::<T>();
218    let mut buckets: Vec<Vec<usize>> = Vec::new();
219    let mut current_bucket: Vec<usize> = Vec::new();
220    let mut current_bytes: usize = 0;
221
222    // Reverse order: last parameter first.
223    for i in (0..params.len()).rev() {
224        let param_bytes = params[i].tensor().numel() * elem_size;
225
226        if !current_bucket.is_empty() && current_bytes + param_bytes > bucket_size_bytes {
227            buckets.push(current_bucket);
228            current_bucket = Vec::new();
229            current_bytes = 0;
230        }
231
232        current_bucket.push(i);
233        current_bytes += param_bytes;
234    }
235
236    if !current_bucket.is_empty() {
237        buckets.push(current_bucket);
238    }
239
240    buckets
241}
242
243/// Allreduce a single bucket's gradients.
244///
245/// Builds a flat buffer from the bucket's parameter gradients, allreduces,
246/// and scatters the result back. Used by both `sync_gradients` (serial)
247/// and `overlapped_sync_gradients` (parallel).
248fn sync_one_bucket<T: Float>(
249    bucket: &[usize],
250    params: &[&Parameter<T>],
251    backend: &dyn Backend,
252) -> FerrotorchResult<()> {
253    let mut flat_data: Vec<T> = Vec::new();
254    let mut param_numels: Vec<usize> = Vec::new();
255
256    for &pi in bucket {
257        let numel = params[pi].tensor().numel();
258        param_numels.push(numel);
259
260        let grad = params[pi].tensor().grad()?;
261        match grad {
262            Some(g) => flat_data.extend(g.data_vec()?),
263            None => {
264                flat_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), numel));
265            }
266        }
267    }
268
269    if flat_data.is_empty() {
270        return Ok(());
271    }
272
273    let flat_tensor = Tensor::from_storage(
274        TensorStorage::cpu(flat_data),
275        vec![param_numels.iter().sum()],
276        false,
277    )?;
278    let synced = allreduce(&flat_tensor, backend, ReduceOp::Mean)?;
279    let synced_data = synced.data()?;
280
281    let mut offset = 0;
282    for (&pi, &numel) in bucket.iter().zip(param_numels.iter()) {
283        let grad_slice = &synced_data[offset..offset + numel];
284        let grad_tensor = Tensor::from_storage(
285            TensorStorage::cpu(grad_slice.to_vec()),
286            params[pi].tensor().shape().to_vec(),
287            false,
288        )?;
289        params[pi].tensor().set_grad(Some(grad_tensor))?;
290        offset += numel;
291    }
292
293    Ok(())
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::backend::SimulatedBackend;
300    use ferrotorch_core::storage::TensorStorage;
301    use ferrotorch_core::{FerrotorchResult, Tensor};
302    use ferrotorch_nn::Parameter;
303    use std::thread;
304
305    /// Minimal module with one parameter for testing DDP.
306    struct TestModule<T: Float> {
307        weight: Parameter<T>,
308        training: bool,
309    }
310
311    impl<T: Float> TestModule<T> {
312        fn new(data: &[T]) -> FerrotorchResult<Self> {
313            Ok(Self {
314                weight: Parameter::from_slice(data, &[data.len()])?,
315                training: true,
316            })
317        }
318    }
319
320    impl<T: Float> Module<T> for TestModule<T> {
321        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
322            Ok(input.clone())
323        }
324
325        fn parameters(&self) -> Vec<&Parameter<T>> {
326            vec![&self.weight]
327        }
328
329        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
330            vec![&mut self.weight]
331        }
332
333        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
334            vec![("weight".into(), &self.weight)]
335        }
336
337        fn train(&mut self) {
338            self.training = true;
339        }
340
341        fn eval(&mut self) {
342            self.training = false;
343        }
344
345        fn is_training(&self) -> bool {
346            self.training
347        }
348    }
349
350    #[test]
351    fn test_ddp_sync_gradients() {
352        // 4 ranks. Each rank's parameter has a gradient equal to [rank, rank, rank].
353        // After sync_gradients (mean), all should have [1.5, 1.5, 1.5].
354        let group = SimulatedBackend::create_group(4).unwrap();
355        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
356
357        let handles: Vec<_> = arcs
358            .iter()
359            .cloned()
360            .map(|b| {
361                thread::spawn(move || {
362                    let rank = b.rank();
363                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0]).unwrap();
364                    let ddp = DDP::new(model, b);
365
366                    // Simulate a backward pass by manually setting gradients.
367                    let grad_val = rank as f32;
368                    let grad = Tensor::from_storage(
369                        TensorStorage::cpu(vec![grad_val, grad_val, grad_val]),
370                        vec![3],
371                        false,
372                    )
373                    .unwrap();
374                    ddp.module().weight.tensor().set_grad(Some(grad)).unwrap();
375
376                    // Sync gradients across all ranks.
377                    ddp.sync_gradients().unwrap();
378
379                    // All ranks should now have mean gradient = (0+1+2+3)/4 = 1.5
380                    let synced_grad = ddp.module().weight.tensor().grad().unwrap().unwrap();
381                    let data = synced_grad.data().unwrap();
382                    for &v in data {
383                        assert!((v - 1.5).abs() < 1e-5, "rank {rank}: expected 1.5, got {v}");
384                    }
385                })
386            })
387            .collect();
388
389        for h in handles {
390            h.join().unwrap();
391        }
392    }
393
394    #[test]
395    fn test_ddp_broadcast_parameters() {
396        // Rank 0 has weights [10, 20, 30]. Other ranks have [0, 0, 0].
397        // After broadcast_parameters(0), all should have [10, 20, 30].
398        let group = SimulatedBackend::create_group(3).unwrap();
399        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
400
401        let handles: Vec<_> = arcs
402            .iter()
403            .cloned()
404            .map(|b| {
405                thread::spawn(move || {
406                    let rank = b.rank();
407                    let data: Vec<f32> = if rank == 0 {
408                        vec![10.0, 20.0, 30.0]
409                    } else {
410                        vec![0.0, 0.0, 0.0]
411                    };
412                    let model = TestModule::<f32>::new(&data).unwrap();
413                    let mut ddp = DDP::new(model, b);
414
415                    ddp.broadcast_parameters(0).unwrap();
416
417                    let param_data = ddp.module().weight.tensor().data().unwrap();
418                    assert!(
419                        (param_data[0] - 10.0).abs() < 1e-5,
420                        "rank {rank}: expected 10.0, got {}",
421                        param_data[0]
422                    );
423                    assert!(
424                        (param_data[1] - 20.0).abs() < 1e-5,
425                        "rank {rank}: expected 20.0, got {}",
426                        param_data[1]
427                    );
428                    assert!(
429                        (param_data[2] - 30.0).abs() < 1e-5,
430                        "rank {rank}: expected 30.0, got {}",
431                        param_data[2]
432                    );
433                })
434            })
435            .collect();
436
437        for h in handles {
438            h.join().unwrap();
439        }
440    }
441
442    #[test]
443    fn test_ddp_delegates_module_trait() {
444        let group = SimulatedBackend::create_group(1).unwrap();
445        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
446        let model = TestModule::<f32>::new(&[1.0, 2.0]).unwrap();
447        let mut ddp = DDP::new(model, b);
448
449        // Module trait methods should delegate.
450        assert!(ddp.is_training());
451        ddp.eval();
452        assert!(!ddp.is_training());
453        ddp.train();
454        assert!(ddp.is_training());
455
456        assert_eq!(ddp.parameters().len(), 1);
457        assert_eq!(ddp.named_parameters()[0].0, "weight");
458    }
459}