Skip to main content

axonml_distributed/
ddp.rs

1//! DDP - Distributed Data Parallel
2//!
3//! Provides `DistributedDataParallel` wrapper for distributed training.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_autograd::Variable;
11use axonml_nn::{Module, Parameter};
12use axonml_tensor::Tensor;
13
14// =============================================================================
15// DistributedDataParallel
16// =============================================================================
17
18/// Wrapper that enables distributed data parallel training.
19///
20/// DDP replicates the model across multiple processes and synchronizes
21/// gradients during the backward pass.
22pub struct DistributedDataParallel<M: Module> {
23    module: M,
24    process_group: ProcessGroup,
25    broadcast_buffers: bool,
26    gradient_as_bucket_view: bool,
27}
28
29impl<M: Module> DistributedDataParallel<M> {
30    /// Creates a new DDP wrapper.
31    pub fn new(module: M, process_group: ProcessGroup) -> Self {
32        Self {
33            module,
34            process_group,
35            broadcast_buffers: true,
36            gradient_as_bucket_view: true,
37        }
38    }
39
40    /// Sets whether to broadcast buffers from rank 0.
41    pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
42        self.broadcast_buffers = broadcast;
43        self
44    }
45
46    /// Sets whether to use gradient bucketing.
47    pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
48        self.gradient_as_bucket_view = bucket_view;
49        self
50    }
51
52    /// Returns a reference to the underlying module.
53    pub fn module(&self) -> &M {
54        &self.module
55    }
56
57    /// Returns a mutable reference to the underlying module.
58    pub fn module_mut(&mut self) -> &mut M {
59        &mut self.module
60    }
61
62    /// Returns the process group.
63    pub fn process_group(&self) -> &ProcessGroup {
64        &self.process_group
65    }
66
67    /// Synchronizes model parameters across all processes.
68    /// Should be called once at the start of training.
69    pub fn sync_parameters(&mut self) {
70        // Broadcast parameters from rank 0
71        for param in self.module.parameters() {
72            let mut tensor = param.data().clone();
73            self.process_group.broadcast_tensor(&mut tensor, 0);
74            // In a real implementation, we'd update the parameter
75        }
76    }
77
78    /// Synchronizes gradients across all processes.
79    /// Should be called after the backward pass.
80    pub fn sync_gradients(&self) {
81        // Get all gradients and all-reduce them
82        for param in self.module.parameters() {
83            if let Some(grad) = param.grad() {
84                let mut grad_tensor = grad.clone();
85                self.process_group
86                    .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
87                // In a real implementation, we'd update the gradient
88            }
89        }
90    }
91
92    /// Performs forward pass with gradient synchronization.
93    pub fn forward(&self, input: &Variable) -> Variable {
94        self.module.forward(input)
95    }
96}
97
98impl<M: Module> Module for DistributedDataParallel<M> {
99    fn forward(&self, input: &Variable) -> Variable {
100        self.module.forward(input)
101    }
102
103    fn parameters(&self) -> Vec<Parameter> {
104        self.module.parameters()
105    }
106
107    fn train(&mut self) {
108        self.module.train();
109    }
110
111    fn eval(&mut self) {
112        self.module.eval();
113    }
114
115    fn is_training(&self) -> bool {
116        self.module.is_training()
117    }
118}
119
120// =============================================================================
121// GradientBucket
122// =============================================================================
123
124/// A bucket for accumulating gradients before all-reduce.
125pub struct GradientBucket {
126    /// Flattened gradient data.
127    data: Vec<f32>,
128    /// Original shapes and sizes.
129    shapes: Vec<(Vec<usize>, usize)>,
130    /// Capacity in number of elements.
131    capacity: usize,
132}
133
134impl GradientBucket {
135    /// Creates a new gradient bucket.
136    #[must_use] pub fn new(capacity: usize) -> Self {
137        Self {
138            data: Vec::with_capacity(capacity),
139            shapes: Vec::new(),
140            capacity,
141        }
142    }
143
144    /// Checks if the bucket is full.
145    #[must_use] pub fn is_full(&self) -> bool {
146        self.data.len() >= self.capacity
147    }
148
149    /// Checks if the bucket is empty.
150    #[must_use] pub fn is_empty(&self) -> bool {
151        self.data.is_empty()
152    }
153
154    /// Returns the current size.
155    #[must_use] pub fn size(&self) -> usize {
156        self.data.len()
157    }
158
159    /// Adds a tensor to the bucket.
160    pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
161        let data = tensor.to_vec();
162        if self.data.len() + data.len() > self.capacity {
163            return false;
164        }
165
166        self.shapes.push((tensor.shape().to_vec(), data.len()));
167        self.data.extend(data);
168        true
169    }
170
171    /// Returns the flattened data.
172    #[must_use] pub fn data(&self) -> &[f32] {
173        &self.data
174    }
175
176    /// Returns mutable flattened data.
177    pub fn data_mut(&mut self) -> &mut [f32] {
178        &mut self.data
179    }
180
181    /// Clears the bucket.
182    pub fn clear(&mut self) {
183        self.data.clear();
184        self.shapes.clear();
185    }
186
187    /// Extracts tensors back from the bucket.
188    #[must_use] pub fn extract(&self) -> Vec<Tensor<f32>> {
189        let mut result = Vec::new();
190        let mut offset = 0;
191
192        for (shape, size) in &self.shapes {
193            let end = offset + size;
194            let data = self.data[offset..end].to_vec();
195            result.push(Tensor::from_vec(data, shape).unwrap());
196            offset = end;
197        }
198
199        result
200    }
201}
202
203// =============================================================================
204// Gradient Synchronization Strategies
205// =============================================================================
206
207/// Strategy for gradient synchronization.
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum GradSyncStrategy {
210    /// Synchronize after each backward pass.
211    Synchronous,
212    /// Overlap computation and communication.
213    Overlapped,
214    /// No gradient synchronization (for debugging).
215    NoSync,
216}
217
218/// Gradient synchronizer.
219pub struct GradientSynchronizer {
220    strategy: GradSyncStrategy,
221    bucket_size: usize,
222    buckets: Vec<GradientBucket>,
223}
224
225impl GradientSynchronizer {
226    /// Creates a new gradient synchronizer.
227    #[must_use] pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
228        Self {
229            strategy,
230            bucket_size,
231            buckets: Vec::new(),
232        }
233    }
234
235    /// Returns the synchronization strategy.
236    #[must_use] pub fn strategy(&self) -> GradSyncStrategy {
237        self.strategy
238    }
239
240    /// Prepares buckets for gradient accumulation.
241    pub fn prepare(&mut self, num_params: usize) {
242        let num_buckets = num_params.div_ceil(self.bucket_size);
243        self.buckets = (0..num_buckets)
244            .map(|_| GradientBucket::new(self.bucket_size))
245            .collect();
246    }
247
248    /// Adds a gradient to the appropriate bucket.
249    pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
250        if bucket_idx < self.buckets.len() {
251            self.buckets[bucket_idx].add(tensor);
252        }
253    }
254
255    /// Synchronizes all buckets.
256    pub fn sync_all(&mut self, process_group: &ProcessGroup) {
257        if self.strategy == GradSyncStrategy::NoSync {
258            return;
259        }
260
261        for bucket in &mut self.buckets {
262            if !bucket.is_empty() {
263                let mut data = bucket.data().to_vec();
264                let len = data.len();
265                process_group
266                    .backend()
267                    .all_reduce(&mut data, ReduceOp::Average);
268                bucket.data_mut()[..len].copy_from_slice(&data);
269            }
270        }
271    }
272
273    /// Clears all buckets.
274    pub fn clear(&mut self) {
275        for bucket in &mut self.buckets {
276            bucket.clear();
277        }
278    }
279}
280
281impl Default for GradientSynchronizer {
282    fn default() -> Self {
283        Self::new(GradSyncStrategy::Synchronous, 25_000_000) // ~100MB for f32
284    }
285}
286
287// =============================================================================
288// Tests
289// =============================================================================
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use axonml_nn::Linear;
295
296    #[test]
297    fn test_ddp_creation() {
298        let module = Linear::new(10, 5);
299        let pg = ProcessGroup::mock();
300        let ddp = DistributedDataParallel::new(module, pg);
301
302        assert_eq!(ddp.process_group().rank(), 0);
303        assert_eq!(ddp.process_group().world_size(), 1);
304    }
305
306    #[test]
307    fn test_ddp_forward() {
308        let module = Linear::new(4, 2);
309        let pg = ProcessGroup::mock();
310        let ddp = DistributedDataParallel::new(module, pg);
311
312        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
313        let output = ddp.forward(&input);
314
315        assert_eq!(output.data().shape(), &[1, 2]);
316    }
317
318    #[test]
319    fn test_ddp_module_access() {
320        let module = Linear::new(10, 5);
321        let pg = ProcessGroup::mock();
322        let mut ddp = DistributedDataParallel::new(module, pg);
323
324        // Access module
325        let _ = ddp.module();
326        let _ = ddp.module_mut();
327    }
328
329    #[test]
330    fn test_ddp_train_eval() {
331        let module = Linear::new(10, 5);
332        let pg = ProcessGroup::mock();
333        let mut ddp = DistributedDataParallel::new(module, pg);
334
335        // Module trait's default is_training() returns true
336        // Linear doesn't override train/eval behavior
337        assert!(ddp.is_training());
338
339        // Call train/eval - they are forwarded to the wrapped module
340        // but Linear's default implementation doesn't change state
341        ddp.train();
342        ddp.eval();
343
344        // Test that methods can be called without panic
345        let _ = ddp.is_training();
346    }
347
348    #[test]
349    fn test_ddp_parameters() {
350        let module = Linear::new(10, 5);
351        let pg = ProcessGroup::mock();
352        let ddp = DistributedDataParallel::new(module, pg);
353
354        let params = ddp.parameters();
355        assert!(!params.is_empty());
356    }
357
358    #[test]
359    fn test_gradient_bucket() {
360        let mut bucket = GradientBucket::new(100);
361
362        assert!(bucket.is_empty());
363
364        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
365        assert!(bucket.add(&tensor1));
366
367        assert!(!bucket.is_empty());
368        assert_eq!(bucket.size(), 3);
369
370        let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
371        assert!(bucket.add(&tensor2));
372
373        assert_eq!(bucket.size(), 5);
374        assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
375    }
376
377    #[test]
378    fn test_gradient_bucket_extract() {
379        let mut bucket = GradientBucket::new(100);
380
381        let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
382        let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
383
384        bucket.add(&tensor1);
385        bucket.add(&tensor2);
386
387        let extracted = bucket.extract();
388        assert_eq!(extracted.len(), 2);
389        assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
390        assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
391    }
392
393    #[test]
394    fn test_gradient_bucket_full() {
395        let mut bucket = GradientBucket::new(5);
396
397        let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
398        assert!(bucket.add(&tensor1));
399
400        let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
401        assert!(!bucket.add(&tensor2)); // Won't fit
402    }
403
404    #[test]
405    fn test_gradient_bucket_clear() {
406        let mut bucket = GradientBucket::new(100);
407        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
408        bucket.add(&tensor);
409
410        bucket.clear();
411        assert!(bucket.is_empty());
412    }
413
414    #[test]
415    fn test_gradient_synchronizer() {
416        let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
417        sync.prepare(10);
418
419        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
420    }
421
422    #[test]
423    fn test_gradient_synchronizer_no_sync() {
424        let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
425        sync.prepare(10);
426
427        let pg = ProcessGroup::mock();
428        sync.sync_all(&pg); // Should do nothing
429    }
430
431    #[test]
432    fn test_gradient_synchronizer_default() {
433        let sync = GradientSynchronizer::default();
434        assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
435    }
436
437    #[test]
438    fn test_grad_sync_strategy() {
439        assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
440        assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
441    }
442}