1use crate::backend::ReduceOp;
19use crate::process_group::ProcessGroup;
20use axonml_autograd::Variable;
21use axonml_nn::{Module, Parameter};
22use axonml_tensor::Tensor;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
30pub enum ShardingStrategy {
31 #[default]
33 FullShard,
34 ShardGradOp,
36 NoShard,
38 HybridShard,
40}
41
42#[derive(Debug)]
48struct ShardedParam {
49 local_shard: Tensor<f32>,
51 original_shape: Vec<usize>,
53 numel: usize,
55 #[allow(dead_code)]
58 pub padding: usize,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
63pub enum CPUOffload {
64 #[default]
66 None,
67 Params,
69 Full,
71}
72
73pub struct FullyShardedDataParallel<M: Module> {
82 module: M,
84 process_group: ProcessGroup,
86 sharding_strategy: ShardingStrategy,
88 cpu_offload: CPUOffload,
90 sharded_params: Vec<ShardedParam>,
92 is_gathered: bool,
94 mixed_precision: bool,
96}
97
98impl<M: Module> FullyShardedDataParallel<M> {
99 pub fn new(module: M, process_group: ProcessGroup) -> Self {
101 let mut fsdp = Self {
102 module,
103 process_group,
104 sharding_strategy: ShardingStrategy::default(),
105 cpu_offload: CPUOffload::default(),
106 sharded_params: Vec::new(),
107 is_gathered: true,
108 mixed_precision: false,
109 };
110
111 fsdp.shard_parameters();
113 fsdp
114 }
115
116 pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
118 self.sharding_strategy = strategy;
119 self.shard_parameters();
120 self
121 }
122
123 pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
125 self.cpu_offload = offload;
126 self
127 }
128
129 pub fn mixed_precision(mut self, enabled: bool) -> Self {
131 self.mixed_precision = enabled;
132 self
133 }
134
135 pub fn module(&self) -> &M {
137 &self.module
138 }
139
140 pub fn module_mut(&mut self) -> &mut M {
142 &mut self.module
143 }
144
145 pub fn process_group(&self) -> &ProcessGroup {
147 &self.process_group
148 }
149
150 pub fn strategy(&self) -> ShardingStrategy {
152 self.sharding_strategy
153 }
154
155 fn shard_parameters(&mut self) {
157 if self.sharding_strategy == ShardingStrategy::NoShard {
158 return;
159 }
160
161 let world_size = self.process_group.world_size();
162 let rank = self.process_group.rank();
163
164 self.sharded_params.clear();
165
166 for param in self.module.parameters() {
167 let data = param.data();
168 let shape = data.shape().to_vec();
169 let numel = data.numel();
170
171 let shard_size = numel.div_ceil(world_size);
173 let padding = shard_size * world_size - numel;
174
175 let flat_data = data.to_vec();
177 let start = rank * shard_size;
178 let end = ((rank + 1) * shard_size).min(flat_data.len());
179
180 let mut shard_data: Vec<f32> = if start < flat_data.len() {
181 flat_data[start..end].to_vec()
182 } else {
183 vec![0.0; shard_size]
184 };
185
186 while shard_data.len() < shard_size {
188 shard_data.push(0.0);
189 }
190
191 self.sharded_params.push(ShardedParam {
192 local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
193 original_shape: shape,
194 numel,
195 padding,
196 });
197 }
198
199 self.is_gathered = false;
200 }
201
202 pub fn gather_parameters(&mut self) {
204 if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
205 return;
206 }
207
208 let params = self.module.parameters();
209
210 for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
211 let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
213
214 let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
216 let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
217
218 param.update_data(restored);
219 }
220
221 self.is_gathered = true;
222 }
223
224 pub fn reshard_parameters(&mut self) {
226 if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
227 return;
228 }
229
230 self.shard_parameters();
231 }
232
233 pub fn sync_gradients(&self) {
239 match self.sharding_strategy {
240 ShardingStrategy::NoShard => {
241 for param in self.module.parameters() {
243 if let Some(grad) = param.grad() {
244 let mut grad_tensor = grad.clone();
245 self.process_group
246 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
247 param.set_grad(grad_tensor);
248 }
249 }
250 }
251 ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
252 for param in self.module.parameters() {
254 if let Some(grad) = param.grad() {
255 let reduced = self
256 .process_group
257 .reduce_scatter_tensor(&grad, ReduceOp::Average);
258 param.set_grad(reduced);
260 }
261 }
262 }
263 ShardingStrategy::HybridShard => {
264 for param in self.module.parameters() {
266 if let Some(grad) = param.grad() {
267 let mut grad_tensor = grad.clone();
268 self.process_group
269 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
270 param.set_grad(grad_tensor);
271 }
272 }
273 }
274 }
275 }
276
277 pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
279 let mut total_norm_sq = 0.0f32;
280
281 for param in self.module.parameters() {
282 if let Some(grad) = param.grad() {
283 let grad_vec = grad.to_vec();
284 let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
285 total_norm_sq += norm_sq;
286 }
287 }
288
289 let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
291 self.process_group
292 .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
293 let global_norm = norm_tensor.to_vec()[0].sqrt();
294
295 if global_norm > max_norm {
297 let clip_coef = max_norm / (global_norm + 1e-6);
298 for param in self.module.parameters() {
299 if let Some(grad) = param.grad() {
300 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
301 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
302 param.variable().set_grad(clipped_tensor);
303 }
304 }
305 }
306
307 global_norm
308 }
309
310 pub fn memory_estimate(&self) -> FSDPMemoryStats {
312 let params = self.module.parameters();
313 let total_params: usize = params.iter().map(|p| p.numel()).sum();
314 let world_size = self.process_group.world_size();
315
316 let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
318
319 let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
320 ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
321 ShardingStrategy::ShardGradOp => (
322 param_memory,
323 param_memory / world_size,
324 param_memory * 2 / world_size,
325 ),
326 ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
327 param_memory / world_size,
328 param_memory / world_size,
329 param_memory * 2 / world_size,
330 ),
331 };
332
333 FSDPMemoryStats {
334 total_params,
335 param_memory_bytes: sharded_params,
336 grad_memory_bytes: sharded_grads,
337 optim_memory_bytes: sharded_optim,
338 world_size,
339 }
340 }
341}
342
343impl<M: Module> Module for FullyShardedDataParallel<M> {
344 fn forward(&self, input: &Variable) -> Variable {
345 self.module.forward(input)
348 }
349
350 fn parameters(&self) -> Vec<Parameter> {
351 self.module.parameters()
352 }
353
354 fn train(&mut self) {
355 self.module.train();
356 }
357
358 fn eval(&mut self) {
359 self.module.eval();
360 }
361
362 fn is_training(&self) -> bool {
363 self.module.is_training()
364 }
365}
366
367#[derive(Debug, Clone)]
369pub struct FSDPMemoryStats {
370 pub total_params: usize,
372 pub param_memory_bytes: usize,
374 pub grad_memory_bytes: usize,
376 pub optim_memory_bytes: usize,
378 pub world_size: usize,
380}
381
382impl FSDPMemoryStats {
383 pub fn total_memory_mb(&self) -> f32 {
385 (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
386 / (1024.0 * 1024.0)
387 }
388
389 pub fn memory_savings(&self) -> f32 {
391 if self.world_size > 1 {
392 1.0 - (1.0 / self.world_size as f32)
393 } else {
394 0.0
395 }
396 }
397}
398
399pub struct ColumnParallelLinear {
408 weight: Parameter,
410 bias: Option<Parameter>,
412 process_group: ProcessGroup,
414 #[allow(dead_code)]
416 in_features: usize,
417 #[allow(dead_code)]
419 out_features: usize,
420 gather_output: bool,
422}
423
424impl ColumnParallelLinear {
425 pub fn new(
427 in_features: usize,
428 out_features: usize,
429 bias: bool,
430 process_group: ProcessGroup,
431 gather_output: bool,
432 ) -> Self {
433 let world_size = process_group.world_size();
434 let local_out_features = out_features / world_size;
435
436 let weight_data = Tensor::randn(&[local_out_features, in_features]);
437 let weight = Parameter::new(weight_data, true);
438
439 let bias = if bias {
440 let bias_data = Tensor::zeros(&[local_out_features]);
441 Some(Parameter::new(bias_data, true))
442 } else {
443 None
444 };
445
446 Self {
447 weight,
448 bias,
449 process_group,
450 in_features,
451 out_features,
452 gather_output,
453 }
454 }
455}
456
457impl Module for ColumnParallelLinear {
458 fn forward(&self, input: &Variable) -> Variable {
459 let weight_var = Variable::new(self.weight.data(), false);
461 let output = input.matmul(&weight_var.transpose(0, 1));
462
463 let output = if let Some(ref bias) = self.bias {
465 let bias_var = Variable::new(bias.data(), false);
466 output.add(&bias_var)
467 } else {
468 output
469 };
470
471 if self.gather_output {
473 let gathered = self.process_group.all_gather_tensor(&output.data());
474 Variable::new(gathered, output.requires_grad())
475 } else {
476 output
477 }
478 }
479
480 fn parameters(&self) -> Vec<Parameter> {
481 let mut params = vec![self.weight.clone()];
482 if let Some(ref bias) = self.bias {
483 params.push(bias.clone());
484 }
485 params
486 }
487}
488
489pub struct RowParallelLinear {
494 weight: Parameter,
496 bias: Option<Parameter>,
498 process_group: ProcessGroup,
500 #[allow(dead_code)]
502 in_features: usize,
503 #[allow(dead_code)]
505 out_features: usize,
506 input_is_parallel: bool,
508}
509
510impl RowParallelLinear {
511 pub fn new(
513 in_features: usize,
514 out_features: usize,
515 bias: bool,
516 process_group: ProcessGroup,
517 input_is_parallel: bool,
518 ) -> Self {
519 let world_size = process_group.world_size();
520 let rank = process_group.rank();
521 let local_in_features = in_features / world_size;
522
523 let weight_data = Tensor::randn(&[out_features, local_in_features]);
524 let weight = Parameter::new(weight_data, true);
525
526 let bias = if bias && rank == 0 {
528 let bias_data = Tensor::zeros(&[out_features]);
529 Some(Parameter::new(bias_data, true))
530 } else {
531 None
532 };
533
534 Self {
535 weight,
536 bias,
537 process_group,
538 in_features,
539 out_features,
540 input_is_parallel,
541 }
542 }
543}
544
545impl Module for RowParallelLinear {
546 fn forward(&self, input: &Variable) -> Variable {
547 let local_input = if self.input_is_parallel {
549 input.clone()
550 } else {
551 let world_size = self.process_group.world_size();
553 let rank = self.process_group.rank();
554 let data = input.data();
555 let shape = data.shape();
556 let feature_dim = shape[shape.len() - 1];
557 let local_features = feature_dim / world_size;
558 let start = rank * local_features;
559 let end = start + local_features;
560
561 let sliced = if shape.len() == 2 {
563 data.slice(&[0..shape[0], start..end])
564 } else {
565 data.clone() };
567 Variable::new(sliced, input.requires_grad())
568 };
569
570 let weight_var = Variable::new(self.weight.data(), false);
572 let local_output = local_input.matmul(&weight_var.transpose(0, 1));
573
574 let mut output_data = local_output.data().clone();
576 self.process_group
577 .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
578 let output = Variable::new(output_data, local_output.requires_grad());
579
580 if let Some(ref bias) = self.bias {
582 let bias_var = Variable::new(bias.data(), false);
583 output.add(&bias_var)
584 } else {
585 output
586 }
587 }
588
589 fn parameters(&self) -> Vec<Parameter> {
590 let mut params = vec![self.weight.clone()];
591 if let Some(ref bias) = self.bias {
592 params.push(bias.clone());
593 }
594 params
595 }
596}
597
598#[cfg(test)]
603mod tests {
604 use super::*;
605 use axonml_nn::Linear;
606
607 #[test]
608 fn test_sharding_strategy_default() {
609 assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
610 }
611
612 #[test]
613 fn test_fsdp_creation() {
614 let model = Linear::new(10, 5);
615 let pg = ProcessGroup::mock();
616 let fsdp = FullyShardedDataParallel::new(model, pg);
617
618 assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
619 }
620
621 #[test]
622 fn test_fsdp_forward() {
623 let model = Linear::new(4, 2);
624 let pg = ProcessGroup::mock();
625 let mut fsdp = FullyShardedDataParallel::new(model, pg);
626
627 fsdp.gather_parameters();
629
630 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
631 let output = fsdp.forward(&input);
632
633 assert_eq!(output.data().shape(), &[1, 2]);
634 }
635
636 #[test]
637 fn test_fsdp_builder() {
638 let model = Linear::new(10, 5);
639 let pg = ProcessGroup::mock();
640
641 let fsdp = FullyShardedDataParallel::new(model, pg)
642 .sharding_strategy(ShardingStrategy::ShardGradOp)
643 .cpu_offload(CPUOffload::Params)
644 .mixed_precision(true);
645
646 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
647 }
648
649 #[test]
650 fn test_fsdp_memory_stats() {
651 let model = Linear::new(100, 50);
652 let pg = ProcessGroup::mock();
653 let fsdp = FullyShardedDataParallel::new(model, pg);
654
655 let stats = fsdp.memory_estimate();
656 assert!(stats.total_params > 0);
657 assert!(stats.total_memory_mb() > 0.0);
658 }
659
660 #[test]
661 fn test_fsdp_no_shard() {
662 let model = Linear::new(10, 5);
663 let pg = ProcessGroup::mock();
664 let fsdp =
665 FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
666
667 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
668 }
669
670 #[test]
671 fn test_column_parallel_linear() {
672 let pg = ProcessGroup::mock();
673 let layer = ColumnParallelLinear::new(8, 4, true, pg, false); let input = Variable::new(Tensor::randn(&[2, 8]), false);
677 let output = layer.forward(&input);
678
679 assert_eq!(output.data().shape(), &[2, 4]);
681 }
682
683 #[test]
684 fn test_row_parallel_linear() {
685 let pg = ProcessGroup::mock();
686 let layer = RowParallelLinear::new(8, 4, true, pg, false);
687
688 let input = Variable::new(Tensor::randn(&[2, 8]), false);
689 let output = layer.forward(&input);
690
691 assert_eq!(output.data().shape(), &[2, 4]);
692 }
693}