Skip to main content

axonml_distributed/
ddp.rs

1//! DDP - Distributed Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/ddp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::backend::ReduceOp;
19use crate::process_group::ProcessGroup;
20use axonml_autograd::Variable;
21use axonml_nn::{Module, Parameter};
22use axonml_tensor::Tensor;
23
24// =============================================================================
25// DistributedDataParallel
26// =============================================================================
27
28/// Wrapper that enables distributed data parallel training.
29///
30/// DDP replicates the model across multiple processes and synchronizes
31/// gradients during the backward pass.
32pub struct DistributedDataParallel<M: Module> {
33    module: M,
34    process_group: ProcessGroup,
35    broadcast_buffers: bool,
36    gradient_as_bucket_view: bool,
37}
38
39impl<M: Module> DistributedDataParallel<M> {
40    /// Creates a new DDP wrapper.
41    pub fn new(module: M, process_group: ProcessGroup) -> Self {
42        Self {
43            module,
44            process_group,
45            broadcast_buffers: true,
46            gradient_as_bucket_view: true,
47        }
48    }
49
50    /// Sets whether to broadcast buffers from rank 0.
51    pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
52        self.broadcast_buffers = broadcast;
53        self
54    }
55
56    /// Sets whether to use gradient bucketing.
57    pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
58        self.gradient_as_bucket_view = bucket_view;
59        self
60    }
61
62    /// Returns a reference to the underlying module.
63    pub fn module(&self) -> &M {
64        &self.module
65    }
66
67    /// Returns a mutable reference to the underlying module.
68    pub fn module_mut(&mut self) -> &mut M {
69        &mut self.module
70    }
71
72    /// Returns the process group.
73    pub fn process_group(&self) -> &ProcessGroup {
74        &self.process_group
75    }
76
77    /// Synchronizes model parameters across all processes.
78    /// Should be called once at the start of training to ensure all
79    /// ranks start from identical parameters (broadcast from rank 0).
80    pub fn sync_parameters(&mut self) {
81        for param in self.module.parameters() {
82            let mut tensor = param.data().clone();
83            self.process_group.broadcast_tensor(&mut tensor, 0);
84            // Write the broadcast result back to the parameter
85            param.update_data(tensor);
86        }
87    }
88
89    /// Synchronizes gradients across all processes.
90    /// Should be called after the backward pass. All-reduces gradients
91    /// so every rank gets the average gradient across all ranks.
92    pub fn sync_gradients(&self) {
93        for param in self.module.parameters() {
94            if let Some(grad) = param.grad() {
95                let mut grad_tensor = grad.clone();
96                self.process_group
97                    .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
98                // Write the all-reduced gradient back to the parameter
99                param.set_grad(grad_tensor);
100            }
101        }
102    }
103
104    /// Performs forward pass with gradient synchronization.
105    pub fn forward(&self, input: &Variable) -> Variable {
106        self.module.forward(input)
107    }
108}
109
110impl<M: Module> Module for DistributedDataParallel<M> {
111    fn forward(&self, input: &Variable) -> Variable {
112        self.module.forward(input)
113    }
114
115    fn parameters(&self) -> Vec<Parameter> {
116        self.module.parameters()
117    }
118
119    fn train(&mut self) {
120        self.module.train();
121    }
122
123    fn eval(&mut self) {
124        self.module.eval();
125    }
126
127    fn is_training(&self) -> bool {
128        self.module.is_training()
129    }
130}
131
132// =============================================================================
133// GradientBucket
134// =============================================================================
135
136/// A bucket for accumulating gradients before all-reduce.
137pub struct GradientBucket {
138    /// Flattened gradient data.
139    data: Vec<f32>,
140    /// Original shapes and sizes.
141    shapes: Vec<(Vec<usize>, usize)>,
142    /// Capacity in number of elements.
143    capacity: usize,
144}
145
146impl GradientBucket {
147    /// Creates a new gradient bucket.
148    #[must_use]
149    pub fn new(capacity: usize) -> Self {
150        Self {
151            data: Vec::with_capacity(capacity),
152            shapes: Vec::new(),
153            capacity,
154        }
155    }
156
157    /// Checks if the bucket is full.
158    #[must_use]
159    pub fn is_full(&self) -> bool {
160        self.data.len() >= self.capacity
161    }
162
163    /// Checks if the bucket is empty.
164    #[must_use]
165    pub fn is_empty(&self) -> bool {
166        self.data.is_empty()
167    }
168
169    /// Returns the current size.
170    #[must_use]
171    pub fn size(&self) -> usize {
172        self.data.len()
173    }
174
175    /// Adds a tensor to the bucket.
176    pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
177        let data = tensor.to_vec();
178        if self.data.len() + data.len() > self.capacity {
179            return false;
180        }
181
182        self.shapes.push((tensor.shape().to_vec(), data.len()));
183        self.data.extend(data);
184        true
185    }
186
187    /// Returns the flattened data.
188    #[must_use]
189    pub fn data(&self) -> &[f32] {
190        &self.data
191    }
192
193    /// Returns mutable flattened data.
194    pub fn data_mut(&mut self) -> &mut [f32] {
195        &mut self.data
196    }
197
198    /// Clears the bucket.
199    pub fn clear(&mut self) {
200        self.data.clear();
201        self.shapes.clear();
202    }
203
204    /// Extracts tensors back from the bucket.
205    #[must_use]
206    pub fn extract(&self) -> Vec<Tensor<f32>> {
207        let mut result = Vec::new();
208        let mut offset = 0;
209
210        for (shape, size) in &self.shapes {
211            let end = offset + size;
212            let data = self.data[offset..end].to_vec();
213            result.push(Tensor::from_vec(data, shape).unwrap());
214            offset = end;
215        }
216
217        result
218    }
219}
220
221// =============================================================================
222// Gradient Synchronization Strategies
223// =============================================================================
224
225/// Strategy for gradient synchronization.
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub enum GradSyncStrategy {
228    /// Synchronize after each backward pass.
229    Synchronous,
230    /// Overlap computation and communication.
231    Overlapped,
232    /// No gradient synchronization (for debugging).
233    NoSync,
234}
235
236/// Gradient synchronizer.
237pub struct GradientSynchronizer {
238    strategy: GradSyncStrategy,
239    bucket_size: usize,
240    buckets: Vec<GradientBucket>,
241}
242
243impl GradientSynchronizer {
244    /// Creates a new gradient synchronizer.
245    #[must_use]
246    pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
247        Self {
248            strategy,
249            bucket_size,
250            buckets: Vec::new(),
251        }
252    }
253
254    /// Returns the synchronization strategy.
255    #[must_use]
256    pub fn strategy(&self) -> GradSyncStrategy {
257        self.strategy
258    }
259
260    /// Prepares buckets for gradient accumulation.
261    pub fn prepare(&mut self, num_params: usize) {
262        let num_buckets = num_params.div_ceil(self.bucket_size);
263        self.buckets = (0..num_buckets)
264            .map(|_| GradientBucket::new(self.bucket_size))
265            .collect();
266    }
267
268    /// Adds a gradient to the appropriate bucket.
269    pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
270        if bucket_idx < self.buckets.len() {
271            self.buckets[bucket_idx].add(tensor);
272        }
273    }
274
275    /// Synchronizes all buckets.
276    pub fn sync_all(&mut self, process_group: &ProcessGroup) {
277        if self.strategy == GradSyncStrategy::NoSync {
278            return;
279        }
280
281        for bucket in &mut self.buckets {
282            if !bucket.is_empty() {
283                let mut data = bucket.data().to_vec();
284                let len = data.len();
285                process_group
286                    .backend()
287                    .all_reduce(&mut data, ReduceOp::Average);
288                bucket.data_mut()[..len].copy_from_slice(&data);
289            }
290        }
291    }
292
293    /// Clears all buckets.
294    pub fn clear(&mut self) {
295        for bucket in &mut self.buckets {
296            bucket.clear();
297        }
298    }
299}
300
301impl Default for GradientSynchronizer {
302    fn default() -> Self {
303        Self::new(GradSyncStrategy::Synchronous, 25_000_000) // ~100MB for f32
304    }
305}
306
307// =============================================================================
308// Tests
309// =============================================================================
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use axonml_nn::Linear;
315
316    #[test]
317    fn test_ddp_creation() {
318        let module = Linear::new(10, 5);
319        let pg = ProcessGroup::mock();
320        let ddp = DistributedDataParallel::new(module, pg);
321
322        assert_eq!(ddp.process_group().rank(), 0);
323        assert_eq!(ddp.process_group().world_size(), 1);
324    }
325
326    #[test]
327    fn test_ddp_forward() {
328        let module = Linear::new(4, 2);
329        let pg = ProcessGroup::mock();
330        let ddp = DistributedDataParallel::new(module, pg);
331
332        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
333        let output = ddp.forward(&input);
334
335        assert_eq!(output.data().shape(), &[1, 2]);
336    }
337
338    #[test]
339    fn test_ddp_module_access() {
340        let module = Linear::new(10, 5);
341        let pg = ProcessGroup::mock();
342        let mut ddp = DistributedDataParallel::new(module, pg);
343
344        // Access module
345        let _ = ddp.module();
346        let _ = ddp.module_mut();
347    }
348
349    #[test]
350    fn test_ddp_train_eval() {
351        let module = Linear::new(10, 5);
352        let pg = ProcessGroup::mock();
353        let mut ddp = DistributedDataParallel::new(module, pg);
354
355        // Module trait's default is_training() returns true
356        // Linear doesn't override train/eval behavior
357        assert!(ddp.is_training());
358
359        // Call train/eval - they are forwarded to the wrapped module
360        // but Linear's default implementation doesn't change state
361        ddp.train();
362        ddp.eval();
363
364        // Test that methods can be called without panic
365        let _ = ddp.is_training();
366    }
367
368    #[test]
369    fn test_ddp_parameters() {
370        let module = Linear::new(10, 5);
371        let pg = ProcessGroup::mock();
372        let ddp = DistributedDataParallel::new(module, pg);
373
374        let params = ddp.parameters();
375        assert!(!params.is_empty());
376    }
377
378    #[test]
379    fn test_gradient_bucket() {
380        let mut bucket = GradientBucket::new(100);
381
382        assert!(bucket.is_empty());
383
384        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
385        assert!(bucket.add(&tensor1));
386
387        assert!(!bucket.is_empty());
388        assert_eq!(bucket.size(), 3);
389
390        let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
391        assert!(bucket.add(&tensor2));
392
393        assert_eq!(bucket.size(), 5);
394        assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
395    }
396
397    #[test]
398    fn test_gradient_bucket_extract() {
399        let mut bucket = GradientBucket::new(100);
400
401        let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
402        let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
403
404        bucket.add(&tensor1);
405        bucket.add(&tensor2);
406
407        let extracted = bucket.extract();
408        assert_eq!(extracted.len(), 2);
409        assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
410        assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
411    }
412
413    #[test]
414    fn test_gradient_bucket_full() {
415        let mut bucket = GradientBucket::new(5);
416
417        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
418        assert!(bucket.add(&tensor1));
419
420        let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
421        assert!(!bucket.add(&tensor2)); // Won't fit
422    }
423
424    #[test]
425    fn test_gradient_bucket_clear() {
426        let mut bucket = GradientBucket::new(100);
427        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
428        bucket.add(&tensor);
429
430        bucket.clear();
431        assert!(bucket.is_empty());
432    }
433
434    #[test]
435    fn test_gradient_synchronizer() {
436        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
437        sync.prepare(10);
438
439        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
440    }
441
442    #[test]
443    fn test_gradient_synchronizer_no_sync() {
444        let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
445        sync.prepare(10);
446
447        let pg = ProcessGroup::mock();
448        sync.sync_all(&pg); // Should do nothing
449    }
450
451    #[test]
452    fn test_gradient_synchronizer_default() {
453        let sync = GradientSynchronizer::default();
454        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
455    }
456
457    #[test]
458    fn test_grad_sync_strategy() {
459        assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
460        assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
461    }
462}