Skip to main content

axonml_distributed/
ddp.rs

1//! DDP - Distributed Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/ddp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// DistributedDataParallel
25// =============================================================================
26
27/// Wrapper that enables distributed data parallel training.
28///
29/// DDP replicates the model across multiple processes and synchronizes
30/// gradients during the backward pass.
31pub struct DistributedDataParallel<M: Module> {
32    module: M,
33    process_group: ProcessGroup,
34    broadcast_buffers: bool,
35    gradient_as_bucket_view: bool,
36}
37
38impl<M: Module> DistributedDataParallel<M> {
39    /// Creates a new DDP wrapper.
40    pub fn new(module: M, process_group: ProcessGroup) -> Self {
41        Self {
42            module,
43            process_group,
44            broadcast_buffers: true,
45            gradient_as_bucket_view: true,
46        }
47    }
48
49    /// Sets whether to broadcast buffers from rank 0.
50    pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
51        self.broadcast_buffers = broadcast;
52        self
53    }
54
55    /// Sets whether to use gradient bucketing.
56    pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
57        self.gradient_as_bucket_view = bucket_view;
58        self
59    }
60
61    /// Returns a reference to the underlying module.
62    pub fn module(&self) -> &M {
63        &self.module
64    }
65
66    /// Returns a mutable reference to the underlying module.
67    pub fn module_mut(&mut self) -> &mut M {
68        &mut self.module
69    }
70
71    /// Returns the process group.
72    pub fn process_group(&self) -> &ProcessGroup {
73        &self.process_group
74    }
75
76    /// Synchronizes model parameters across all processes.
77    /// Should be called once at the start of training.
78    pub fn sync_parameters(&mut self) {
79        // Broadcast parameters from rank 0
80        for param in self.module.parameters() {
81            let mut tensor = param.data().clone();
82            self.process_group.broadcast_tensor(&mut tensor, 0);
83            // In a real implementation, we'd update the parameter
84        }
85    }
86
87    /// Synchronizes gradients across all processes.
88    /// Should be called after the backward pass.
89    pub fn sync_gradients(&self) {
90        // Get all gradients and all-reduce them
91        for param in self.module.parameters() {
92            if let Some(grad) = param.grad() {
93                let mut grad_tensor = grad.clone();
94                self.process_group
95                    .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
96                // In a real implementation, we'd update the gradient
97            }
98        }
99    }
100
101    /// Performs forward pass with gradient synchronization.
102    pub fn forward(&self, input: &Variable) -> Variable {
103        self.module.forward(input)
104    }
105}
106
107impl<M: Module> Module for DistributedDataParallel<M> {
108    fn forward(&self, input: &Variable) -> Variable {
109        self.module.forward(input)
110    }
111
112    fn parameters(&self) -> Vec<Parameter> {
113        self.module.parameters()
114    }
115
116    fn train(&mut self) {
117        self.module.train();
118    }
119
120    fn eval(&mut self) {
121        self.module.eval();
122    }
123
124    fn is_training(&self) -> bool {
125        self.module.is_training()
126    }
127}
128
129// =============================================================================
130// GradientBucket
131// =============================================================================
132
133/// A bucket for accumulating gradients before all-reduce.
134pub struct GradientBucket {
135    /// Flattened gradient data.
136    data: Vec<f32>,
137    /// Original shapes and sizes.
138    shapes: Vec<(Vec<usize>, usize)>,
139    /// Capacity in number of elements.
140    capacity: usize,
141}
142
143impl GradientBucket {
144    /// Creates a new gradient bucket.
145    #[must_use]
146    pub fn new(capacity: usize) -> Self {
147        Self {
148            data: Vec::with_capacity(capacity),
149            shapes: Vec::new(),
150            capacity,
151        }
152    }
153
154    /// Checks if the bucket is full.
155    #[must_use]
156    pub fn is_full(&self) -> bool {
157        self.data.len() >= self.capacity
158    }
159
160    /// Checks if the bucket is empty.
161    #[must_use]
162    pub fn is_empty(&self) -> bool {
163        self.data.is_empty()
164    }
165
166    /// Returns the current size.
167    #[must_use]
168    pub fn size(&self) -> usize {
169        self.data.len()
170    }
171
172    /// Adds a tensor to the bucket.
173    pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
174        let data = tensor.to_vec();
175        if self.data.len() + data.len() > self.capacity {
176            return false;
177        }
178
179        self.shapes.push((tensor.shape().to_vec(), data.len()));
180        self.data.extend(data);
181        true
182    }
183
184    /// Returns the flattened data.
185    #[must_use]
186    pub fn data(&self) -> &[f32] {
187        &self.data
188    }
189
190    /// Returns mutable flattened data.
191    pub fn data_mut(&mut self) -> &mut [f32] {
192        &mut self.data
193    }
194
195    /// Clears the bucket.
196    pub fn clear(&mut self) {
197        self.data.clear();
198        self.shapes.clear();
199    }
200
201    /// Extracts tensors back from the bucket.
202    #[must_use]
203    pub fn extract(&self) -> Vec<Tensor<f32>> {
204        let mut result = Vec::new();
205        let mut offset = 0;
206
207        for (shape, size) in &self.shapes {
208            let end = offset + size;
209            let data = self.data[offset..end].to_vec();
210            result.push(Tensor::from_vec(data, shape).unwrap());
211            offset = end;
212        }
213
214        result
215    }
216}
217
218// =============================================================================
219// Gradient Synchronization Strategies
220// =============================================================================
221
222/// Strategy for gradient synchronization.
223#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum GradSyncStrategy {
225    /// Synchronize after each backward pass.
226    Synchronous,
227    /// Overlap computation and communication.
228    Overlapped,
229    /// No gradient synchronization (for debugging).
230    NoSync,
231}
232
233/// Gradient synchronizer.
234pub struct GradientSynchronizer {
235    strategy: GradSyncStrategy,
236    bucket_size: usize,
237    buckets: Vec<GradientBucket>,
238}
239
240impl GradientSynchronizer {
241    /// Creates a new gradient synchronizer.
242    #[must_use]
243    pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
244        Self {
245            strategy,
246            bucket_size,
247            buckets: Vec::new(),
248        }
249    }
250
251    /// Returns the synchronization strategy.
252    #[must_use]
253    pub fn strategy(&self) -> GradSyncStrategy {
254        self.strategy
255    }
256
257    /// Prepares buckets for gradient accumulation.
258    pub fn prepare(&mut self, num_params: usize) {
259        let num_buckets = num_params.div_ceil(self.bucket_size);
260        self.buckets = (0..num_buckets)
261            .map(|_| GradientBucket::new(self.bucket_size))
262            .collect();
263    }
264
265    /// Adds a gradient to the appropriate bucket.
266    pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
267        if bucket_idx < self.buckets.len() {
268            self.buckets[bucket_idx].add(tensor);
269        }
270    }
271
272    /// Synchronizes all buckets.
273    pub fn sync_all(&mut self, process_group: &ProcessGroup) {
274        if self.strategy == GradSyncStrategy::NoSync {
275            return;
276        }
277
278        for bucket in &mut self.buckets {
279            if !bucket.is_empty() {
280                let mut data = bucket.data().to_vec();
281                let len = data.len();
282                process_group
283                    .backend()
284                    .all_reduce(&mut data, ReduceOp::Average);
285                bucket.data_mut()[..len].copy_from_slice(&data);
286            }
287        }
288    }
289
290    /// Clears all buckets.
291    pub fn clear(&mut self) {
292        for bucket in &mut self.buckets {
293            bucket.clear();
294        }
295    }
296}
297
298impl Default for GradientSynchronizer {
299    fn default() -> Self {
300        Self::new(GradSyncStrategy::Synchronous, 25_000_000) // ~100MB for f32
301    }
302}
303
304// =============================================================================
305// Tests
306// =============================================================================
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use axonml_nn::Linear;
312
313    #[test]
314    fn test_ddp_creation() {
315        let module = Linear::new(10, 5);
316        let pg = ProcessGroup::mock();
317        let ddp = DistributedDataParallel::new(module, pg);
318
319        assert_eq!(ddp.process_group().rank(), 0);
320        assert_eq!(ddp.process_group().world_size(), 1);
321    }
322
323    #[test]
324    fn test_ddp_forward() {
325        let module = Linear::new(4, 2);
326        let pg = ProcessGroup::mock();
327        let ddp = DistributedDataParallel::new(module, pg);
328
329        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
330        let output = ddp.forward(&input);
331
332        assert_eq!(output.data().shape(), &[1, 2]);
333    }
334
335    #[test]
336    fn test_ddp_module_access() {
337        let module = Linear::new(10, 5);
338        let pg = ProcessGroup::mock();
339        let mut ddp = DistributedDataParallel::new(module, pg);
340
341        // Access module
342        let _ = ddp.module();
343        let _ = ddp.module_mut();
344    }
345
346    #[test]
347    fn test_ddp_train_eval() {
348        let module = Linear::new(10, 5);
349        let pg = ProcessGroup::mock();
350        let mut ddp = DistributedDataParallel::new(module, pg);
351
352        // Module trait's default is_training() returns true
353        // Linear doesn't override train/eval behavior
354        assert!(ddp.is_training());
355
356        // Call train/eval - they are forwarded to the wrapped module
357        // but Linear's default implementation doesn't change state
358        ddp.train();
359        ddp.eval();
360
361        // Test that methods can be called without panic
362        let _ = ddp.is_training();
363    }
364
365    #[test]
366    fn test_ddp_parameters() {
367        let module = Linear::new(10, 5);
368        let pg = ProcessGroup::mock();
369        let ddp = DistributedDataParallel::new(module, pg);
370
371        let params = ddp.parameters();
372        assert!(!params.is_empty());
373    }
374
375    #[test]
376    fn test_gradient_bucket() {
377        let mut bucket = GradientBucket::new(100);
378
379        assert!(bucket.is_empty());
380
381        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
382        assert!(bucket.add(&tensor1));
383
384        assert!(!bucket.is_empty());
385        assert_eq!(bucket.size(), 3);
386
387        let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
388        assert!(bucket.add(&tensor2));
389
390        assert_eq!(bucket.size(), 5);
391        assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
392    }
393
394    #[test]
395    fn test_gradient_bucket_extract() {
396        let mut bucket = GradientBucket::new(100);
397
398        let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
399        let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
400
401        bucket.add(&tensor1);
402        bucket.add(&tensor2);
403
404        let extracted = bucket.extract();
405        assert_eq!(extracted.len(), 2);
406        assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
407        assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
408    }
409
410    #[test]
411    fn test_gradient_bucket_full() {
412        let mut bucket = GradientBucket::new(5);
413
414        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
415        assert!(bucket.add(&tensor1));
416
417        let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
418        assert!(!bucket.add(&tensor2)); // Won't fit
419    }
420
421    #[test]
422    fn test_gradient_bucket_clear() {
423        let mut bucket = GradientBucket::new(100);
424        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
425        bucket.add(&tensor);
426
427        bucket.clear();
428        assert!(bucket.is_empty());
429    }
430
431    #[test]
432    fn test_gradient_synchronizer() {
433        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
434        sync.prepare(10);
435
436        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
437    }
438
439    #[test]
440    fn test_gradient_synchronizer_no_sync() {
441        let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
442        sync.prepare(10);
443
444        let pg = ProcessGroup::mock();
445        sync.sync_all(&pg); // Should do nothing
446    }
447
448    #[test]
449    fn test_gradient_synchronizer_default() {
450        let sync = GradientSynchronizer::default();
451        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
452    }
453
454    #[test]
455    fn test_grad_sync_strategy() {
456        assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
457        assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
458    }
459}