1use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23pub struct DistributedDataParallel<M: Module> {
32 module: M,
33 process_group: ProcessGroup,
34 broadcast_buffers: bool,
35 gradient_as_bucket_view: bool,
36}
37
38impl<M: Module> DistributedDataParallel<M> {
39 pub fn new(module: M, process_group: ProcessGroup) -> Self {
41 Self {
42 module,
43 process_group,
44 broadcast_buffers: true,
45 gradient_as_bucket_view: true,
46 }
47 }
48
49 pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
51 self.broadcast_buffers = broadcast;
52 self
53 }
54
55 pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
57 self.gradient_as_bucket_view = bucket_view;
58 self
59 }
60
61 pub fn module(&self) -> &M {
63 &self.module
64 }
65
66 pub fn module_mut(&mut self) -> &mut M {
68 &mut self.module
69 }
70
71 pub fn process_group(&self) -> &ProcessGroup {
73 &self.process_group
74 }
75
76 pub fn sync_parameters(&mut self) {
80 for param in self.module.parameters() {
81 let mut tensor = param.data().clone();
82 self.process_group.broadcast_tensor(&mut tensor, 0);
83 param.update_data(tensor);
85 }
86 }
87
88 pub fn sync_gradients(&self) {
92 for param in self.module.parameters() {
93 if let Some(grad) = param.grad() {
94 let mut grad_tensor = grad.clone();
95 self.process_group
96 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
97 param.set_grad(grad_tensor);
99 }
100 }
101 }
102
103 pub fn forward(&self, input: &Variable) -> Variable {
105 self.module.forward(input)
106 }
107}
108
109impl<M: Module> Module for DistributedDataParallel<M> {
110 fn forward(&self, input: &Variable) -> Variable {
111 self.module.forward(input)
112 }
113
114 fn parameters(&self) -> Vec<Parameter> {
115 self.module.parameters()
116 }
117
118 fn train(&mut self) {
119 self.module.train();
120 }
121
122 fn eval(&mut self) {
123 self.module.eval();
124 }
125
126 fn is_training(&self) -> bool {
127 self.module.is_training()
128 }
129}
130
131pub struct GradientBucket {
137 data: Vec<f32>,
139 shapes: Vec<(Vec<usize>, usize)>,
141 capacity: usize,
143}
144
145impl GradientBucket {
146 #[must_use]
148 pub fn new(capacity: usize) -> Self {
149 Self {
150 data: Vec::with_capacity(capacity),
151 shapes: Vec::new(),
152 capacity,
153 }
154 }
155
156 #[must_use]
158 pub fn is_full(&self) -> bool {
159 self.data.len() >= self.capacity
160 }
161
162 #[must_use]
164 pub fn is_empty(&self) -> bool {
165 self.data.is_empty()
166 }
167
168 #[must_use]
170 pub fn size(&self) -> usize {
171 self.data.len()
172 }
173
174 pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
176 let data = tensor.to_vec();
177 if self.data.len() + data.len() > self.capacity {
178 return false;
179 }
180
181 self.shapes.push((tensor.shape().to_vec(), data.len()));
182 self.data.extend(data);
183 true
184 }
185
186 #[must_use]
188 pub fn data(&self) -> &[f32] {
189 &self.data
190 }
191
192 pub fn data_mut(&mut self) -> &mut [f32] {
194 &mut self.data
195 }
196
197 pub fn clear(&mut self) {
199 self.data.clear();
200 self.shapes.clear();
201 }
202
203 #[must_use]
205 pub fn extract(&self) -> Vec<Tensor<f32>> {
206 let mut result = Vec::new();
207 let mut offset = 0;
208
209 for (shape, size) in &self.shapes {
210 let end = offset + size;
211 let data = self.data[offset..end].to_vec();
212 result.push(Tensor::from_vec(data, shape).unwrap());
213 offset = end;
214 }
215
216 result
217 }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum GradSyncStrategy {
227 Synchronous,
229 Overlapped,
231 NoSync,
233}
234
235pub struct GradientSynchronizer {
237 strategy: GradSyncStrategy,
238 bucket_size: usize,
239 buckets: Vec<GradientBucket>,
240}
241
242impl GradientSynchronizer {
243 #[must_use]
245 pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
246 Self {
247 strategy,
248 bucket_size,
249 buckets: Vec::new(),
250 }
251 }
252
253 #[must_use]
255 pub fn strategy(&self) -> GradSyncStrategy {
256 self.strategy
257 }
258
259 pub fn prepare(&mut self, num_params: usize) {
261 let num_buckets = num_params.div_ceil(self.bucket_size);
262 self.buckets = (0..num_buckets)
263 .map(|_| GradientBucket::new(self.bucket_size))
264 .collect();
265 }
266
267 pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
269 if bucket_idx < self.buckets.len() {
270 self.buckets[bucket_idx].add(tensor);
271 }
272 }
273
274 pub fn sync_all(&mut self, process_group: &ProcessGroup) {
276 if self.strategy == GradSyncStrategy::NoSync {
277 return;
278 }
279
280 for bucket in &mut self.buckets {
281 if !bucket.is_empty() {
282 let mut data = bucket.data().to_vec();
283 let len = data.len();
284 process_group
285 .backend()
286 .all_reduce(&mut data, ReduceOp::Average);
287 bucket.data_mut()[..len].copy_from_slice(&data);
288 }
289 }
290 }
291
292 pub fn clear(&mut self) {
294 for bucket in &mut self.buckets {
295 bucket.clear();
296 }
297 }
298}
299
300impl Default for GradientSynchronizer {
301 fn default() -> Self {
302 Self::new(GradSyncStrategy::Synchronous, 25_000_000) }
304}
305
306#[cfg(test)]
311mod tests {
312 use super::*;
313 use axonml_nn::Linear;
314
315 #[test]
316 fn test_ddp_creation() {
317 let module = Linear::new(10, 5);
318 let pg = ProcessGroup::mock();
319 let ddp = DistributedDataParallel::new(module, pg);
320
321 assert_eq!(ddp.process_group().rank(), 0);
322 assert_eq!(ddp.process_group().world_size(), 1);
323 }
324
325 #[test]
326 fn test_ddp_forward() {
327 let module = Linear::new(4, 2);
328 let pg = ProcessGroup::mock();
329 let ddp = DistributedDataParallel::new(module, pg);
330
331 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
332 let output = ddp.forward(&input);
333
334 assert_eq!(output.data().shape(), &[1, 2]);
335 }
336
337 #[test]
338 fn test_ddp_module_access() {
339 let module = Linear::new(10, 5);
340 let pg = ProcessGroup::mock();
341 let mut ddp = DistributedDataParallel::new(module, pg);
342
343 let _ = ddp.module();
345 let _ = ddp.module_mut();
346 }
347
348 #[test]
349 fn test_ddp_train_eval() {
350 let module = Linear::new(10, 5);
351 let pg = ProcessGroup::mock();
352 let mut ddp = DistributedDataParallel::new(module, pg);
353
354 assert!(ddp.is_training());
357
358 ddp.train();
361 ddp.eval();
362
363 let _ = ddp.is_training();
365 }
366
367 #[test]
368 fn test_ddp_parameters() {
369 let module = Linear::new(10, 5);
370 let pg = ProcessGroup::mock();
371 let ddp = DistributedDataParallel::new(module, pg);
372
373 let params = ddp.parameters();
374 assert!(!params.is_empty());
375 }
376
377 #[test]
378 fn test_gradient_bucket() {
379 let mut bucket = GradientBucket::new(100);
380
381 assert!(bucket.is_empty());
382
383 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
384 assert!(bucket.add(&tensor1));
385
386 assert!(!bucket.is_empty());
387 assert_eq!(bucket.size(), 3);
388
389 let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
390 assert!(bucket.add(&tensor2));
391
392 assert_eq!(bucket.size(), 5);
393 assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
394 }
395
396 #[test]
397 fn test_gradient_bucket_extract() {
398 let mut bucket = GradientBucket::new(100);
399
400 let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
401 let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
402
403 bucket.add(&tensor1);
404 bucket.add(&tensor2);
405
406 let extracted = bucket.extract();
407 assert_eq!(extracted.len(), 2);
408 assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
409 assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
410 }
411
412 #[test]
413 fn test_gradient_bucket_full() {
414 let mut bucket = GradientBucket::new(5);
415
416 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
417 assert!(bucket.add(&tensor1));
418
419 let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
420 assert!(!bucket.add(&tensor2)); }
422
423 #[test]
424 fn test_gradient_bucket_clear() {
425 let mut bucket = GradientBucket::new(100);
426 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
427 bucket.add(&tensor);
428
429 bucket.clear();
430 assert!(bucket.is_empty());
431 }
432
433 #[test]
434 fn test_gradient_synchronizer() {
435 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
436 sync.prepare(10);
437
438 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
439 }
440
441 #[test]
442 fn test_gradient_synchronizer_no_sync() {
443 let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
444 sync.prepare(10);
445
446 let pg = ProcessGroup::mock();
447 sync.sync_all(&pg); }
449
450 #[test]
451 fn test_gradient_synchronizer_default() {
452 let sync = GradientSynchronizer::default();
453 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
454 }
455
456 #[test]
457 fn test_grad_sync_strategy() {
458 assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
459 assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
460 }
461}