1use crate::backend::ReduceOp;
19use crate::process_group::ProcessGroup;
20use axonml_autograd::Variable;
21use axonml_nn::{Module, Parameter};
22use axonml_tensor::Tensor;
23
24pub struct DistributedDataParallel<M: Module> {
33 module: M,
34 process_group: ProcessGroup,
35 broadcast_buffers: bool,
36 gradient_as_bucket_view: bool,
37}
38
39impl<M: Module> DistributedDataParallel<M> {
40 pub fn new(module: M, process_group: ProcessGroup) -> Self {
42 Self {
43 module,
44 process_group,
45 broadcast_buffers: true,
46 gradient_as_bucket_view: true,
47 }
48 }
49
50 pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
52 self.broadcast_buffers = broadcast;
53 self
54 }
55
56 pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
58 self.gradient_as_bucket_view = bucket_view;
59 self
60 }
61
62 pub fn module(&self) -> &M {
64 &self.module
65 }
66
67 pub fn module_mut(&mut self) -> &mut M {
69 &mut self.module
70 }
71
72 pub fn process_group(&self) -> &ProcessGroup {
74 &self.process_group
75 }
76
77 pub fn sync_parameters(&mut self) {
81 for param in self.module.parameters() {
82 let mut tensor = param.data().clone();
83 self.process_group.broadcast_tensor(&mut tensor, 0);
84 param.update_data(tensor);
86 }
87 }
88
89 pub fn sync_gradients(&self) {
93 for param in self.module.parameters() {
94 if let Some(grad) = param.grad() {
95 let mut grad_tensor = grad.clone();
96 self.process_group
97 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
98 param.set_grad(grad_tensor);
100 }
101 }
102 }
103
104 pub fn forward(&self, input: &Variable) -> Variable {
106 self.module.forward(input)
107 }
108}
109
110impl<M: Module> Module for DistributedDataParallel<M> {
111 fn forward(&self, input: &Variable) -> Variable {
112 self.module.forward(input)
113 }
114
115 fn parameters(&self) -> Vec<Parameter> {
116 self.module.parameters()
117 }
118
119 fn train(&mut self) {
120 self.module.train();
121 }
122
123 fn eval(&mut self) {
124 self.module.eval();
125 }
126
127 fn is_training(&self) -> bool {
128 self.module.is_training()
129 }
130}
131
132pub struct GradientBucket {
138 data: Vec<f32>,
140 shapes: Vec<(Vec<usize>, usize)>,
142 capacity: usize,
144}
145
146impl GradientBucket {
147 #[must_use]
149 pub fn new(capacity: usize) -> Self {
150 Self {
151 data: Vec::with_capacity(capacity),
152 shapes: Vec::new(),
153 capacity,
154 }
155 }
156
157 #[must_use]
159 pub fn is_full(&self) -> bool {
160 self.data.len() >= self.capacity
161 }
162
163 #[must_use]
165 pub fn is_empty(&self) -> bool {
166 self.data.is_empty()
167 }
168
169 #[must_use]
171 pub fn size(&self) -> usize {
172 self.data.len()
173 }
174
175 pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
177 let data = tensor.to_vec();
178 if self.data.len() + data.len() > self.capacity {
179 return false;
180 }
181
182 self.shapes.push((tensor.shape().to_vec(), data.len()));
183 self.data.extend(data);
184 true
185 }
186
187 #[must_use]
189 pub fn data(&self) -> &[f32] {
190 &self.data
191 }
192
193 pub fn data_mut(&mut self) -> &mut [f32] {
195 &mut self.data
196 }
197
198 pub fn clear(&mut self) {
200 self.data.clear();
201 self.shapes.clear();
202 }
203
204 #[must_use]
206 pub fn extract(&self) -> Vec<Tensor<f32>> {
207 let mut result = Vec::new();
208 let mut offset = 0;
209
210 for (shape, size) in &self.shapes {
211 let end = offset + size;
212 let data = self.data[offset..end].to_vec();
213 result.push(Tensor::from_vec(data, shape).unwrap());
214 offset = end;
215 }
216
217 result
218 }
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub enum GradSyncStrategy {
228 Synchronous,
230 Overlapped,
232 NoSync,
234}
235
236pub struct GradientSynchronizer {
238 strategy: GradSyncStrategy,
239 bucket_size: usize,
240 buckets: Vec<GradientBucket>,
241}
242
243impl GradientSynchronizer {
244 #[must_use]
246 pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
247 Self {
248 strategy,
249 bucket_size,
250 buckets: Vec::new(),
251 }
252 }
253
254 #[must_use]
256 pub fn strategy(&self) -> GradSyncStrategy {
257 self.strategy
258 }
259
260 pub fn prepare(&mut self, num_params: usize) {
262 let num_buckets = num_params.div_ceil(self.bucket_size);
263 self.buckets = (0..num_buckets)
264 .map(|_| GradientBucket::new(self.bucket_size))
265 .collect();
266 }
267
268 pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
270 if bucket_idx < self.buckets.len() {
271 self.buckets[bucket_idx].add(tensor);
272 }
273 }
274
275 pub fn sync_all(&mut self, process_group: &ProcessGroup) {
277 if self.strategy == GradSyncStrategy::NoSync {
278 return;
279 }
280
281 for bucket in &mut self.buckets {
282 if !bucket.is_empty() {
283 let mut data = bucket.data().to_vec();
284 let len = data.len();
285 process_group
286 .backend()
287 .all_reduce(&mut data, ReduceOp::Average);
288 bucket.data_mut()[..len].copy_from_slice(&data);
289 }
290 }
291 }
292
293 pub fn clear(&mut self) {
295 for bucket in &mut self.buckets {
296 bucket.clear();
297 }
298 }
299}
300
301impl Default for GradientSynchronizer {
302 fn default() -> Self {
303 Self::new(GradSyncStrategy::Synchronous, 25_000_000) }
305}
306
307#[cfg(test)]
312mod tests {
313 use super::*;
314 use axonml_nn::Linear;
315
316 #[test]
317 fn test_ddp_creation() {
318 let module = Linear::new(10, 5);
319 let pg = ProcessGroup::mock();
320 let ddp = DistributedDataParallel::new(module, pg);
321
322 assert_eq!(ddp.process_group().rank(), 0);
323 assert_eq!(ddp.process_group().world_size(), 1);
324 }
325
326 #[test]
327 fn test_ddp_forward() {
328 let module = Linear::new(4, 2);
329 let pg = ProcessGroup::mock();
330 let ddp = DistributedDataParallel::new(module, pg);
331
332 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
333 let output = ddp.forward(&input);
334
335 assert_eq!(output.data().shape(), &[1, 2]);
336 }
337
338 #[test]
339 fn test_ddp_module_access() {
340 let module = Linear::new(10, 5);
341 let pg = ProcessGroup::mock();
342 let mut ddp = DistributedDataParallel::new(module, pg);
343
344 let _ = ddp.module();
346 let _ = ddp.module_mut();
347 }
348
349 #[test]
350 fn test_ddp_train_eval() {
351 let module = Linear::new(10, 5);
352 let pg = ProcessGroup::mock();
353 let mut ddp = DistributedDataParallel::new(module, pg);
354
355 assert!(ddp.is_training());
358
359 ddp.train();
362 ddp.eval();
363
364 let _ = ddp.is_training();
366 }
367
368 #[test]
369 fn test_ddp_parameters() {
370 let module = Linear::new(10, 5);
371 let pg = ProcessGroup::mock();
372 let ddp = DistributedDataParallel::new(module, pg);
373
374 let params = ddp.parameters();
375 assert!(!params.is_empty());
376 }
377
378 #[test]
379 fn test_gradient_bucket() {
380 let mut bucket = GradientBucket::new(100);
381
382 assert!(bucket.is_empty());
383
384 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
385 assert!(bucket.add(&tensor1));
386
387 assert!(!bucket.is_empty());
388 assert_eq!(bucket.size(), 3);
389
390 let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
391 assert!(bucket.add(&tensor2));
392
393 assert_eq!(bucket.size(), 5);
394 assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
395 }
396
397 #[test]
398 fn test_gradient_bucket_extract() {
399 let mut bucket = GradientBucket::new(100);
400
401 let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
402 let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
403
404 bucket.add(&tensor1);
405 bucket.add(&tensor2);
406
407 let extracted = bucket.extract();
408 assert_eq!(extracted.len(), 2);
409 assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
410 assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
411 }
412
413 #[test]
414 fn test_gradient_bucket_full() {
415 let mut bucket = GradientBucket::new(5);
416
417 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
418 assert!(bucket.add(&tensor1));
419
420 let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
421 assert!(!bucket.add(&tensor2)); }
423
424 #[test]
425 fn test_gradient_bucket_clear() {
426 let mut bucket = GradientBucket::new(100);
427 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
428 bucket.add(&tensor);
429
430 bucket.clear();
431 assert!(bucket.is_empty());
432 }
433
434 #[test]
435 fn test_gradient_synchronizer() {
436 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
437 sync.prepare(10);
438
439 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
440 }
441
442 #[test]
443 fn test_gradient_synchronizer_no_sync() {
444 let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
445 sync.prepare(10);
446
447 let pg = ProcessGroup::mock();
448 sync.sync_all(&pg); }
450
451 #[test]
452 fn test_gradient_synchronizer_default() {
453 let sync = GradientSynchronizer::default();
454 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
455 }
456
457 #[test]
458 fn test_grad_sync_strategy() {
459 assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
460 assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
461 }
462}