1use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23pub struct DistributedDataParallel<M: Module> {
32 module: M,
33 process_group: ProcessGroup,
34 broadcast_buffers: bool,
35 gradient_as_bucket_view: bool,
36}
37
38impl<M: Module> DistributedDataParallel<M> {
39 pub fn new(module: M, process_group: ProcessGroup) -> Self {
41 Self {
42 module,
43 process_group,
44 broadcast_buffers: true,
45 gradient_as_bucket_view: true,
46 }
47 }
48
49 pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
51 self.broadcast_buffers = broadcast;
52 self
53 }
54
55 pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
57 self.gradient_as_bucket_view = bucket_view;
58 self
59 }
60
61 pub fn module(&self) -> &M {
63 &self.module
64 }
65
66 pub fn module_mut(&mut self) -> &mut M {
68 &mut self.module
69 }
70
71 pub fn process_group(&self) -> &ProcessGroup {
73 &self.process_group
74 }
75
76 pub fn sync_parameters(&mut self) {
79 for param in self.module.parameters() {
81 let mut tensor = param.data().clone();
82 self.process_group.broadcast_tensor(&mut tensor, 0);
83 }
85 }
86
87 pub fn sync_gradients(&self) {
90 for param in self.module.parameters() {
92 if let Some(grad) = param.grad() {
93 let mut grad_tensor = grad.clone();
94 self.process_group
95 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
96 }
98 }
99 }
100
101 pub fn forward(&self, input: &Variable) -> Variable {
103 self.module.forward(input)
104 }
105}
106
107impl<M: Module> Module for DistributedDataParallel<M> {
108 fn forward(&self, input: &Variable) -> Variable {
109 self.module.forward(input)
110 }
111
112 fn parameters(&self) -> Vec<Parameter> {
113 self.module.parameters()
114 }
115
116 fn train(&mut self) {
117 self.module.train();
118 }
119
120 fn eval(&mut self) {
121 self.module.eval();
122 }
123
124 fn is_training(&self) -> bool {
125 self.module.is_training()
126 }
127}
128
129pub struct GradientBucket {
135 data: Vec<f32>,
137 shapes: Vec<(Vec<usize>, usize)>,
139 capacity: usize,
141}
142
143impl GradientBucket {
144 #[must_use]
146 pub fn new(capacity: usize) -> Self {
147 Self {
148 data: Vec::with_capacity(capacity),
149 shapes: Vec::new(),
150 capacity,
151 }
152 }
153
154 #[must_use]
156 pub fn is_full(&self) -> bool {
157 self.data.len() >= self.capacity
158 }
159
160 #[must_use]
162 pub fn is_empty(&self) -> bool {
163 self.data.is_empty()
164 }
165
166 #[must_use]
168 pub fn size(&self) -> usize {
169 self.data.len()
170 }
171
172 pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
174 let data = tensor.to_vec();
175 if self.data.len() + data.len() > self.capacity {
176 return false;
177 }
178
179 self.shapes.push((tensor.shape().to_vec(), data.len()));
180 self.data.extend(data);
181 true
182 }
183
184 #[must_use]
186 pub fn data(&self) -> &[f32] {
187 &self.data
188 }
189
190 pub fn data_mut(&mut self) -> &mut [f32] {
192 &mut self.data
193 }
194
195 pub fn clear(&mut self) {
197 self.data.clear();
198 self.shapes.clear();
199 }
200
201 #[must_use]
203 pub fn extract(&self) -> Vec<Tensor<f32>> {
204 let mut result = Vec::new();
205 let mut offset = 0;
206
207 for (shape, size) in &self.shapes {
208 let end = offset + size;
209 let data = self.data[offset..end].to_vec();
210 result.push(Tensor::from_vec(data, shape).unwrap());
211 offset = end;
212 }
213
214 result
215 }
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum GradSyncStrategy {
225 Synchronous,
227 Overlapped,
229 NoSync,
231}
232
233pub struct GradientSynchronizer {
235 strategy: GradSyncStrategy,
236 bucket_size: usize,
237 buckets: Vec<GradientBucket>,
238}
239
240impl GradientSynchronizer {
241 #[must_use]
243 pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
244 Self {
245 strategy,
246 bucket_size,
247 buckets: Vec::new(),
248 }
249 }
250
251 #[must_use]
253 pub fn strategy(&self) -> GradSyncStrategy {
254 self.strategy
255 }
256
257 pub fn prepare(&mut self, num_params: usize) {
259 let num_buckets = num_params.div_ceil(self.bucket_size);
260 self.buckets = (0..num_buckets)
261 .map(|_| GradientBucket::new(self.bucket_size))
262 .collect();
263 }
264
265 pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
267 if bucket_idx < self.buckets.len() {
268 self.buckets[bucket_idx].add(tensor);
269 }
270 }
271
272 pub fn sync_all(&mut self, process_group: &ProcessGroup) {
274 if self.strategy == GradSyncStrategy::NoSync {
275 return;
276 }
277
278 for bucket in &mut self.buckets {
279 if !bucket.is_empty() {
280 let mut data = bucket.data().to_vec();
281 let len = data.len();
282 process_group
283 .backend()
284 .all_reduce(&mut data, ReduceOp::Average);
285 bucket.data_mut()[..len].copy_from_slice(&data);
286 }
287 }
288 }
289
290 pub fn clear(&mut self) {
292 for bucket in &mut self.buckets {
293 bucket.clear();
294 }
295 }
296}
297
298impl Default for GradientSynchronizer {
299 fn default() -> Self {
300 Self::new(GradSyncStrategy::Synchronous, 25_000_000) }
302}
303
304#[cfg(test)]
309mod tests {
310 use super::*;
311 use axonml_nn::Linear;
312
313 #[test]
314 fn test_ddp_creation() {
315 let module = Linear::new(10, 5);
316 let pg = ProcessGroup::mock();
317 let ddp = DistributedDataParallel::new(module, pg);
318
319 assert_eq!(ddp.process_group().rank(), 0);
320 assert_eq!(ddp.process_group().world_size(), 1);
321 }
322
323 #[test]
324 fn test_ddp_forward() {
325 let module = Linear::new(4, 2);
326 let pg = ProcessGroup::mock();
327 let ddp = DistributedDataParallel::new(module, pg);
328
329 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
330 let output = ddp.forward(&input);
331
332 assert_eq!(output.data().shape(), &[1, 2]);
333 }
334
335 #[test]
336 fn test_ddp_module_access() {
337 let module = Linear::new(10, 5);
338 let pg = ProcessGroup::mock();
339 let mut ddp = DistributedDataParallel::new(module, pg);
340
341 let _ = ddp.module();
343 let _ = ddp.module_mut();
344 }
345
346 #[test]
347 fn test_ddp_train_eval() {
348 let module = Linear::new(10, 5);
349 let pg = ProcessGroup::mock();
350 let mut ddp = DistributedDataParallel::new(module, pg);
351
352 assert!(ddp.is_training());
355
356 ddp.train();
359 ddp.eval();
360
361 let _ = ddp.is_training();
363 }
364
365 #[test]
366 fn test_ddp_parameters() {
367 let module = Linear::new(10, 5);
368 let pg = ProcessGroup::mock();
369 let ddp = DistributedDataParallel::new(module, pg);
370
371 let params = ddp.parameters();
372 assert!(!params.is_empty());
373 }
374
375 #[test]
376 fn test_gradient_bucket() {
377 let mut bucket = GradientBucket::new(100);
378
379 assert!(bucket.is_empty());
380
381 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
382 assert!(bucket.add(&tensor1));
383
384 assert!(!bucket.is_empty());
385 assert_eq!(bucket.size(), 3);
386
387 let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
388 assert!(bucket.add(&tensor2));
389
390 assert_eq!(bucket.size(), 5);
391 assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
392 }
393
394 #[test]
395 fn test_gradient_bucket_extract() {
396 let mut bucket = GradientBucket::new(100);
397
398 let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
399 let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
400
401 bucket.add(&tensor1);
402 bucket.add(&tensor2);
403
404 let extracted = bucket.extract();
405 assert_eq!(extracted.len(), 2);
406 assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
407 assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
408 }
409
410 #[test]
411 fn test_gradient_bucket_full() {
412 let mut bucket = GradientBucket::new(5);
413
414 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
415 assert!(bucket.add(&tensor1));
416
417 let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
418 assert!(!bucket.add(&tensor2)); }
420
421 #[test]
422 fn test_gradient_bucket_clear() {
423 let mut bucket = GradientBucket::new(100);
424 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
425 bucket.add(&tensor);
426
427 bucket.clear();
428 assert!(bucket.is_empty());
429 }
430
431 #[test]
432 fn test_gradient_synchronizer() {
433 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
434 sync.prepare(10);
435
436 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
437 }
438
439 #[test]
440 fn test_gradient_synchronizer_no_sync() {
441 let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
442 sync.prepare(10);
443
444 let pg = ProcessGroup::mock();
445 sync.sync_all(&pg); }
447
448 #[test]
449 fn test_gradient_synchronizer_default() {
450 let sync = GradientSynchronizer::default();
451 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
452 }
453
454 #[test]
455 fn test_grad_sync_strategy() {
456 assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
457 assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
458 }
459}