1use crate::backend::ReduceOp;
22use crate::process_group::ProcessGroup;
23use axonml_autograd::Variable;
24use axonml_nn::{Module, Parameter};
25use axonml_tensor::Tensor;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ShardingStrategy {
34 FullShard,
36 ShardGradOp,
38 NoShard,
40 HybridShard,
42}
43
44impl Default for ShardingStrategy {
45 fn default() -> Self {
46 Self::FullShard
47 }
48}
49
50#[derive(Debug)]
56#[allow(dead_code)]
57struct ShardedParam {
58 local_shard: Tensor<f32>,
60 original_shape: Vec<usize>,
62 numel: usize,
64 padding: usize,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum CPUOffload {
71 None,
73 Params,
75 Full,
77}
78
79impl Default for CPUOffload {
80 fn default() -> Self {
81 Self::None
82 }
83}
84
85pub struct FullyShardedDataParallel<M: Module> {
94 module: M,
96 process_group: ProcessGroup,
98 sharding_strategy: ShardingStrategy,
100 cpu_offload: CPUOffload,
102 sharded_params: Vec<ShardedParam>,
104 is_gathered: bool,
106 mixed_precision: bool,
108}
109
110impl<M: Module> FullyShardedDataParallel<M> {
111 pub fn new(module: M, process_group: ProcessGroup) -> Self {
113 let mut fsdp = Self {
114 module,
115 process_group,
116 sharding_strategy: ShardingStrategy::default(),
117 cpu_offload: CPUOffload::default(),
118 sharded_params: Vec::new(),
119 is_gathered: true,
120 mixed_precision: false,
121 };
122
123 fsdp.shard_parameters();
125 fsdp
126 }
127
128 pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
130 self.sharding_strategy = strategy;
131 self.shard_parameters();
132 self
133 }
134
135 pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
137 self.cpu_offload = offload;
138 self
139 }
140
141 pub fn mixed_precision(mut self, enabled: bool) -> Self {
143 self.mixed_precision = enabled;
144 self
145 }
146
147 pub fn module(&self) -> &M {
149 &self.module
150 }
151
152 pub fn module_mut(&mut self) -> &mut M {
154 &mut self.module
155 }
156
157 pub fn process_group(&self) -> &ProcessGroup {
159 &self.process_group
160 }
161
162 pub fn strategy(&self) -> ShardingStrategy {
164 self.sharding_strategy
165 }
166
167 fn shard_parameters(&mut self) {
169 if self.sharding_strategy == ShardingStrategy::NoShard {
170 return;
171 }
172
173 let world_size = self.process_group.world_size();
174 let rank = self.process_group.rank();
175
176 self.sharded_params.clear();
177
178 for param in self.module.parameters() {
179 let data = param.data();
180 let shape = data.shape().to_vec();
181 let numel = data.numel();
182
183 let shard_size = (numel + world_size - 1) / world_size;
185 let padding = shard_size * world_size - numel;
186
187 let flat_data = data.to_vec();
189 let start = rank * shard_size;
190 let end = ((rank + 1) * shard_size).min(flat_data.len());
191
192 let mut shard_data: Vec<f32> = if start < flat_data.len() {
193 flat_data[start..end].to_vec()
194 } else {
195 vec![0.0; shard_size]
196 };
197
198 while shard_data.len() < shard_size {
200 shard_data.push(0.0);
201 }
202
203 self.sharded_params.push(ShardedParam {
204 local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
205 original_shape: shape,
206 numel,
207 padding,
208 });
209 }
210
211 self.is_gathered = false;
212 }
213
214 pub fn gather_parameters(&mut self) {
216 if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
217 return;
218 }
219
220 let _world_size = self.process_group.world_size();
221 let params = self.module.parameters();
222
223 for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
224 let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
226
227 let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
229 let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
230
231 param.update_data(restored);
232 }
233
234 self.is_gathered = true;
235 }
236
237 pub fn reshard_parameters(&mut self) {
239 if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
240 return;
241 }
242
243 self.shard_parameters();
244 }
245
246 pub fn sync_gradients(&self) {
248 match self.sharding_strategy {
249 ShardingStrategy::NoShard => {
250 for param in self.module.parameters() {
252 if let Some(grad) = param.grad() {
253 let mut grad_tensor = grad.clone();
254 self.process_group
255 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
256 }
257 }
258 }
259 ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
260 for param in self.module.parameters() {
262 if let Some(grad) = param.grad() {
263 let _reduced = self
264 .process_group
265 .reduce_scatter_tensor(&grad, ReduceOp::Average);
266 }
268 }
269 }
270 ShardingStrategy::HybridShard => {
271 for param in self.module.parameters() {
273 if let Some(grad) = param.grad() {
274 let mut grad_tensor = grad.clone();
275 self.process_group
276 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
277 }
278 }
279 }
280 }
281 }
282
283 pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
285 let mut total_norm_sq = 0.0f32;
286
287 for param in self.module.parameters() {
288 if let Some(grad) = param.grad() {
289 let grad_vec = grad.to_vec();
290 let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
291 total_norm_sq += norm_sq;
292 }
293 }
294
295 let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
297 self.process_group
298 .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
299 let global_norm = norm_tensor.to_vec()[0].sqrt();
300
301 if global_norm > max_norm {
303 let clip_coef = max_norm / (global_norm + 1e-6);
304 for param in self.module.parameters() {
305 if let Some(grad) = param.grad() {
306 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
307 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
308 param.variable().set_grad(clipped_tensor);
309 }
310 }
311 }
312
313 global_norm
314 }
315
316 pub fn memory_estimate(&self) -> FSDPMemoryStats {
318 let params = self.module.parameters();
319 let total_params: usize = params.iter().map(|p| p.numel()).sum();
320 let world_size = self.process_group.world_size();
321
322 let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
324
325 let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
326 ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
327 ShardingStrategy::ShardGradOp => (
328 param_memory,
329 param_memory / world_size,
330 param_memory * 2 / world_size,
331 ),
332 ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
333 param_memory / world_size,
334 param_memory / world_size,
335 param_memory * 2 / world_size,
336 ),
337 };
338
339 FSDPMemoryStats {
340 total_params,
341 param_memory_bytes: sharded_params,
342 grad_memory_bytes: sharded_grads,
343 optim_memory_bytes: sharded_optim,
344 world_size,
345 }
346 }
347}
348
349impl<M: Module> Module for FullyShardedDataParallel<M> {
350 fn forward(&self, input: &Variable) -> Variable {
351 self.module.forward(input)
354 }
355
356 fn parameters(&self) -> Vec<Parameter> {
357 self.module.parameters()
358 }
359
360 fn train(&mut self) {
361 self.module.train();
362 }
363
364 fn eval(&mut self) {
365 self.module.eval();
366 }
367
368 fn is_training(&self) -> bool {
369 self.module.is_training()
370 }
371}
372
373#[derive(Debug, Clone)]
375pub struct FSDPMemoryStats {
376 pub total_params: usize,
378 pub param_memory_bytes: usize,
380 pub grad_memory_bytes: usize,
382 pub optim_memory_bytes: usize,
384 pub world_size: usize,
386}
387
388impl FSDPMemoryStats {
389 pub fn total_memory_mb(&self) -> f32 {
391 (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
392 / (1024.0 * 1024.0)
393 }
394
395 pub fn memory_savings(&self) -> f32 {
397 if self.world_size > 1 {
398 1.0 - (1.0 / self.world_size as f32)
399 } else {
400 0.0
401 }
402 }
403}
404
405#[allow(dead_code)]
414pub struct ColumnParallelLinear {
415 weight: Parameter,
417 bias: Option<Parameter>,
419 process_group: ProcessGroup,
421 in_features: usize,
423 out_features: usize,
425 gather_output: bool,
427}
428
429impl ColumnParallelLinear {
430 pub fn new(
432 in_features: usize,
433 out_features: usize,
434 bias: bool,
435 process_group: ProcessGroup,
436 gather_output: bool,
437 ) -> Self {
438 let world_size = process_group.world_size();
439 let local_out_features = out_features / world_size;
440
441 let weight_data = Tensor::randn(&[local_out_features, in_features]);
442 let weight = Parameter::new(weight_data, true);
443
444 let bias = if bias {
445 let bias_data = Tensor::zeros(&[local_out_features]);
446 Some(Parameter::new(bias_data, true))
447 } else {
448 None
449 };
450
451 Self {
452 weight,
453 bias,
454 process_group,
455 in_features,
456 out_features,
457 gather_output,
458 }
459 }
460}
461
462impl Module for ColumnParallelLinear {
463 fn forward(&self, input: &Variable) -> Variable {
464 let weight_var = Variable::new(self.weight.data(), false);
466 let output = input.matmul(&weight_var.transpose(0, 1));
467
468 let output = if let Some(ref bias) = self.bias {
470 let bias_var = Variable::new(bias.data(), false);
471 output.add(&bias_var)
472 } else {
473 output
474 };
475
476 if self.gather_output {
478 let gathered = self.process_group.all_gather_tensor(&output.data());
479 Variable::new(gathered, output.requires_grad())
480 } else {
481 output
482 }
483 }
484
485 fn parameters(&self) -> Vec<Parameter> {
486 let mut params = vec![self.weight.clone()];
487 if let Some(ref bias) = self.bias {
488 params.push(bias.clone());
489 }
490 params
491 }
492}
493
494#[allow(dead_code)]
499pub struct RowParallelLinear {
500 weight: Parameter,
502 bias: Option<Parameter>,
504 process_group: ProcessGroup,
506 in_features: usize,
508 out_features: usize,
510 input_is_parallel: bool,
512}
513
514impl RowParallelLinear {
515 pub fn new(
517 in_features: usize,
518 out_features: usize,
519 bias: bool,
520 process_group: ProcessGroup,
521 input_is_parallel: bool,
522 ) -> Self {
523 let world_size = process_group.world_size();
524 let rank = process_group.rank();
525 let local_in_features = in_features / world_size;
526
527 let weight_data = Tensor::randn(&[out_features, local_in_features]);
528 let weight = Parameter::new(weight_data, true);
529
530 let bias = if bias && rank == 0 {
532 let bias_data = Tensor::zeros(&[out_features]);
533 Some(Parameter::new(bias_data, true))
534 } else {
535 None
536 };
537
538 Self {
539 weight,
540 bias,
541 process_group,
542 in_features,
543 out_features,
544 input_is_parallel,
545 }
546 }
547}
548
549impl Module for RowParallelLinear {
550 fn forward(&self, input: &Variable) -> Variable {
551 let local_input = if self.input_is_parallel {
553 input.clone()
554 } else {
555 let world_size = self.process_group.world_size();
557 let rank = self.process_group.rank();
558 let data = input.data();
559 let shape = data.shape();
560 let feature_dim = shape[shape.len() - 1];
561 let local_features = feature_dim / world_size;
562 let start = rank * local_features;
563 let end = start + local_features;
564
565 let sliced = if shape.len() == 2 {
567 data.slice(&[0..shape[0], start..end])
568 } else {
569 data.clone() };
571 Variable::new(sliced, input.requires_grad())
572 };
573
574 let weight_var = Variable::new(self.weight.data(), false);
576 let local_output = local_input.matmul(&weight_var.transpose(0, 1));
577
578 let mut output_data = local_output.data().clone();
580 self.process_group
581 .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
582 let output = Variable::new(output_data, local_output.requires_grad());
583
584 if let Some(ref bias) = self.bias {
586 let bias_var = Variable::new(bias.data(), false);
587 output.add(&bias_var)
588 } else {
589 output
590 }
591 }
592
593 fn parameters(&self) -> Vec<Parameter> {
594 let mut params = vec![self.weight.clone()];
595 if let Some(ref bias) = self.bias {
596 params.push(bias.clone());
597 }
598 params
599 }
600}
601
602#[cfg(test)]
607mod tests {
608 use super::*;
609 use axonml_nn::Linear;
610
611 #[test]
612 fn test_sharding_strategy_default() {
613 assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
614 }
615
616 #[test]
617 fn test_fsdp_creation() {
618 let model = Linear::new(10, 5);
619 let pg = ProcessGroup::mock();
620 let fsdp = FullyShardedDataParallel::new(model, pg);
621
622 assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
623 }
624
625 #[test]
626 fn test_fsdp_forward() {
627 let model = Linear::new(4, 2);
628 let pg = ProcessGroup::mock();
629 let mut fsdp = FullyShardedDataParallel::new(model, pg);
630
631 fsdp.gather_parameters();
633
634 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
635 let output = fsdp.forward(&input);
636
637 assert_eq!(output.data().shape(), &[1, 2]);
638 }
639
640 #[test]
641 fn test_fsdp_builder() {
642 let model = Linear::new(10, 5);
643 let pg = ProcessGroup::mock();
644
645 let fsdp = FullyShardedDataParallel::new(model, pg)
646 .sharding_strategy(ShardingStrategy::ShardGradOp)
647 .cpu_offload(CPUOffload::Params)
648 .mixed_precision(true);
649
650 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
651 }
652
653 #[test]
654 fn test_fsdp_memory_stats() {
655 let model = Linear::new(100, 50);
656 let pg = ProcessGroup::mock();
657 let fsdp = FullyShardedDataParallel::new(model, pg);
658
659 let stats = fsdp.memory_estimate();
660 assert!(stats.total_params > 0);
661 assert!(stats.total_memory_mb() > 0.0);
662 }
663
664 #[test]
665 fn test_fsdp_no_shard() {
666 let model = Linear::new(10, 5);
667 let pg = ProcessGroup::mock();
668 let fsdp =
669 FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
670
671 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
672 }
673
674 #[test]
675 fn test_column_parallel_linear() {
676 let pg = ProcessGroup::mock();
677 let layer = ColumnParallelLinear::new(8, 4, true, pg, false); let input = Variable::new(Tensor::randn(&[2, 8]), false);
681 let output = layer.forward(&input);
682
683 assert_eq!(output.data().shape(), &[2, 4]);
685 }
686
687 #[test]
688 fn test_row_parallel_linear() {
689 let pg = ProcessGroup::mock();
690 let layer = RowParallelLinear::new(8, 4, true, pg, false);
691
692 let input = Variable::new(Tensor::randn(&[2, 8]), false);
693 let output = layer.forward(&input);
694
695 assert_eq!(output.data().shape(), &[2, 4]);
696 }
697}