Skip to main content

ferrotorch_distributed/
ddp.rs

1//! Distributed Data Parallel (DDP) wrapper.
2//!
3//! [`DDP`] wraps a [`Module`] and synchronizes parameter gradients across
4//! all ranks after each backward pass by allreducing them. This is the
5//! standard data-parallel training strategy: each rank processes a
6//! different mini-batch, then gradients are averaged so that all replicas
7//! stay in sync.
8//!
9//! ## REQ status (per `.design/ferrotorch-distributed/ddp.md`)
10//!
11//! Full evidence rows (impl + non-test production consumer + upstream
12//! cites) live in the design doc; this synopsis is a one-line summary per
13//! REQ.
14//!
15//! | REQ | Status | Evidence |
16//! |---|---|---|
17//! | REQ-1 (`DDP<M, T>` struct) | SHIPPED | `pub struct DDP` in `ddp.rs` mirrors `class DistributedDataParallel` in `torch/nn/parallel/distributed.py`; consumer `pub use ddp::DDP` in `lib.rs` |
18//! | REQ-2 (`new` + 25 MB default bucket) | SHIPPED | `pub fn new` + `const DEFAULT_BUCKET_SIZE_BYTES` in `ddp.rs` mirror `_DEFAULT_BUCKET_CAP_MB = 25` in `torch/nn/parallel/distributed.py`; consumer `with_bucket_size` (same file) plus `lib.rs` re-export |
19//! | REQ-3 (`with_bucket_size`) | SHIPPED | `pub fn with_bucket_size` in `ddp.rs` mirrors `bucket_cap_mb` kwarg upstream; consumer `pub fn new` calls it with the default |
20//! | REQ-4 (`sync_gradients` + overlapped variant) | SHIPPED | `pub fn sync_gradients` and `pub fn overlapped_sync_gradients` in `ddp.rs`; consumer of `crate::collective::allreduce` via `fn sync_one_bucket` |
21//! | REQ-5 (`broadcast_parameters`) | SHIPPED | `pub fn broadcast_parameters` in `ddp.rs`; consumer of `crate::collective::broadcast` |
22//! | REQ-6 (`impl Module` forwarding) | SHIPPED | `impl Module<T> for DDP<M, T>` in `ddp.rs`; consumer `lib.rs` re-export makes `DDP` drop-in wherever `impl Module<T>` is accepted |
23//! | REQ-7 (reverse-order bucket assignment) | SHIPPED | `fn compute_buckets` in `ddp.rs` iterates `(0..params.len()).rev()` per `torch/csrc/distributed/c10d/reducer.cpp` reducer comment; consumer `pub fn with_bucket_size` (same file) |
24//! | REQ-8 (`sync_one_bucket` helper) | SHIPPED | `fn sync_one_bucket` in `ddp.rs`; consumer `pub fn sync_gradients` and `pub fn overlapped_sync_gradients` (same file) |
25
26use std::sync::Arc;
27
28use crate::backend::Backend;
29use crate::collective::{ReduceOp, allreduce};
30use ferrotorch_core::storage::TensorStorage;
31use ferrotorch_core::{FerrotorchResult, Float, Tensor};
32use ferrotorch_nn::{Module, Parameter};
33
34/// Default bucket size for gradient bucketing (25 MB).
35const DEFAULT_BUCKET_SIZE_BYTES: usize = 25 * 1024 * 1024;
36
37/// Distributed Data Parallel module wrapper.
38///
39/// Wraps an inner [`Module`] and provides [`sync_gradients`] to allreduce
40/// parameter gradients across all ranks. Parameters are grouped into
41/// buckets (default 25 MB) for efficient communication.
42///
43/// ```ignore
44/// let ddp = DDP::new(model, backend);
45///
46/// loop {
47///     let output = ddp.module().forward(&input)?;
48///     let loss = criterion.forward(&output, &target)?;
49///     ferrotorch_core::backward(&loss)?;
50///     ddp.sync_gradients()?;
51///     optimizer.step()?;
52///     optimizer.zero_grad()?;
53/// }
54/// ```
55pub struct DDP<M: Module<T>, T: Float> {
56    module: M,
57    backend: Arc<dyn Backend>,
58    /// Bucket assignments: `buckets[i]` is a list of parameter indices in bucket i.
59    buckets: Vec<Vec<usize>>,
60    _marker: std::marker::PhantomData<T>,
61}
62
63impl<M: Module<T>, T: Float> DDP<M, T> {
64    /// Wrap a module for distributed data-parallel training.
65    ///
66    /// Parameters are assigned to ~25 MB gradient buckets in reverse order
67    /// (matching PyTorch's convention — backward computes gradients in
68    /// reverse parameter order, so the first bucket fills first).
69    pub fn new(module: M, backend: Arc<dyn Backend>) -> Self {
70        Self::with_bucket_size(module, backend, DEFAULT_BUCKET_SIZE_BYTES)
71    }
72
73    /// Wrap a module with a custom bucket size (in bytes).
74    pub fn with_bucket_size(
75        module: M,
76        backend: Arc<dyn Backend>,
77        bucket_size_bytes: usize,
78    ) -> Self {
79        let params = module.parameters();
80        let buckets = compute_buckets::<T>(&params, bucket_size_bytes);
81        Self {
82            module,
83            backend,
84            buckets,
85            _marker: std::marker::PhantomData,
86        }
87    }
88
89    /// Immutable access to the inner module (for forward pass, etc.).
90    pub fn module(&self) -> &M {
91        &self.module
92    }
93
94    /// Mutable access to the inner module (for train/eval mode, etc.).
95    pub fn module_mut(&mut self) -> &mut M {
96        &mut self.module
97    }
98
99    /// Consume the DDP wrapper and return the inner module.
100    pub fn into_inner(self) -> M {
101        self.module
102    }
103
104    /// The backend used for communication.
105    pub fn backend(&self) -> &Arc<dyn Backend> {
106        &self.backend
107    }
108
109    /// Allreduce parameter gradients across ranks using gradient bucketing.
110    ///
111    /// Parameters are grouped into ~25 MB buckets. Each bucket is
112    /// allreduced independently as a single flat buffer. This enables
113    /// future overlapped communication where the first bucket can start
114    /// transferring while backward is still computing later gradients.
115    ///
116    /// Call this after `backward()` and before `optimizer.step()`.
117    pub fn sync_gradients(&self) -> FerrotorchResult<()> {
118        let params = self.module.parameters();
119        for bucket in &self.buckets {
120            sync_one_bucket::<T>(bucket, &params, self.backend.as_ref())?;
121        }
122        Ok(())
123    }
124
125    /// Allreduce parameter gradients with bucket-level parallelism.
126    ///
127    /// Like [`sync_gradients`], but processes all buckets concurrently using
128    /// `std::thread::scope`. Each bucket's allreduce runs in its own thread,
129    /// overlapping communication across buckets. All threads complete before
130    /// this method returns.
131    ///
132    /// This provides communication/computation overlap when backward and
133    /// sync run on different threads, and communication overlap across
134    /// buckets even in the synchronous case.
135    pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()> {
136        let params = self.module.parameters();
137
138        // Collect errors from threads.
139        let errors: std::sync::Mutex<Vec<ferrotorch_core::error::FerrotorchError>> =
140            std::sync::Mutex::new(Vec::new());
141
142        std::thread::scope(|s| {
143            for bucket in &self.buckets {
144                let params_ref = &params;
145                let backend_ref = self.backend.as_ref();
146                let errors_ref = &errors;
147
148                s.spawn(move || {
149                    let result = sync_one_bucket::<T>(bucket, params_ref, backend_ref);
150                    if let Err(e) = result {
151                        errors_ref.lock().unwrap().push(e);
152                    }
153                });
154            }
155        });
156
157        let errs = errors.into_inner().unwrap();
158        if let Some(e) = errs.into_iter().next() {
159            return Err(e);
160        }
161
162        Ok(())
163    }
164
165    /// Broadcast model parameters from `root` rank to all other ranks.
166    ///
167    /// Ensures all ranks start with identical weights. Call once before
168    /// the training loop begins.
169    ///
170    /// # Warning
171    ///
172    /// This replaces the `Parameter` objects in the module. Any optimizer
173    /// that holds references to the old parameters must be re-initialized
174    /// after calling this method, otherwise optimizer state (momentum,
175    /// adaptive learning rates, etc.) will refer to stale parameters.
176    pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()> {
177        let params_mut = self.module.parameters_mut();
178
179        for param in params_mut {
180            let tensor = param.tensor().clone();
181            let synced = crate::collective::broadcast(&tensor, self.backend.as_ref(), root)?;
182            *param = Parameter::new(synced);
183        }
184
185        Ok(())
186    }
187}
188
189// Forward the Module trait through to the inner module so DDP can be
190// used as a drop-in replacement.
191impl<M: Module<T>, T: Float> Module<T> for DDP<M, T> {
192    fn forward(
193        &self,
194        input: &ferrotorch_core::Tensor<T>,
195    ) -> FerrotorchResult<ferrotorch_core::Tensor<T>> {
196        self.module.forward(input)
197    }
198
199    fn parameters(&self) -> Vec<&Parameter<T>> {
200        self.module.parameters()
201    }
202
203    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
204        self.module.parameters_mut()
205    }
206
207    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
208        self.module.named_parameters()
209    }
210
211    fn train(&mut self) {
212        self.module.train();
213    }
214
215    fn eval(&mut self) {
216        self.module.eval();
217    }
218
219    fn is_training(&self) -> bool {
220        self.module.is_training()
221    }
222}
223
224/// Assign parameters to gradient buckets.
225///
226/// Parameters are added in reverse order (matching PyTorch — backward
227/// computes gradients in reverse parameter order, so the last parameters
228/// fill the first bucket). Each bucket holds at most `bucket_size_bytes`
229/// of gradient data.
230fn compute_buckets<T: Float>(
231    params: &[&Parameter<T>],
232    bucket_size_bytes: usize,
233) -> Vec<Vec<usize>> {
234    let elem_size = std::mem::size_of::<T>();
235    let mut buckets: Vec<Vec<usize>> = Vec::new();
236    let mut current_bucket: Vec<usize> = Vec::new();
237    let mut current_bytes: usize = 0;
238
239    // Reverse order: last parameter first.
240    for i in (0..params.len()).rev() {
241        let param_bytes = params[i].tensor().numel() * elem_size;
242
243        if !current_bucket.is_empty() && current_bytes + param_bytes > bucket_size_bytes {
244            buckets.push(current_bucket);
245            current_bucket = Vec::new();
246            current_bytes = 0;
247        }
248
249        current_bucket.push(i);
250        current_bytes += param_bytes;
251    }
252
253    if !current_bucket.is_empty() {
254        buckets.push(current_bucket);
255    }
256
257    buckets
258}
259
260/// Allreduce a single bucket's gradients.
261///
262/// Builds a flat buffer from the bucket's parameter gradients, allreduces,
263/// and scatters the result back. Used by both `sync_gradients` (serial)
264/// and `overlapped_sync_gradients` (parallel).
265fn sync_one_bucket<T: Float>(
266    bucket: &[usize],
267    params: &[&Parameter<T>],
268    backend: &dyn Backend,
269) -> FerrotorchResult<()> {
270    let mut flat_data: Vec<T> = Vec::new();
271    let mut param_numels: Vec<usize> = Vec::new();
272
273    for &pi in bucket {
274        let numel = params[pi].tensor().numel();
275        param_numels.push(numel);
276
277        let grad = params[pi].tensor().grad()?;
278        match grad {
279            Some(g) => flat_data.extend(g.data_vec()?),
280            None => {
281                flat_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), numel));
282            }
283        }
284    }
285
286    if flat_data.is_empty() {
287        return Ok(());
288    }
289
290    let flat_tensor = Tensor::from_storage(
291        TensorStorage::cpu(flat_data),
292        vec![param_numels.iter().sum()],
293        false,
294    )?;
295    let synced = allreduce(&flat_tensor, backend, ReduceOp::Mean)?;
296    let synced_data = synced.data()?;
297
298    let mut offset = 0;
299    for (&pi, &numel) in bucket.iter().zip(param_numels.iter()) {
300        let grad_slice = &synced_data[offset..offset + numel];
301        let grad_tensor = Tensor::from_storage(
302            TensorStorage::cpu(grad_slice.to_vec()),
303            params[pi].tensor().shape().to_vec(),
304            false,
305        )?;
306        params[pi].tensor().set_grad(Some(grad_tensor))?;
307        offset += numel;
308    }
309
310    Ok(())
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::backend::SimulatedBackend;
317    use ferrotorch_core::storage::TensorStorage;
318    use ferrotorch_core::{FerrotorchResult, Tensor};
319    use ferrotorch_nn::Parameter;
320    use std::thread;
321
322    /// Minimal module with one parameter for testing DDP.
323    struct TestModule<T: Float> {
324        weight: Parameter<T>,
325        training: bool,
326    }
327
328    impl<T: Float> TestModule<T> {
329        fn new(data: &[T]) -> FerrotorchResult<Self> {
330            Ok(Self {
331                weight: Parameter::from_slice(data, &[data.len()])?,
332                training: true,
333            })
334        }
335    }
336
337    impl<T: Float> Module<T> for TestModule<T> {
338        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
339            Ok(input.clone())
340        }
341
342        fn parameters(&self) -> Vec<&Parameter<T>> {
343            vec![&self.weight]
344        }
345
346        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
347            vec![&mut self.weight]
348        }
349
350        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
351            vec![("weight".into(), &self.weight)]
352        }
353
354        fn train(&mut self) {
355            self.training = true;
356        }
357
358        fn eval(&mut self) {
359            self.training = false;
360        }
361
362        fn is_training(&self) -> bool {
363            self.training
364        }
365    }
366
367    #[test]
368    fn test_ddp_sync_gradients() {
369        // 4 ranks. Each rank's parameter has a gradient equal to [rank, rank, rank].
370        // After sync_gradients (mean), all should have [1.5, 1.5, 1.5].
371        let group = SimulatedBackend::create_group(4).unwrap();
372        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
373
374        let handles: Vec<_> = arcs
375            .iter()
376            .cloned()
377            .map(|b| {
378                thread::spawn(move || {
379                    let rank = b.rank();
380                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0]).unwrap();
381                    let ddp = DDP::new(model, b);
382
383                    // Simulate a backward pass by manually setting gradients.
384                    let grad_val = rank as f32;
385                    let grad = Tensor::from_storage(
386                        TensorStorage::cpu(vec![grad_val, grad_val, grad_val]),
387                        vec![3],
388                        false,
389                    )
390                    .unwrap();
391                    ddp.module().weight.tensor().set_grad(Some(grad)).unwrap();
392
393                    // Sync gradients across all ranks.
394                    ddp.sync_gradients().unwrap();
395
396                    // All ranks should now have mean gradient = (0+1+2+3)/4 = 1.5
397                    let synced_grad = ddp.module().weight.tensor().grad().unwrap().unwrap();
398                    let data = synced_grad.data().unwrap();
399                    for &v in data {
400                        assert!((v - 1.5).abs() < 1e-5, "rank {rank}: expected 1.5, got {v}");
401                    }
402                })
403            })
404            .collect();
405
406        for h in handles {
407            h.join().unwrap();
408        }
409    }
410
411    #[test]
412    fn test_ddp_broadcast_parameters() {
413        // Rank 0 has weights [10, 20, 30]. Other ranks have [0, 0, 0].
414        // After broadcast_parameters(0), all should have [10, 20, 30].
415        let group = SimulatedBackend::create_group(3).unwrap();
416        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
417
418        let handles: Vec<_> = arcs
419            .iter()
420            .cloned()
421            .map(|b| {
422                thread::spawn(move || {
423                    let rank = b.rank();
424                    let data: Vec<f32> = if rank == 0 {
425                        vec![10.0, 20.0, 30.0]
426                    } else {
427                        vec![0.0, 0.0, 0.0]
428                    };
429                    let model = TestModule::<f32>::new(&data).unwrap();
430                    let mut ddp = DDP::new(model, b);
431
432                    ddp.broadcast_parameters(0).unwrap();
433
434                    let param_data = ddp.module().weight.tensor().data().unwrap();
435                    assert!(
436                        (param_data[0] - 10.0).abs() < 1e-5,
437                        "rank {rank}: expected 10.0, got {}",
438                        param_data[0]
439                    );
440                    assert!(
441                        (param_data[1] - 20.0).abs() < 1e-5,
442                        "rank {rank}: expected 20.0, got {}",
443                        param_data[1]
444                    );
445                    assert!(
446                        (param_data[2] - 30.0).abs() < 1e-5,
447                        "rank {rank}: expected 30.0, got {}",
448                        param_data[2]
449                    );
450                })
451            })
452            .collect();
453
454        for h in handles {
455            h.join().unwrap();
456        }
457    }
458
459    #[test]
460    fn test_ddp_delegates_module_trait() {
461        let group = SimulatedBackend::create_group(1).unwrap();
462        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
463        let model = TestModule::<f32>::new(&[1.0, 2.0]).unwrap();
464        let mut ddp = DDP::new(model, b);
465
466        // Module trait methods should delegate.
467        assert!(ddp.is_training());
468        ddp.eval();
469        assert!(!ddp.is_training());
470        ddp.train();
471        assert!(ddp.is_training());
472
473        assert_eq!(ddp.parameters().len(), 1);
474        assert_eq!(ddp.named_parameters()[0].0, "weight");
475    }
476}