1use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_autograd::Variable;
11use axonml_nn::{Module, Parameter};
12use axonml_tensor::Tensor;
13
14pub struct DistributedDataParallel<M: Module> {
23 module: M,
24 process_group: ProcessGroup,
25 broadcast_buffers: bool,
26 gradient_as_bucket_view: bool,
27}
28
29impl<M: Module> DistributedDataParallel<M> {
30 pub fn new(module: M, process_group: ProcessGroup) -> Self {
32 Self {
33 module,
34 process_group,
35 broadcast_buffers: true,
36 gradient_as_bucket_view: true,
37 }
38 }
39
40 pub fn broadcast_buffers(mut self, broadcast: bool) -> Self {
42 self.broadcast_buffers = broadcast;
43 self
44 }
45
46 pub fn gradient_as_bucket_view(mut self, bucket_view: bool) -> Self {
48 self.gradient_as_bucket_view = bucket_view;
49 self
50 }
51
52 pub fn module(&self) -> &M {
54 &self.module
55 }
56
57 pub fn module_mut(&mut self) -> &mut M {
59 &mut self.module
60 }
61
62 pub fn process_group(&self) -> &ProcessGroup {
64 &self.process_group
65 }
66
67 pub fn sync_parameters(&mut self) {
70 for param in self.module.parameters() {
72 let mut tensor = param.data().clone();
73 self.process_group.broadcast_tensor(&mut tensor, 0);
74 }
76 }
77
78 pub fn sync_gradients(&self) {
81 for param in self.module.parameters() {
83 if let Some(grad) = param.grad() {
84 let mut grad_tensor = grad.clone();
85 self.process_group
86 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
87 }
89 }
90 }
91
92 pub fn forward(&self, input: &Variable) -> Variable {
94 self.module.forward(input)
95 }
96}
97
98impl<M: Module> Module for DistributedDataParallel<M> {
99 fn forward(&self, input: &Variable) -> Variable {
100 self.module.forward(input)
101 }
102
103 fn parameters(&self) -> Vec<Parameter> {
104 self.module.parameters()
105 }
106
107 fn train(&mut self) {
108 self.module.train();
109 }
110
111 fn eval(&mut self) {
112 self.module.eval();
113 }
114
115 fn is_training(&self) -> bool {
116 self.module.is_training()
117 }
118}
119
120pub struct GradientBucket {
126 data: Vec<f32>,
128 shapes: Vec<(Vec<usize>, usize)>,
130 capacity: usize,
132}
133
134impl GradientBucket {
135 #[must_use]
137 pub fn new(capacity: usize) -> Self {
138 Self {
139 data: Vec::with_capacity(capacity),
140 shapes: Vec::new(),
141 capacity,
142 }
143 }
144
145 #[must_use]
147 pub fn is_full(&self) -> bool {
148 self.data.len() >= self.capacity
149 }
150
151 #[must_use]
153 pub fn is_empty(&self) -> bool {
154 self.data.is_empty()
155 }
156
157 #[must_use]
159 pub fn size(&self) -> usize {
160 self.data.len()
161 }
162
163 pub fn add(&mut self, tensor: &Tensor<f32>) -> bool {
165 let data = tensor.to_vec();
166 if self.data.len() + data.len() > self.capacity {
167 return false;
168 }
169
170 self.shapes.push((tensor.shape().to_vec(), data.len()));
171 self.data.extend(data);
172 true
173 }
174
175 #[must_use]
177 pub fn data(&self) -> &[f32] {
178 &self.data
179 }
180
181 pub fn data_mut(&mut self) -> &mut [f32] {
183 &mut self.data
184 }
185
186 pub fn clear(&mut self) {
188 self.data.clear();
189 self.shapes.clear();
190 }
191
192 #[must_use]
194 pub fn extract(&self) -> Vec<Tensor<f32>> {
195 let mut result = Vec::new();
196 let mut offset = 0;
197
198 for (shape, size) in &self.shapes {
199 let end = offset + size;
200 let data = self.data[offset..end].to_vec();
201 result.push(Tensor::from_vec(data, shape).unwrap());
202 offset = end;
203 }
204
205 result
206 }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum GradSyncStrategy {
216 Synchronous,
218 Overlapped,
220 NoSync,
222}
223
224pub struct GradientSynchronizer {
226 strategy: GradSyncStrategy,
227 bucket_size: usize,
228 buckets: Vec<GradientBucket>,
229}
230
231impl GradientSynchronizer {
232 #[must_use]
234 pub fn new(strategy: GradSyncStrategy, bucket_size: usize) -> Self {
235 Self {
236 strategy,
237 bucket_size,
238 buckets: Vec::new(),
239 }
240 }
241
242 #[must_use]
244 pub fn strategy(&self) -> GradSyncStrategy {
245 self.strategy
246 }
247
248 pub fn prepare(&mut self, num_params: usize) {
250 let num_buckets = num_params.div_ceil(self.bucket_size);
251 self.buckets = (0..num_buckets)
252 .map(|_| GradientBucket::new(self.bucket_size))
253 .collect();
254 }
255
256 pub fn add_gradient(&mut self, bucket_idx: usize, tensor: &Tensor<f32>) {
258 if bucket_idx < self.buckets.len() {
259 self.buckets[bucket_idx].add(tensor);
260 }
261 }
262
263 pub fn sync_all(&mut self, process_group: &ProcessGroup) {
265 if self.strategy == GradSyncStrategy::NoSync {
266 return;
267 }
268
269 for bucket in &mut self.buckets {
270 if !bucket.is_empty() {
271 let mut data = bucket.data().to_vec();
272 let len = data.len();
273 process_group
274 .backend()
275 .all_reduce(&mut data, ReduceOp::Average);
276 bucket.data_mut()[..len].copy_from_slice(&data);
277 }
278 }
279 }
280
281 pub fn clear(&mut self) {
283 for bucket in &mut self.buckets {
284 bucket.clear();
285 }
286 }
287}
288
289impl Default for GradientSynchronizer {
290 fn default() -> Self {
291 Self::new(GradSyncStrategy::Synchronous, 25_000_000) }
293}
294
295#[cfg(test)]
300mod tests {
301 use super::*;
302 use axonml_nn::Linear;
303
304 #[test]
305 fn test_ddp_creation() {
306 let module = Linear::new(10, 5);
307 let pg = ProcessGroup::mock();
308 let ddp = DistributedDataParallel::new(module, pg);
309
310 assert_eq!(ddp.process_group().rank(), 0);
311 assert_eq!(ddp.process_group().world_size(), 1);
312 }
313
314 #[test]
315 fn test_ddp_forward() {
316 let module = Linear::new(4, 2);
317 let pg = ProcessGroup::mock();
318 let ddp = DistributedDataParallel::new(module, pg);
319
320 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
321 let output = ddp.forward(&input);
322
323 assert_eq!(output.data().shape(), &[1, 2]);
324 }
325
326 #[test]
327 fn test_ddp_module_access() {
328 let module = Linear::new(10, 5);
329 let pg = ProcessGroup::mock();
330 let mut ddp = DistributedDataParallel::new(module, pg);
331
332 let _ = ddp.module();
334 let _ = ddp.module_mut();
335 }
336
337 #[test]
338 fn test_ddp_train_eval() {
339 let module = Linear::new(10, 5);
340 let pg = ProcessGroup::mock();
341 let mut ddp = DistributedDataParallel::new(module, pg);
342
343 assert!(ddp.is_training());
346
347 ddp.train();
350 ddp.eval();
351
352 let _ = ddp.is_training();
354 }
355
356 #[test]
357 fn test_ddp_parameters() {
358 let module = Linear::new(10, 5);
359 let pg = ProcessGroup::mock();
360 let ddp = DistributedDataParallel::new(module, pg);
361
362 let params = ddp.parameters();
363 assert!(!params.is_empty());
364 }
365
366 #[test]
367 fn test_gradient_bucket() {
368 let mut bucket = GradientBucket::new(100);
369
370 assert!(bucket.is_empty());
371
372 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
373 assert!(bucket.add(&tensor1));
374
375 assert!(!bucket.is_empty());
376 assert_eq!(bucket.size(), 3);
377
378 let tensor2 = Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap();
379 assert!(bucket.add(&tensor2));
380
381 assert_eq!(bucket.size(), 5);
382 assert_eq!(bucket.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
383 }
384
385 #[test]
386 fn test_gradient_bucket_extract() {
387 let mut bucket = GradientBucket::new(100);
388
389 let tensor1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
390 let tensor2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
391
392 bucket.add(&tensor1);
393 bucket.add(&tensor2);
394
395 let extracted = bucket.extract();
396 assert_eq!(extracted.len(), 2);
397 assert_eq!(extracted[0].to_vec(), vec![1.0, 2.0]);
398 assert_eq!(extracted[1].to_vec(), vec![3.0, 4.0, 5.0]);
399 }
400
401 #[test]
402 fn test_gradient_bucket_full() {
403 let mut bucket = GradientBucket::new(5);
404
405 let tensor1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
406 assert!(bucket.add(&tensor1));
407
408 let tensor2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
409 assert!(!bucket.add(&tensor2)); }
411
412 #[test]
413 fn test_gradient_bucket_clear() {
414 let mut bucket = GradientBucket::new(100);
415 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
416 bucket.add(&tensor);
417
418 bucket.clear();
419 assert!(bucket.is_empty());
420 }
421
422 #[test]
423 fn test_gradient_synchronizer() {
424 let mut sync = GradientSynchronizer::new(GradSyncStrategy::Synchronous, 100);
425 sync.prepare(10);
426
427 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
428 }
429
430 #[test]
431 fn test_gradient_synchronizer_no_sync() {
432 let mut sync = GradientSynchronizer::new(GradSyncStrategy::NoSync, 100);
433 sync.prepare(10);
434
435 let pg = ProcessGroup::mock();
436 sync.sync_all(&pg); }
438
439 #[test]
440 fn test_gradient_synchronizer_default() {
441 let sync = GradientSynchronizer::default();
442 assert_eq!(sync.strategy(), GradSyncStrategy::Synchronous);
443 }
444
445 #[test]
446 fn test_grad_sync_strategy() {
447 assert_eq!(GradSyncStrategy::Synchronous, GradSyncStrategy::Synchronous);
448 assert_ne!(GradSyncStrategy::Synchronous, GradSyncStrategy::NoSync);
449 }
450}