1use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum ShardingStrategy {
30 #[default]
32 FullShard,
33 ShardGradOp,
35 NoShard,
37 HybridShard,
39}
40
41#[derive(Debug)]
47struct ShardedParam {
48 local_shard: Tensor<f32>,
50 original_shape: Vec<usize>,
52 numel: usize,
54 #[allow(dead_code)]
57 pub padding: usize,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum CPUOffload {
63 #[default]
65 None,
66 Params,
68 Full,
70}
71
72pub struct FullyShardedDataParallel<M: Module> {
81 module: M,
83 process_group: ProcessGroup,
85 sharding_strategy: ShardingStrategy,
87 cpu_offload: CPUOffload,
89 sharded_params: Vec<ShardedParam>,
91 is_gathered: bool,
93 mixed_precision: bool,
95}
96
97impl<M: Module> FullyShardedDataParallel<M> {
98 pub fn new(module: M, process_group: ProcessGroup) -> Self {
100 let mut fsdp = Self {
101 module,
102 process_group,
103 sharding_strategy: ShardingStrategy::default(),
104 cpu_offload: CPUOffload::default(),
105 sharded_params: Vec::new(),
106 is_gathered: true,
107 mixed_precision: false,
108 };
109
110 fsdp.shard_parameters();
112 fsdp
113 }
114
115 pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
117 self.sharding_strategy = strategy;
118 self.shard_parameters();
119 self
120 }
121
122 pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
124 self.cpu_offload = offload;
125 self
126 }
127
128 pub fn mixed_precision(mut self, enabled: bool) -> Self {
130 self.mixed_precision = enabled;
131 self
132 }
133
134 pub fn module(&self) -> &M {
136 &self.module
137 }
138
139 pub fn module_mut(&mut self) -> &mut M {
141 &mut self.module
142 }
143
144 pub fn process_group(&self) -> &ProcessGroup {
146 &self.process_group
147 }
148
149 pub fn strategy(&self) -> ShardingStrategy {
151 self.sharding_strategy
152 }
153
154 fn shard_parameters(&mut self) {
156 if self.sharding_strategy == ShardingStrategy::NoShard {
157 return;
158 }
159
160 let world_size = self.process_group.world_size();
161 let rank = self.process_group.rank();
162
163 self.sharded_params.clear();
164
165 for param in self.module.parameters() {
166 let data = param.data();
167 let shape = data.shape().to_vec();
168 let numel = data.numel();
169
170 let shard_size = numel.div_ceil(world_size);
172 let padding = shard_size * world_size - numel;
173
174 let flat_data = data.to_vec();
176 let start = rank * shard_size;
177 let end = ((rank + 1) * shard_size).min(flat_data.len());
178
179 let mut shard_data: Vec<f32> = if start < flat_data.len() {
180 flat_data[start..end].to_vec()
181 } else {
182 vec![0.0; shard_size]
183 };
184
185 while shard_data.len() < shard_size {
187 shard_data.push(0.0);
188 }
189
190 self.sharded_params.push(ShardedParam {
191 local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
192 original_shape: shape,
193 numel,
194 padding,
195 });
196 }
197
198 self.is_gathered = false;
199 }
200
201 pub fn gather_parameters(&mut self) {
203 if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
204 return;
205 }
206
207 let params = self.module.parameters();
208
209 for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
210 let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
212
213 let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
215 let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
216
217 param.update_data(restored);
218 }
219
220 self.is_gathered = true;
221 }
222
223 pub fn reshard_parameters(&mut self) {
225 if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
226 return;
227 }
228
229 self.shard_parameters();
230 }
231
232 pub fn sync_gradients(&self) {
238 match self.sharding_strategy {
239 ShardingStrategy::NoShard => {
240 for param in self.module.parameters() {
242 if let Some(grad) = param.grad() {
243 let mut grad_tensor = grad.clone();
244 self.process_group
245 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
246 param.set_grad(grad_tensor);
247 }
248 }
249 }
250 ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
251 for param in self.module.parameters() {
253 if let Some(grad) = param.grad() {
254 let reduced = self
255 .process_group
256 .reduce_scatter_tensor(&grad, ReduceOp::Average);
257 param.set_grad(reduced);
259 }
260 }
261 }
262 ShardingStrategy::HybridShard => {
263 for param in self.module.parameters() {
265 if let Some(grad) = param.grad() {
266 let mut grad_tensor = grad.clone();
267 self.process_group
268 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
269 param.set_grad(grad_tensor);
270 }
271 }
272 }
273 }
274 }
275
276 pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
278 let mut total_norm_sq = 0.0f32;
279
280 for param in self.module.parameters() {
281 if let Some(grad) = param.grad() {
282 let grad_vec = grad.to_vec();
283 let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
284 total_norm_sq += norm_sq;
285 }
286 }
287
288 let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
290 self.process_group
291 .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
292 let global_norm = norm_tensor.to_vec()[0].sqrt();
293
294 if global_norm > max_norm {
296 let clip_coef = max_norm / (global_norm + 1e-6);
297 for param in self.module.parameters() {
298 if let Some(grad) = param.grad() {
299 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
300 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
301 param.variable().set_grad(clipped_tensor);
302 }
303 }
304 }
305
306 global_norm
307 }
308
309 pub fn memory_estimate(&self) -> FSDPMemoryStats {
311 let params = self.module.parameters();
312 let total_params: usize = params.iter().map(|p| p.numel()).sum();
313 let world_size = self.process_group.world_size();
314
315 let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
317
318 let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
319 ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
320 ShardingStrategy::ShardGradOp => (
321 param_memory,
322 param_memory / world_size,
323 param_memory * 2 / world_size,
324 ),
325 ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
326 param_memory / world_size,
327 param_memory / world_size,
328 param_memory * 2 / world_size,
329 ),
330 };
331
332 FSDPMemoryStats {
333 total_params,
334 param_memory_bytes: sharded_params,
335 grad_memory_bytes: sharded_grads,
336 optim_memory_bytes: sharded_optim,
337 world_size,
338 }
339 }
340}
341
342impl<M: Module> Module for FullyShardedDataParallel<M> {
343 fn forward(&self, input: &Variable) -> Variable {
344 self.module.forward(input)
347 }
348
349 fn parameters(&self) -> Vec<Parameter> {
350 self.module.parameters()
351 }
352
353 fn train(&mut self) {
354 self.module.train();
355 }
356
357 fn eval(&mut self) {
358 self.module.eval();
359 }
360
361 fn is_training(&self) -> bool {
362 self.module.is_training()
363 }
364}
365
366#[derive(Debug, Clone)]
368pub struct FSDPMemoryStats {
369 pub total_params: usize,
371 pub param_memory_bytes: usize,
373 pub grad_memory_bytes: usize,
375 pub optim_memory_bytes: usize,
377 pub world_size: usize,
379}
380
381impl FSDPMemoryStats {
382 pub fn total_memory_mb(&self) -> f32 {
384 (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
385 / (1024.0 * 1024.0)
386 }
387
388 pub fn memory_savings(&self) -> f32 {
390 if self.world_size > 1 {
391 1.0 - (1.0 / self.world_size as f32)
392 } else {
393 0.0
394 }
395 }
396}
397
398pub struct ColumnParallelLinear {
407 weight: Parameter,
409 bias: Option<Parameter>,
411 process_group: ProcessGroup,
413 #[allow(dead_code)]
415 in_features: usize,
416 #[allow(dead_code)]
418 out_features: usize,
419 gather_output: bool,
421}
422
423impl ColumnParallelLinear {
424 pub fn new(
426 in_features: usize,
427 out_features: usize,
428 bias: bool,
429 process_group: ProcessGroup,
430 gather_output: bool,
431 ) -> Self {
432 let world_size = process_group.world_size();
433 let local_out_features = out_features / world_size;
434
435 let weight_data = Tensor::randn(&[local_out_features, in_features]);
436 let weight = Parameter::new(weight_data, true);
437
438 let bias = if bias {
439 let bias_data = Tensor::zeros(&[local_out_features]);
440 Some(Parameter::new(bias_data, true))
441 } else {
442 None
443 };
444
445 Self {
446 weight,
447 bias,
448 process_group,
449 in_features,
450 out_features,
451 gather_output,
452 }
453 }
454}
455
456impl Module for ColumnParallelLinear {
457 fn forward(&self, input: &Variable) -> Variable {
458 let weight_var = Variable::new(self.weight.data(), false);
460 let output = input.matmul(&weight_var.transpose(0, 1));
461
462 let output = if let Some(ref bias) = self.bias {
464 let bias_var = Variable::new(bias.data(), false);
465 output.add(&bias_var)
466 } else {
467 output
468 };
469
470 if self.gather_output {
472 let gathered = self.process_group.all_gather_tensor(&output.data());
473 Variable::new(gathered, output.requires_grad())
474 } else {
475 output
476 }
477 }
478
479 fn parameters(&self) -> Vec<Parameter> {
480 let mut params = vec![self.weight.clone()];
481 if let Some(ref bias) = self.bias {
482 params.push(bias.clone());
483 }
484 params
485 }
486}
487
488pub struct RowParallelLinear {
493 weight: Parameter,
495 bias: Option<Parameter>,
497 process_group: ProcessGroup,
499 #[allow(dead_code)]
501 in_features: usize,
502 #[allow(dead_code)]
504 out_features: usize,
505 input_is_parallel: bool,
507}
508
509impl RowParallelLinear {
510 pub fn new(
512 in_features: usize,
513 out_features: usize,
514 bias: bool,
515 process_group: ProcessGroup,
516 input_is_parallel: bool,
517 ) -> Self {
518 let world_size = process_group.world_size();
519 let rank = process_group.rank();
520 let local_in_features = in_features / world_size;
521
522 let weight_data = Tensor::randn(&[out_features, local_in_features]);
523 let weight = Parameter::new(weight_data, true);
524
525 let bias = if bias && rank == 0 {
527 let bias_data = Tensor::zeros(&[out_features]);
528 Some(Parameter::new(bias_data, true))
529 } else {
530 None
531 };
532
533 Self {
534 weight,
535 bias,
536 process_group,
537 in_features,
538 out_features,
539 input_is_parallel,
540 }
541 }
542}
543
544impl Module for RowParallelLinear {
545 fn forward(&self, input: &Variable) -> Variable {
546 let local_input = if self.input_is_parallel {
548 input.clone()
549 } else {
550 let world_size = self.process_group.world_size();
552 let rank = self.process_group.rank();
553 let data = input.data();
554 let shape = data.shape();
555 let feature_dim = shape[shape.len() - 1];
556 let local_features = feature_dim / world_size;
557 let start = rank * local_features;
558 let end = start + local_features;
559
560 let sliced = if shape.len() == 2 {
562 data.slice(&[0..shape[0], start..end])
563 } else {
564 data.clone() };
566 Variable::new(sliced, input.requires_grad())
567 };
568
569 let weight_var = Variable::new(self.weight.data(), false);
571 let local_output = local_input.matmul(&weight_var.transpose(0, 1));
572
573 let mut output_data = local_output.data().clone();
575 self.process_group
576 .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
577 let output = Variable::new(output_data, local_output.requires_grad());
578
579 if let Some(ref bias) = self.bias {
581 let bias_var = Variable::new(bias.data(), false);
582 output.add(&bias_var)
583 } else {
584 output
585 }
586 }
587
588 fn parameters(&self) -> Vec<Parameter> {
589 let mut params = vec![self.weight.clone()];
590 if let Some(ref bias) = self.bias {
591 params.push(bias.clone());
592 }
593 params
594 }
595}
596
597#[cfg(test)]
602mod tests {
603 use super::*;
604 use axonml_nn::Linear;
605
606 #[test]
607 fn test_sharding_strategy_default() {
608 assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
609 }
610
611 #[test]
612 fn test_fsdp_creation() {
613 let model = Linear::new(10, 5);
614 let pg = ProcessGroup::mock();
615 let fsdp = FullyShardedDataParallel::new(model, pg);
616
617 assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
618 }
619
620 #[test]
621 fn test_fsdp_forward() {
622 let model = Linear::new(4, 2);
623 let pg = ProcessGroup::mock();
624 let mut fsdp = FullyShardedDataParallel::new(model, pg);
625
626 fsdp.gather_parameters();
628
629 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
630 let output = fsdp.forward(&input);
631
632 assert_eq!(output.data().shape(), &[1, 2]);
633 }
634
635 #[test]
636 fn test_fsdp_builder() {
637 let model = Linear::new(10, 5);
638 let pg = ProcessGroup::mock();
639
640 let fsdp = FullyShardedDataParallel::new(model, pg)
641 .sharding_strategy(ShardingStrategy::ShardGradOp)
642 .cpu_offload(CPUOffload::Params)
643 .mixed_precision(true);
644
645 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
646 }
647
648 #[test]
649 fn test_fsdp_memory_stats() {
650 let model = Linear::new(100, 50);
651 let pg = ProcessGroup::mock();
652 let fsdp = FullyShardedDataParallel::new(model, pg);
653
654 let stats = fsdp.memory_estimate();
655 assert!(stats.total_params > 0);
656 assert!(stats.total_memory_mb() > 0.0);
657 }
658
659 #[test]
660 fn test_fsdp_no_shard() {
661 let model = Linear::new(10, 5);
662 let pg = ProcessGroup::mock();
663 let fsdp =
664 FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
665
666 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
667 }
668
669 #[test]
670 fn test_column_parallel_linear() {
671 let pg = ProcessGroup::mock();
672 let layer = ColumnParallelLinear::new(8, 4, true, pg, false); let input = Variable::new(Tensor::randn(&[2, 8]), false);
676 let output = layer.forward(&input);
677
678 assert_eq!(output.data().shape(), &[2, 4]);
680 }
681
682 #[test]
683 fn test_row_parallel_linear() {
684 let pg = ProcessGroup::mock();
685 let layer = RowParallelLinear::new(8, 4, true, pg, false);
686
687 let input = Variable::new(Tensor::randn(&[2, 8]), false);
688 let output = layer.forward(&input);
689
690 assert_eq!(output.data().shape(), &[2, 4]);
691 }
692}