1use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum ShardingStrategy {
30 #[default]
32 FullShard,
33 ShardGradOp,
35 NoShard,
37 HybridShard,
39}
40
41#[derive(Debug)]
47#[allow(dead_code)]
48struct ShardedParam {
49 local_shard: Tensor<f32>,
51 original_shape: Vec<usize>,
53 numel: usize,
55 padding: usize,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61pub enum CPUOffload {
62 #[default]
64 None,
65 Params,
67 Full,
69}
70
71pub struct FullyShardedDataParallel<M: Module> {
80 module: M,
82 process_group: ProcessGroup,
84 sharding_strategy: ShardingStrategy,
86 cpu_offload: CPUOffload,
88 sharded_params: Vec<ShardedParam>,
90 is_gathered: bool,
92 mixed_precision: bool,
94}
95
96impl<M: Module> FullyShardedDataParallel<M> {
97 pub fn new(module: M, process_group: ProcessGroup) -> Self {
99 let mut fsdp = Self {
100 module,
101 process_group,
102 sharding_strategy: ShardingStrategy::default(),
103 cpu_offload: CPUOffload::default(),
104 sharded_params: Vec::new(),
105 is_gathered: true,
106 mixed_precision: false,
107 };
108
109 fsdp.shard_parameters();
111 fsdp
112 }
113
114 pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
116 self.sharding_strategy = strategy;
117 self.shard_parameters();
118 self
119 }
120
121 pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
123 self.cpu_offload = offload;
124 self
125 }
126
127 pub fn mixed_precision(mut self, enabled: bool) -> Self {
129 self.mixed_precision = enabled;
130 self
131 }
132
133 pub fn module(&self) -> &M {
135 &self.module
136 }
137
138 pub fn module_mut(&mut self) -> &mut M {
140 &mut self.module
141 }
142
143 pub fn process_group(&self) -> &ProcessGroup {
145 &self.process_group
146 }
147
148 pub fn strategy(&self) -> ShardingStrategy {
150 self.sharding_strategy
151 }
152
153 fn shard_parameters(&mut self) {
155 if self.sharding_strategy == ShardingStrategy::NoShard {
156 return;
157 }
158
159 let world_size = self.process_group.world_size();
160 let rank = self.process_group.rank();
161
162 self.sharded_params.clear();
163
164 for param in self.module.parameters() {
165 let data = param.data();
166 let shape = data.shape().to_vec();
167 let numel = data.numel();
168
169 let shard_size = numel.div_ceil(world_size);
171 let padding = shard_size * world_size - numel;
172
173 let flat_data = data.to_vec();
175 let start = rank * shard_size;
176 let end = ((rank + 1) * shard_size).min(flat_data.len());
177
178 let mut shard_data: Vec<f32> = if start < flat_data.len() {
179 flat_data[start..end].to_vec()
180 } else {
181 vec![0.0; shard_size]
182 };
183
184 while shard_data.len() < shard_size {
186 shard_data.push(0.0);
187 }
188
189 self.sharded_params.push(ShardedParam {
190 local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
191 original_shape: shape,
192 numel,
193 padding,
194 });
195 }
196
197 self.is_gathered = false;
198 }
199
200 pub fn gather_parameters(&mut self) {
202 if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
203 return;
204 }
205
206 let _world_size = self.process_group.world_size();
207 let params = self.module.parameters();
208
209 for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
210 let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
212
213 let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
215 let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
216
217 param.update_data(restored);
218 }
219
220 self.is_gathered = true;
221 }
222
223 pub fn reshard_parameters(&mut self) {
225 if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
226 return;
227 }
228
229 self.shard_parameters();
230 }
231
232 pub fn sync_gradients(&self) {
234 match self.sharding_strategy {
235 ShardingStrategy::NoShard => {
236 for param in self.module.parameters() {
238 if let Some(grad) = param.grad() {
239 let mut grad_tensor = grad.clone();
240 self.process_group
241 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
242 }
243 }
244 }
245 ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
246 for param in self.module.parameters() {
248 if let Some(grad) = param.grad() {
249 let _reduced = self
250 .process_group
251 .reduce_scatter_tensor(&grad, ReduceOp::Average);
252 }
254 }
255 }
256 ShardingStrategy::HybridShard => {
257 for param in self.module.parameters() {
259 if let Some(grad) = param.grad() {
260 let mut grad_tensor = grad.clone();
261 self.process_group
262 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
263 }
264 }
265 }
266 }
267 }
268
269 pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
271 let mut total_norm_sq = 0.0f32;
272
273 for param in self.module.parameters() {
274 if let Some(grad) = param.grad() {
275 let grad_vec = grad.to_vec();
276 let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
277 total_norm_sq += norm_sq;
278 }
279 }
280
281 let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
283 self.process_group
284 .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
285 let global_norm = norm_tensor.to_vec()[0].sqrt();
286
287 if global_norm > max_norm {
289 let clip_coef = max_norm / (global_norm + 1e-6);
290 for param in self.module.parameters() {
291 if let Some(grad) = param.grad() {
292 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
293 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
294 param.variable().set_grad(clipped_tensor);
295 }
296 }
297 }
298
299 global_norm
300 }
301
302 pub fn memory_estimate(&self) -> FSDPMemoryStats {
304 let params = self.module.parameters();
305 let total_params: usize = params.iter().map(|p| p.numel()).sum();
306 let world_size = self.process_group.world_size();
307
308 let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
310
311 let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
312 ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
313 ShardingStrategy::ShardGradOp => (
314 param_memory,
315 param_memory / world_size,
316 param_memory * 2 / world_size,
317 ),
318 ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
319 param_memory / world_size,
320 param_memory / world_size,
321 param_memory * 2 / world_size,
322 ),
323 };
324
325 FSDPMemoryStats {
326 total_params,
327 param_memory_bytes: sharded_params,
328 grad_memory_bytes: sharded_grads,
329 optim_memory_bytes: sharded_optim,
330 world_size,
331 }
332 }
333}
334
335impl<M: Module> Module for FullyShardedDataParallel<M> {
336 fn forward(&self, input: &Variable) -> Variable {
337 self.module.forward(input)
340 }
341
342 fn parameters(&self) -> Vec<Parameter> {
343 self.module.parameters()
344 }
345
346 fn train(&mut self) {
347 self.module.train();
348 }
349
350 fn eval(&mut self) {
351 self.module.eval();
352 }
353
354 fn is_training(&self) -> bool {
355 self.module.is_training()
356 }
357}
358
359#[derive(Debug, Clone)]
361pub struct FSDPMemoryStats {
362 pub total_params: usize,
364 pub param_memory_bytes: usize,
366 pub grad_memory_bytes: usize,
368 pub optim_memory_bytes: usize,
370 pub world_size: usize,
372}
373
374impl FSDPMemoryStats {
375 pub fn total_memory_mb(&self) -> f32 {
377 (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
378 / (1024.0 * 1024.0)
379 }
380
381 pub fn memory_savings(&self) -> f32 {
383 if self.world_size > 1 {
384 1.0 - (1.0 / self.world_size as f32)
385 } else {
386 0.0
387 }
388 }
389}
390
391#[allow(dead_code)]
400pub struct ColumnParallelLinear {
401 weight: Parameter,
403 bias: Option<Parameter>,
405 process_group: ProcessGroup,
407 in_features: usize,
409 out_features: usize,
411 gather_output: bool,
413}
414
415impl ColumnParallelLinear {
416 pub fn new(
418 in_features: usize,
419 out_features: usize,
420 bias: bool,
421 process_group: ProcessGroup,
422 gather_output: bool,
423 ) -> Self {
424 let world_size = process_group.world_size();
425 let local_out_features = out_features / world_size;
426
427 let weight_data = Tensor::randn(&[local_out_features, in_features]);
428 let weight = Parameter::new(weight_data, true);
429
430 let bias = if bias {
431 let bias_data = Tensor::zeros(&[local_out_features]);
432 Some(Parameter::new(bias_data, true))
433 } else {
434 None
435 };
436
437 Self {
438 weight,
439 bias,
440 process_group,
441 in_features,
442 out_features,
443 gather_output,
444 }
445 }
446}
447
448impl Module for ColumnParallelLinear {
449 fn forward(&self, input: &Variable) -> Variable {
450 let weight_var = Variable::new(self.weight.data(), false);
452 let output = input.matmul(&weight_var.transpose(0, 1));
453
454 let output = if let Some(ref bias) = self.bias {
456 let bias_var = Variable::new(bias.data(), false);
457 output.add(&bias_var)
458 } else {
459 output
460 };
461
462 if self.gather_output {
464 let gathered = self.process_group.all_gather_tensor(&output.data());
465 Variable::new(gathered, output.requires_grad())
466 } else {
467 output
468 }
469 }
470
471 fn parameters(&self) -> Vec<Parameter> {
472 let mut params = vec![self.weight.clone()];
473 if let Some(ref bias) = self.bias {
474 params.push(bias.clone());
475 }
476 params
477 }
478}
479
480#[allow(dead_code)]
485pub struct RowParallelLinear {
486 weight: Parameter,
488 bias: Option<Parameter>,
490 process_group: ProcessGroup,
492 in_features: usize,
494 out_features: usize,
496 input_is_parallel: bool,
498}
499
500impl RowParallelLinear {
501 pub fn new(
503 in_features: usize,
504 out_features: usize,
505 bias: bool,
506 process_group: ProcessGroup,
507 input_is_parallel: bool,
508 ) -> Self {
509 let world_size = process_group.world_size();
510 let rank = process_group.rank();
511 let local_in_features = in_features / world_size;
512
513 let weight_data = Tensor::randn(&[out_features, local_in_features]);
514 let weight = Parameter::new(weight_data, true);
515
516 let bias = if bias && rank == 0 {
518 let bias_data = Tensor::zeros(&[out_features]);
519 Some(Parameter::new(bias_data, true))
520 } else {
521 None
522 };
523
524 Self {
525 weight,
526 bias,
527 process_group,
528 in_features,
529 out_features,
530 input_is_parallel,
531 }
532 }
533}
534
535impl Module for RowParallelLinear {
536 fn forward(&self, input: &Variable) -> Variable {
537 let local_input = if self.input_is_parallel {
539 input.clone()
540 } else {
541 let world_size = self.process_group.world_size();
543 let rank = self.process_group.rank();
544 let data = input.data();
545 let shape = data.shape();
546 let feature_dim = shape[shape.len() - 1];
547 let local_features = feature_dim / world_size;
548 let start = rank * local_features;
549 let end = start + local_features;
550
551 let sliced = if shape.len() == 2 {
553 data.slice(&[0..shape[0], start..end])
554 } else {
555 data.clone() };
557 Variable::new(sliced, input.requires_grad())
558 };
559
560 let weight_var = Variable::new(self.weight.data(), false);
562 let local_output = local_input.matmul(&weight_var.transpose(0, 1));
563
564 let mut output_data = local_output.data().clone();
566 self.process_group
567 .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
568 let output = Variable::new(output_data, local_output.requires_grad());
569
570 if let Some(ref bias) = self.bias {
572 let bias_var = Variable::new(bias.data(), false);
573 output.add(&bias_var)
574 } else {
575 output
576 }
577 }
578
579 fn parameters(&self) -> Vec<Parameter> {
580 let mut params = vec![self.weight.clone()];
581 if let Some(ref bias) = self.bias {
582 params.push(bias.clone());
583 }
584 params
585 }
586}
587
588#[cfg(test)]
593mod tests {
594 use super::*;
595 use axonml_nn::Linear;
596
597 #[test]
598 fn test_sharding_strategy_default() {
599 assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
600 }
601
602 #[test]
603 fn test_fsdp_creation() {
604 let model = Linear::new(10, 5);
605 let pg = ProcessGroup::mock();
606 let fsdp = FullyShardedDataParallel::new(model, pg);
607
608 assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
609 }
610
611 #[test]
612 fn test_fsdp_forward() {
613 let model = Linear::new(4, 2);
614 let pg = ProcessGroup::mock();
615 let mut fsdp = FullyShardedDataParallel::new(model, pg);
616
617 fsdp.gather_parameters();
619
620 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
621 let output = fsdp.forward(&input);
622
623 assert_eq!(output.data().shape(), &[1, 2]);
624 }
625
626 #[test]
627 fn test_fsdp_builder() {
628 let model = Linear::new(10, 5);
629 let pg = ProcessGroup::mock();
630
631 let fsdp = FullyShardedDataParallel::new(model, pg)
632 .sharding_strategy(ShardingStrategy::ShardGradOp)
633 .cpu_offload(CPUOffload::Params)
634 .mixed_precision(true);
635
636 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
637 }
638
639 #[test]
640 fn test_fsdp_memory_stats() {
641 let model = Linear::new(100, 50);
642 let pg = ProcessGroup::mock();
643 let fsdp = FullyShardedDataParallel::new(model, pg);
644
645 let stats = fsdp.memory_estimate();
646 assert!(stats.total_params > 0);
647 assert!(stats.total_memory_mb() > 0.0);
648 }
649
650 #[test]
651 fn test_fsdp_no_shard() {
652 let model = Linear::new(10, 5);
653 let pg = ProcessGroup::mock();
654 let fsdp =
655 FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
656
657 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
658 }
659
660 #[test]
661 fn test_column_parallel_linear() {
662 let pg = ProcessGroup::mock();
663 let layer = ColumnParallelLinear::new(8, 4, true, pg, false); let input = Variable::new(Tensor::randn(&[2, 8]), false);
667 let output = layer.forward(&input);
668
669 assert_eq!(output.data().shape(), &[2, 4]);
671 }
672
673 #[test]
674 fn test_row_parallel_linear() {
675 let pg = ProcessGroup::mock();
676 let layer = RowParallelLinear::new(8, 4, true, pg, false);
677
678 let input = Variable::new(Tensor::randn(&[2, 8]), false);
679 let output = layer.forward(&input);
680
681 assert_eq!(output.data().shape(), &[2, 4]);
682 }
683}