Skip to main content

axonml_distributed/
ddp.rs

1//! DDP - Distributed Data Parallel
2//!
3//! Provides `DistributedDataParallel` wrapper for distributed training.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_autograd::Variable;
11use axonml_nn::{Module, Parameter};
12use axonml_tensor::Tensor;
13
14// =============================================================================
15// DistributedDataParallel
16// =============================================================================
17
18/// Wrapper that enables distributed data parallel training.
19///
20/// DDP replicates the model across multiple processes and synchronizes
21/// gradients during the backward pass.
22pub struct DistributedDataParallel<M: Module> {
23    module: M,
24    process_group: ProcessGroup,
25    broadcast_buffers: bool,
26    gradient_as_bucket_view: bool,
27}
28
29impl<M: Module> DistributedDataParallel<M> {
30    /// Creates a new DDP wrapper.
31    pub fn new(module: M, process_group: ProcessGroup) -> Self {
32        Self {
33            module,
34            process_group,
35            broadcast_buffers: true,
36            gradient_as_bucket_view: true,
37        }
38    }
39
40    /// Sets whether to broadcast buffers from rank 0.
41    pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
42        self.broadcast_buffers = broadcast;
43        self
44    }
45
46    /// Sets whether to use gradient bucketing.
47    pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
48        self.gradient_as_bucket_view = bucket_view;
49        self
50    }
51
52    /// Returns a reference to the underlying module.
53    pub fn module(&self) -> &M {
54        &self.module
55    }
56
57    /// Returns a mutable reference to the underlying module.
58    pub fn module_mut(&mut self) -> &mut M {
59        &mut self.module
60    }
61
62    /// Returns the process group.
63    pub fn process_group(&self) -> &ProcessGroup {
64        &self.process_group
65    }
66
67    /// Synchronizes model parameters across all processes.
68    /// Should be called once at the start of training.
69    pub fn sync_parameters(&mut self) {
70        // Broadcast parameters from rank 0
71        for param in self.module.parameters() {
72            let mut tensor = param.data().clone();
73            self.process_group.broadcast_tensor(&mut tensor, 0);
74            // In a real implementation, we'd update the parameter
75        }
76    }
77
78    /// Synchronizes gradients across all processes.
79    /// Should be called after the backward pass.
80    pub fn sync_gradients(&self) {
81        // Get all gradients and all-reduce them
82        for param in self.module.parameters() {
83            if let Some(grad) = param.grad() {
84                let mut grad_tensor = grad.clone();
85                self.process_group
86                    .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
87                // In a real implementation, we'd update the gradient
88            }
89        }
90    }
91
92    /// Performs forward pass with gradient synchronization.
93    pub fn forward(&self, input: &Variable) -> Variable {
94        self.module.forward(input)
95    }
96}
97
98impl<M: Module> Module for DistributedDataParallel<M> {
99    fn forward(&self, input: &Variable) -> Variable {
100        self.module.forward(input)
101    }
102
103    fn parameters(&self) -> Vec<Parameter> {
104        self.module.parameters()
105    }
106
107    fn train(&mut self) {
108        self.module.train();
109    }
110
111    fn eval(&mut self) {
112        self.module.eval();
113    }
114
115    fn is_training(&self) -> bool {
116        self.module.is_training()
117    }
118}
119
120// =============================================================================
121// GradientBucket
122// =============================================================================
123
124/// A bucket for accumulating gradients before all-reduce.
125pub struct GradientBucket {
126    /// Flattened gradient data.
127    data: Vec<f32>,
128    /// Original shapes and sizes.
129    shapes: Vec<(Vec<usize>, usize)>,
130    /// Capacity in number of elements.
131    capacity: usize,
132}
133
134impl GradientBucket {
135    /// Creates a new gradient bucket.
136    #[must_use]
137    pub fn new(capacity: usize) -> Self {
138        Self {
139            data: Vec::with_capacity(capacity),
140            shapes: Vec::new(),
141            capacity,
142        }
143    }
144
145    /// Checks if the bucket is full.
146    #[must_use]
147    pub fn is_full(&self) -> bool {
148        self.data.len() >= self.capacity
149    }
150
151    /// Checks if the bucket is empty.
152    #[must_use]
153    pub fn is_empty(&self) -> bool {
154        self.data.is_empty()
155    }
156
157    /// Returns the current size.
158    #[must_use]
159    pub fn size(&self) -> usize {
160        self.data.len()
161    }
162
163    /// Adds a tensor to the bucket.
164    pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
165        let data = tensor.to_vec();
166        if self.data.len() + data.len() > self.capacity {
167            return false;
168        }
169
170        self.shapes.push((tensor.shape().to_vec(), data.len()));
171        self.data.extend(data);
172        true
173    }
174
175    /// Returns the flattened data.
176    #[must_use]
177    pub fn data(&self) -> &[f32] {
178        &self.data
179    }
180
181    /// Returns mutable flattened data.
182    pub fn data_mut(&mut self) -> &mut [f32] {
183        &mut self.data
184    }
185
186    /// Clears the bucket.
187    pub fn clear(&mut self) {
188        self.data.clear();
189        self.shapes.clear();
190    }
191
192    /// Extracts tensors back from the bucket.
193    #[must_use]
194    pub fn extract(&self) -> Vec<Tensor<f32>> {
195        let mut result = Vec::new();
196        let mut offset = 0;
197
198        for (shape, size) in &self.shapes {
199            let end = offset + size;
200            let data = self.data[offset..end].to_vec();
201            result.push(Tensor::from_vec(data, shape).unwrap());
202            offset = end;
203        }
204
205        result
206    }
207}
208
209// =============================================================================
210// Gradient Synchronization Strategies
211// =============================================================================
212
213/// Strategy for gradient synchronization.
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum GradSyncStrategy {
216    /// Synchronize after each backward pass.
217    Synchronous,
218    /// Overlap computation and communication.
219    Overlapped,
220    /// No gradient synchronization (for debugging).
221    NoSync,
222}
223
224/// Gradient synchronizer.
225pub struct GradientSynchronizer {
226    strategy: GradSyncStrategy,
227    bucket_size: usize,
228    buckets: Vec<GradientBucket>,
229}
230
231impl GradientSynchronizer {
232    /// Creates a new gradient synchronizer.
233    #[must_use]
234    pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
235        Self {
236            strategy,
237            bucket_size,
238            buckets: Vec::new(),
239        }
240    }
241
242    /// Returns the synchronization strategy.
243    #[must_use]
244    pub fn strategy(&self) -> GradSyncStrategy {
245        self.strategy
246    }
247
248    /// Prepares buckets for gradient accumulation.
249    pub fn prepare(&mut self, num_params: usize) {
250        let num_buckets = num_params.div_ceil(self.bucket_size);
251        self.buckets = (0..num_buckets)
252            .map(|_| GradientBucket::new(self.bucket_size))
253            .collect();
254    }
255
256    /// Adds a gradient to the appropriate bucket.
257    pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
258        if bucket_idx < self.buckets.len() {
259            self.buckets[bucket_idx].add(tensor);
260        }
261    }
262
263    /// Synchronizes all buckets.
264    pub fn sync_all(&mut self, process_group: &ProcessGroup) {
265        if self.strategy == GradSyncStrategy::NoSync {
266            return;
267        }
268
269        for bucket in &mut self.buckets {
270            if !bucket.is_empty() {
271                let mut data = bucket.data().to_vec();
272                let len = data.len();
273                process_group
274                    .backend()
275                    .all_reduce(&mut data, ReduceOp::Average);
276                bucket.data_mut()[..len].copy_from_slice(&data);
277            }
278        }
279    }
280
281    /// Clears all buckets.
282    pub fn clear(&mut self) {
283        for bucket in &mut self.buckets {
284            bucket.clear();
285        }
286    }
287}
288
289impl Default for GradientSynchronizer {
290    fn default() -> Self {
291        Self::new(GradSyncStrategy::Synchronous, 25_000_000) // ~100MB for f32
292    }
293}
294
295// =============================================================================
296// Tests
297// =============================================================================
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use axonml_nn::Linear;
303
304    #[test]
305    fn test_ddp_creation() {
306        let module = Linear::new(10, 5);
307        let pg = ProcessGroup::mock();
308        let ddp = DistributedDataParallel::new(module, pg);
309
310        assert_eq!(ddp.process_group().rank(), 0);
311        assert_eq!(ddp.process_group().world_size(), 1);
312    }
313
314    #[test]
315    fn test_ddp_forward() {
316        let module = Linear::new(4, 2);
317        let pg = ProcessGroup::mock();
318        let ddp = DistributedDataParallel::new(module, pg);
319
320        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
321        let output = ddp.forward(&input);
322
323        assert_eq!(output.data().shape(), &[1, 2]);
324    }
325
326    #[test]
327    fn test_ddp_module_access() {
328        let module = Linear::new(10, 5);
329        let pg = ProcessGroup::mock();
330        let mut ddp = DistributedDataParallel::new(module, pg);
331
332        // Access module
333        let _ = ddp.module();
334        let _ = ddp.module_mut();
335    }
336
337    #[test]
338    fn test_ddp_train_eval() {
339        let module = Linear::new(10, 5);
340        let pg = ProcessGroup::mock();
341        let mut ddp = DistributedDataParallel::new(module, pg);
342
343        // Module trait's default is_training() returns true
344        // Linear doesn't override train/eval behavior
345        assert!(ddp.is_training());
346
347        // Call train/eval - they are forwarded to the wrapped module
348        // but Linear's default implementation doesn't change state
349        ddp.train();
350        ddp.eval();
351
352        // Test that methods can be called without panic
353        let _ = ddp.is_training();
354    }
355
356    #[test]
357    fn test_ddp_parameters() {
358        let module = Linear::new(10, 5);
359        let pg = ProcessGroup::mock();
360        let ddp = DistributedDataParallel::new(module, pg);
361
362        let params = ddp.parameters();
363        assert!(!params.is_empty());
364    }
365
366    #[test]
367    fn test_gradient_bucket() {
368        let mut bucket = GradientBucket::new(100);
369
370        assert!(bucket.is_empty());
371
372        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
373        assert!(bucket.add(&tensor1));
374
375        assert!(!bucket.is_empty());
376        assert_eq!(bucket.size(), 3);
377
378        let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
379        assert!(bucket.add(&tensor2));
380
381        assert_eq!(bucket.size(), 5);
382        assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
383    }
384
385    #[test]
386    fn test_gradient_bucket_extract() {
387        let mut bucket = GradientBucket::new(100);
388
389        let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
390        let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
391
392        bucket.add(&tensor1);
393        bucket.add(&tensor2);
394
395        let extracted = bucket.extract();
396        assert_eq!(extracted.len(), 2);
397        assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
398        assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
399    }
400
401    #[test]
402    fn test_gradient_bucket_full() {
403        let mut bucket = GradientBucket::new(5);
404
405        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
406        assert!(bucket.add(&tensor1));
407
408        let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
409        assert!(!bucket.add(&tensor2)); // Won't fit
410    }
411
412    #[test]
413    fn test_gradient_bucket_clear() {
414        let mut bucket = GradientBucket::new(100);
415        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
416        bucket.add(&tensor);
417
418        bucket.clear();
419        assert!(bucket.is_empty());
420    }
421
422    #[test]
423    fn test_gradient_synchronizer() {
424        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
425        sync.prepare(10);
426
427        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
428    }
429
430    #[test]
431    fn test_gradient_synchronizer_no_sync() {
432        let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
433        sync.prepare(10);
434
435        let pg = ProcessGroup::mock();
436        sync.sync_all(&pg); // Should do nothing
437    }
438
439    #[test]
440    fn test_gradient_synchronizer_default() {
441        let sync = GradientSynchronizer::default();
442        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
443    }
444
445    #[test]
446    fn test_grad_sync_strategy() {
447        assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
448        assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
449    }
450}