1use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_autograd::Variable;
11use axonml_nn::{Module, Parameter};
12use axonml_tensor::Tensor;
13
14pub struct DistributedDataParallel<M: Module> {
23 module: M,
24 process_group: ProcessGroup,
25 broadcast_buffers: bool,
26 gradient_as_bucket_view: bool,
27}
28
29impl<M: Module> DistributedDataParallel<M> {
30 pub fn new(module: M, process_group: ProcessGroup) -> Self {
32 Self {
33 module,
34 process_group,
35 broadcast_buffers: true,
36 gradient_as_bucket_view: true,
37 }
38 }
39
40 pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
42 self.broadcast_buffers = broadcast;
43 self
44 }
45
46 pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
48 self.gradient_as_bucket_view = bucket_view;
49 self
50 }
51
52 pub fn module(&self) -> &M {
54 &self.module
55 }
56
57 pub fn module_mut(&mut self) -> &mut M {
59 &mut self.module
60 }
61
62 pub fn process_group(&self) -> &ProcessGroup {
64 &self.process_group
65 }
66
67 pub fn sync_parameters(&mut self) {
70 for param in self.module.parameters() {
72 let mut tensor = param.data().clone();
73 self.process_group.broadcast_tensor(&mut tensor, 0);
74 }
76 }
77
78 pub fn sync_gradients(&self) {
81 for param in self.module.parameters() {
83 if let Some(grad) = param.grad() {
84 let mut grad_tensor = grad.clone();
85 self.process_group
86 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
87 }
89 }
90 }
91
92 pub fn forward(&self, input: &Variable) -> Variable {
94 self.module.forward(input)
95 }
96}
97
98impl<M: Module> Module for DistributedDataParallel<M> {
99 fn forward(&self, input: &Variable) -> Variable {
100 self.module.forward(input)
101 }
102
103 fn parameters(&self) -> Vec<Parameter> {
104 self.module.parameters()
105 }
106
107 fn train(&mut self) {
108 self.module.train();
109 }
110
111 fn eval(&mut self) {
112 self.module.eval();
113 }
114
115 fn is_training(&self) -> bool {
116 self.module.is_training()
117 }
118}
119
120pub struct GradientBucket {
126 data: Vec<f32>,
128 shapes: Vec<(Vec<usize>, usize)>,
130 capacity: usize,
132}
133
134impl GradientBucket {
135 #[must_use] pub fn new(capacity: usize) -> Self {
137 Self {
138 data: Vec::with_capacity(capacity),
139 shapes: Vec::new(),
140 capacity,
141 }
142 }
143
144 #[must_use] pub fn is_full(&self) -> bool {
146 self.data.len() >= self.capacity
147 }
148
149 #[must_use] pub fn is_empty(&self) -> bool {
151 self.data.is_empty()
152 }
153
154 #[must_use] pub fn size(&self) -> usize {
156 self.data.len()
157 }
158
159 pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
161 let data = tensor.to_vec();
162 if self.data.len() + data.len() > self.capacity {
163 return false;
164 }
165
166 self.shapes.push((tensor.shape().to_vec(), data.len()));
167 self.data.extend(data);
168 true
169 }
170
171 #[must_use] pub fn data(&self) -> &[f32] {
173 &self.data
174 }
175
176 pub fn data_mut(&mut self) -> &mut [f32] {
178 &mut self.data
179 }
180
181 pub fn clear(&mut self) {
183 self.data.clear();
184 self.shapes.clear();
185 }
186
187 #[must_use] pub fn extract(&self) -> Vec<Tensor<f32>> {
189 let mut result = Vec::new();
190 let mut offset = 0;
191
192 for (shape, size) in &self.shapes {
193 let end = offset + size;
194 let data = self.data[offset..end].to_vec();
195 result.push(Tensor::from_vec(data, shape).unwrap());
196 offset = end;
197 }
198
199 result
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum GradSyncStrategy {
210 Synchronous,
212 Overlapped,
214 NoSync,
216}
217
218pub struct GradientSynchronizer {
220 strategy: GradSyncStrategy,
221 bucket_size: usize,
222 buckets: Vec<GradientBucket>,
223}
224
225impl GradientSynchronizer {
226 #[must_use] pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
228 Self {
229 strategy,
230 bucket_size,
231 buckets: Vec::new(),
232 }
233 }
234
235 #[must_use] pub fn strategy(&self) -> GradSyncStrategy {
237 self.strategy
238 }
239
240 pub fn prepare(&mut self, num_params: usize) {
242 let num_buckets = num_params.div_ceil(self.bucket_size);
243 self.buckets = (0..num_buckets)
244 .map(|_| GradientBucket::new(self.bucket_size))
245 .collect();
246 }
247
248 pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
250 if bucket_idx < self.buckets.len() {
251 self.buckets[bucket_idx].add(tensor);
252 }
253 }
254
255 pub fn sync_all(&mut self, process_group: &ProcessGroup) {
257 if self.strategy == GradSyncStrategy::NoSync {
258 return;
259 }
260
261 for bucket in &mut self.buckets {
262 if !bucket.is_empty() {
263 let mut data = bucket.data().to_vec();
264 let len = data.len();
265 process_group
266 .backend()
267 .all_reduce(&mut data, ReduceOp::Average);
268 bucket.data_mut()[..len].copy_from_slice(&data);
269 }
270 }
271 }
272
273 pub fn clear(&mut self) {
275 for bucket in &mut self.buckets {
276 bucket.clear();
277 }
278 }
279}
280
281impl Default for GradientSynchronizer {
282 fn default() -> Self {
283 Self::new(GradSyncStrategy::Synchronous, 25_000_000) }
285}
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294 use axonml_nn::Linear;
295
296 #[test]
297 fn test_ddp_creation() {
298 let module = Linear::new(10, 5);
299 let pg = ProcessGroup::mock();
300 let ddp = DistributedDataParallel::new(module, pg);
301
302 assert_eq!(ddp.process_group().rank(), 0);
303 assert_eq!(ddp.process_group().world_size(), 1);
304 }
305
306 #[test]
307 fn test_ddp_forward() {
308 let module = Linear::new(4, 2);
309 let pg = ProcessGroup::mock();
310 let ddp = DistributedDataParallel::new(module, pg);
311
312 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
313 let output = ddp.forward(&input);
314
315 assert_eq!(output.data().shape(), &[1, 2]);
316 }
317
318 #[test]
319 fn test_ddp_module_access() {
320 let module = Linear::new(10, 5);
321 let pg = ProcessGroup::mock();
322 let mut ddp = DistributedDataParallel::new(module, pg);
323
324 let _ = ddp.module();
326 let _ = ddp.module_mut();
327 }
328
329 #[test]
330 fn test_ddp_train_eval() {
331 let module = Linear::new(10, 5);
332 let pg = ProcessGroup::mock();
333 let mut ddp = DistributedDataParallel::new(module, pg);
334
335 assert!(ddp.is_training());
338
339 ddp.train();
342 ddp.eval();
343
344 let _ = ddp.is_training();
346 }
347
348 #[test]
349 fn test_ddp_parameters() {
350 let module = Linear::new(10, 5);
351 let pg = ProcessGroup::mock();
352 let ddp = DistributedDataParallel::new(module, pg);
353
354 let params = ddp.parameters();
355 assert!(!params.is_empty());
356 }
357
358 #[test]
359 fn test_gradient_bucket() {
360 let mut bucket = GradientBucket::new(100);
361
362 assert!(bucket.is_empty());
363
364 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
365 assert!(bucket.add(&tensor1));
366
367 assert!(!bucket.is_empty());
368 assert_eq!(bucket.size(), 3);
369
370 let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
371 assert!(bucket.add(&tensor2));
372
373 assert_eq!(bucket.size(), 5);
374 assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
375 }
376
377 #[test]
378 fn test_gradient_bucket_extract() {
379 let mut bucket = GradientBucket::new(100);
380
381 let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
382 let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
383
384 bucket.add(&tensor1);
385 bucket.add(&tensor2);
386
387 let extracted = bucket.extract();
388 assert_eq!(extracted.len(), 2);
389 assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
390 assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
391 }
392
393 #[test]
394 fn test_gradient_bucket_full() {
395 let mut bucket = GradientBucket::new(5);
396
397 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
398 assert!(bucket.add(&tensor1));
399
400 let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
401 assert!(!bucket.add(&tensor2)); }
403
404 #[test]
405 fn test_gradient_bucket_clear() {
406 let mut bucket = GradientBucket::new(100);
407 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
408 bucket.add(&tensor);
409
410 bucket.clear();
411 assert!(bucket.is_empty());
412 }
413
414 #[test]
415 fn test_gradient_synchronizer() {
416 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
417 sync.prepare(10);
418
419 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
420 }
421
422 #[test]
423 fn test_gradient_synchronizer_no_sync() {
424 let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
425 sync.prepare(10);
426
427 let pg = ProcessGroup::mock();
428 sync.sync_all(&pg); }
430
431 #[test]
432 fn test_gradient_synchronizer_default() {
433 let sync = GradientSynchronizer::default();
434 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
435 }
436
437 #[test]
438 fn test_grad_sync_strategy() {
439 assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
440 assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
441 }
442}