Skip to main content

axonml_distributed/
ddp.rs

1//! DDP - Distributed Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/ddp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// DistributedDataParallel
25// =============================================================================
26
27/// Wrapper that enables distributed data parallel training.
28///
29/// DDP replicates the model across multiple processes and synchronizes
30/// gradients during the backward pass.
31pub struct DistributedDataParallel<M: Module> {
32    module: M,
33    process_group: ProcessGroup,
34    broadcast_buffers: bool,
35    gradient_as_bucket_view: bool,
36}
37
38impl<M: Module> DistributedDataParallel<M> {
39    /// Creates a new DDP wrapper.
40    pub fn new(module: M, process_group: ProcessGroup) -> Self {
41        Self {
42            module,
43            process_group,
44            broadcast_buffers: true,
45            gradient_as_bucket_view: true,
46        }
47    }
48
49    /// Sets whether to broadcast buffers from rank 0.
50    pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
51        self.broadcast_buffers = broadcast;
52        self
53    }
54
55    /// Sets whether to use gradient bucketing.
56    pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
57        self.gradient_as_bucket_view = bucket_view;
58        self
59    }
60
61    /// Returns a reference to the underlying module.
62    pub fn module(&self) -> &M {
63        &self.module
64    }
65
66    /// Returns a mutable reference to the underlying module.
67    pub fn module_mut(&mut self) -> &mut M {
68        &mut self.module
69    }
70
71    /// Returns the process group.
72    pub fn process_group(&self) -> &ProcessGroup {
73        &self.process_group
74    }
75
76    /// Synchronizes model parameters across all processes.
77    /// Should be called once at the start of training to ensure all
78    /// ranks start from identical parameters (broadcast from rank 0).
79    pub fn sync_parameters(&mut self) {
80        for param in self.module.parameters() {
81            let mut tensor = param.data().clone();
82            self.process_group.broadcast_tensor(&mut tensor, 0);
83            // Write the broadcast result back to the parameter
84            param.update_data(tensor);
85        }
86    }
87
88    /// Synchronizes gradients across all processes.
89    /// Should be called after the backward pass. All-reduces gradients
90    /// so every rank gets the average gradient across all ranks.
91    pub fn sync_gradients(&self) {
92        for param in self.module.parameters() {
93            if let Some(grad) = param.grad() {
94                let mut grad_tensor = grad.clone();
95                self.process_group
96                    .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
97                // Write the all-reduced gradient back to the parameter
98                param.set_grad(grad_tensor);
99            }
100        }
101    }
102
103    /// Performs forward pass with gradient synchronization.
104    pub fn forward(&self, input: &Variable) -> Variable {
105        self.module.forward(input)
106    }
107}
108
109impl<M: Module> Module for DistributedDataParallel<M> {
110    fn forward(&self, input: &Variable) -> Variable {
111        self.module.forward(input)
112    }
113
114    fn parameters(&self) -> Vec<Parameter> {
115        self.module.parameters()
116    }
117
118    fn train(&mut self) {
119        self.module.train();
120    }
121
122    fn eval(&mut self) {
123        self.module.eval();
124    }
125
126    fn is_training(&self) -> bool {
127        self.module.is_training()
128    }
129}
130
131// =============================================================================
132// GradientBucket
133// =============================================================================
134
135/// A bucket for accumulating gradients before all-reduce.
136pub struct GradientBucket {
137    /// Flattened gradient data.
138    data: Vec<f32>,
139    /// Original shapes and sizes.
140    shapes: Vec<(Vec<usize>, usize)>,
141    /// Capacity in number of elements.
142    capacity: usize,
143}
144
145impl GradientBucket {
146    /// Creates a new gradient bucket.
147    #[must_use]
148    pub fn new(capacity: usize) -> Self {
149        Self {
150            data: Vec::with_capacity(capacity),
151            shapes: Vec::new(),
152            capacity,
153        }
154    }
155
156    /// Checks if the bucket is full.
157    #[must_use]
158    pub fn is_full(&self) -> bool {
159        self.data.len() >= self.capacity
160    }
161
162    /// Checks if the bucket is empty.
163    #[must_use]
164    pub fn is_empty(&self) -> bool {
165        self.data.is_empty()
166    }
167
168    /// Returns the current size.
169    #[must_use]
170    pub fn size(&self) -> usize {
171        self.data.len()
172    }
173
174    /// Adds a tensor to the bucket.
175    pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
176        let data = tensor.to_vec();
177        if self.data.len() + data.len() > self.capacity {
178            return false;
179        }
180
181        self.shapes.push((tensor.shape().to_vec(), data.len()));
182        self.data.extend(data);
183        true
184    }
185
186    /// Returns the flattened data.
187    #[must_use]
188    pub fn data(&self) -> &[f32] {
189        &self.data
190    }
191
192    /// Returns mutable flattened data.
193    pub fn data_mut(&mut self) -> &mut [f32] {
194        &mut self.data
195    }
196
197    /// Clears the bucket.
198    pub fn clear(&mut self) {
199        self.data.clear();
200        self.shapes.clear();
201    }
202
203    /// Extracts tensors back from the bucket.
204    #[must_use]
205    pub fn extract(&self) -> Vec<Tensor<f32>> {
206        let mut result = Vec::new();
207        let mut offset = 0;
208
209        for (shape, size) in &self.shapes {
210            let end = offset + size;
211            let data = self.data[offset..end].to_vec();
212            result.push(Tensor::from_vec(data, shape).unwrap());
213            offset = end;
214        }
215
216        result
217    }
218}
219
220// =============================================================================
221// Gradient Synchronization Strategies
222// =============================================================================
223
224/// Strategy for gradient synchronization.
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum GradSyncStrategy {
227    /// Synchronize after each backward pass.
228    Synchronous,
229    /// Overlap computation and communication.
230    Overlapped,
231    /// No gradient synchronization (for debugging).
232    NoSync,
233}
234
235/// Gradient synchronizer.
236pub struct GradientSynchronizer {
237    strategy: GradSyncStrategy,
238    bucket_size: usize,
239    buckets: Vec<GradientBucket>,
240}
241
242impl GradientSynchronizer {
243    /// Creates a new gradient synchronizer.
244    #[must_use]
245    pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
246        Self {
247            strategy,
248            bucket_size,
249            buckets: Vec::new(),
250        }
251    }
252
253    /// Returns the synchronization strategy.
254    #[must_use]
255    pub fn strategy(&self) -> GradSyncStrategy {
256        self.strategy
257    }
258
259    /// Prepares buckets for gradient accumulation.
260    pub fn prepare(&mut self, num_params: usize) {
261        let num_buckets = num_params.div_ceil(self.bucket_size);
262        self.buckets = (0..num_buckets)
263            .map(|_| GradientBucket::new(self.bucket_size))
264            .collect();
265    }
266
267    /// Adds a gradient to the appropriate bucket.
268    pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
269        if bucket_idx < self.buckets.len() {
270            self.buckets[bucket_idx].add(tensor);
271        }
272    }
273
274    /// Synchronizes all buckets.
275    pub fn sync_all(&mut self, process_group: &ProcessGroup) {
276        if self.strategy == GradSyncStrategy::NoSync {
277            return;
278        }
279
280        for bucket in &mut self.buckets {
281            if !bucket.is_empty() {
282                let mut data = bucket.data().to_vec();
283                let len = data.len();
284                process_group
285                    .backend()
286                    .all_reduce(&mut data, ReduceOp::Average);
287                bucket.data_mut()[..len].copy_from_slice(&data);
288            }
289        }
290    }
291
292    /// Clears all buckets.
293    pub fn clear(&mut self) {
294        for bucket in &mut self.buckets {
295            bucket.clear();
296        }
297    }
298}
299
300impl Default for GradientSynchronizer {
301    fn default() -> Self {
302        Self::new(GradSyncStrategy::Synchronous, 25_000_000) // ~100MB for f32
303    }
304}
305
306// =============================================================================
307// Tests
308// =============================================================================
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use axonml_nn::Linear;
314
315    #[test]
316    fn test_ddp_creation() {
317        let module = Linear::new(10, 5);
318        let pg = ProcessGroup::mock();
319        let ddp = DistributedDataParallel::new(module, pg);
320
321        assert_eq!(ddp.process_group().rank(), 0);
322        assert_eq!(ddp.process_group().world_size(), 1);
323    }
324
325    #[test]
326    fn test_ddp_forward() {
327        let module = Linear::new(4, 2);
328        let pg = ProcessGroup::mock();
329        let ddp = DistributedDataParallel::new(module, pg);
330
331        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
332        let output = ddp.forward(&input);
333
334        assert_eq!(output.data().shape(), &[1, 2]);
335    }
336
337    #[test]
338    fn test_ddp_module_access() {
339        let module = Linear::new(10, 5);
340        let pg = ProcessGroup::mock();
341        let mut ddp = DistributedDataParallel::new(module, pg);
342
343        // Access module
344        let _ = ddp.module();
345        let _ = ddp.module_mut();
346    }
347
348    #[test]
349    fn test_ddp_train_eval() {
350        let module = Linear::new(10, 5);
351        let pg = ProcessGroup::mock();
352        let mut ddp = DistributedDataParallel::new(module, pg);
353
354        // Module trait's default is_training() returns true
355        // Linear doesn't override train/eval behavior
356        assert!(ddp.is_training());
357
358        // Call train/eval - they are forwarded to the wrapped module
359        // but Linear's default implementation doesn't change state
360        ddp.train();
361        ddp.eval();
362
363        // Test that methods can be called without panic
364        let _ = ddp.is_training();
365    }
366
367    #[test]
368    fn test_ddp_parameters() {
369        let module = Linear::new(10, 5);
370        let pg = ProcessGroup::mock();
371        let ddp = DistributedDataParallel::new(module, pg);
372
373        let params = ddp.parameters();
374        assert!(!params.is_empty());
375    }
376
377    #[test]
378    fn test_gradient_bucket() {
379        let mut bucket = GradientBucket::new(100);
380
381        assert!(bucket.is_empty());
382
383        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
384        assert!(bucket.add(&tensor1));
385
386        assert!(!bucket.is_empty());
387        assert_eq!(bucket.size(), 3);
388
389        let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
390        assert!(bucket.add(&tensor2));
391
392        assert_eq!(bucket.size(), 5);
393        assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
394    }
395
396    #[test]
397    fn test_gradient_bucket_extract() {
398        let mut bucket = GradientBucket::new(100);
399
400        let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
401        let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
402
403        bucket.add(&tensor1);
404        bucket.add(&tensor2);
405
406        let extracted = bucket.extract();
407        assert_eq!(extracted.len(), 2);
408        assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
409        assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
410    }
411
412    #[test]
413    fn test_gradient_bucket_full() {
414        let mut bucket = GradientBucket::new(5);
415
416        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
417        assert!(bucket.add(&tensor1));
418
419        let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
420        assert!(!bucket.add(&tensor2)); // Won't fit
421    }
422
423    #[test]
424    fn test_gradient_bucket_clear() {
425        let mut bucket = GradientBucket::new(100);
426        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
427        bucket.add(&tensor);
428
429        bucket.clear();
430        assert!(bucket.is_empty());
431    }
432
433    #[test]
434    fn test_gradient_synchronizer() {
435        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
436        sync.prepare(10);
437
438        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
439    }
440
441    #[test]
442    fn test_gradient_synchronizer_no_sync() {
443        let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
444        sync.prepare(10);
445
446        let pg = ProcessGroup::mock();
447        sync.sync_all(&pg); // Should do nothing
448    }
449
450    #[test]
451    fn test_gradient_synchronizer_default() {
452        let sync = GradientSynchronizer::default();
453        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
454    }
455
456    #[test]
457    fn test_grad_sync_strategy() {
458        assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
459        assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
460    }
461}