1use crate::backend::ReduceOp;
22use crate::process_group::ProcessGroup;
23use axonml_autograd::Variable;
24use axonml_nn::{Module, Parameter};
25use axonml_tensor::Tensor;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ShardingStrategy {
34 FullShard,
36 ShardGradOp,
38 NoShard,
40 HybridShard,
42}
43
44impl Default for ShardingStrategy {
45 fn default() -> Self {
46 Self::FullShard
47 }
48}
49
50#[derive(Debug)]
56#[allow(dead_code)]
57struct ShardedParam {
58 local_shard: Tensor<f32>,
60 original_shape: Vec<usize>,
62 numel: usize,
64 padding: usize,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum CPUOffload {
71 None,
73 Params,
75 Full,
77}
78
79impl Default for CPUOffload {
80 fn default() -> Self {
81 Self::None
82 }
83}
84
85pub struct FullyShardedDataParallel<M: Module> {
94 module: M,
96 process_group: ProcessGroup,
98 sharding_strategy: ShardingStrategy,
100 cpu_offload: CPUOffload,
102 sharded_params: Vec<ShardedParam>,
104 is_gathered: bool,
106 mixed_precision: bool,
108}
109
110impl<M: Module> FullyShardedDataParallel<M> {
111 pub fn new(module: M, process_group: ProcessGroup) -> Self {
113 let mut fsdp = Self {
114 module,
115 process_group,
116 sharding_strategy: ShardingStrategy::default(),
117 cpu_offload: CPUOffload::default(),
118 sharded_params: Vec::new(),
119 is_gathered: true,
120 mixed_precision: false,
121 };
122
123 fsdp.shard_parameters();
125 fsdp
126 }
127
128 pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
130 self.sharding_strategy = strategy;
131 self.shard_parameters();
132 self
133 }
134
135 pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
137 self.cpu_offload = offload;
138 self
139 }
140
141 pub fn mixed_precision(mut self, enabled: bool) -> Self {
143 self.mixed_precision = enabled;
144 self
145 }
146
147 pub fn module(&self) -> &M {
149 &self.module
150 }
151
152 pub fn module_mut(&mut self) -> &mut M {
154 &mut self.module
155 }
156
157 pub fn process_group(&self) -> &ProcessGroup {
159 &self.process_group
160 }
161
162 pub fn strategy(&self) -> ShardingStrategy {
164 self.sharding_strategy
165 }
166
167 fn shard_parameters(&mut self) {
169 if self.sharding_strategy == ShardingStrategy::NoShard {
170 return;
171 }
172
173 let world_size = self.process_group.world_size();
174 let rank = self.process_group.rank();
175
176 self.sharded_params.clear();
177
178 for param in self.module.parameters() {
179 let data = param.data();
180 let shape = data.shape().to_vec();
181 let numel = data.numel();
182
183 let shard_size = (numel + world_size - 1) / world_size;
185 let padding = shard_size * world_size - numel;
186
187 let flat_data = data.to_vec();
189 let start = rank * shard_size;
190 let end = ((rank + 1) * shard_size).min(flat_data.len());
191
192 let mut shard_data: Vec<f32> = if start < flat_data.len() {
193 flat_data[start..end].to_vec()
194 } else {
195 vec![0.0; shard_size]
196 };
197
198 while shard_data.len() < shard_size {
200 shard_data.push(0.0);
201 }
202
203 self.sharded_params.push(ShardedParam {
204 local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
205 original_shape: shape,
206 numel,
207 padding,
208 });
209 }
210
211 self.is_gathered = false;
212 }
213
214 pub fn gather_parameters(&mut self) {
216 if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
217 return;
218 }
219
220 let _world_size = self.process_group.world_size();
221 let params = self.module.parameters();
222
223 for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
224 let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
226
227 let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
229 let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
230
231 param.update_data(restored);
232 }
233
234 self.is_gathered = true;
235 }
236
237 pub fn reshard_parameters(&mut self) {
239 if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
240 return;
241 }
242
243 self.shard_parameters();
244 }
245
246 pub fn sync_gradients(&self) {
248 match self.sharding_strategy {
249 ShardingStrategy::NoShard => {
250 for param in self.module.parameters() {
252 if let Some(grad) = param.grad() {
253 let mut grad_tensor = grad.clone();
254 self.process_group.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
255 }
256 }
257 }
258 ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
259 for param in self.module.parameters() {
261 if let Some(grad) = param.grad() {
262 let _reduced = self.process_group.reduce_scatter_tensor(&grad, ReduceOp::Average);
263 }
265 }
266 }
267 ShardingStrategy::HybridShard => {
268 for param in self.module.parameters() {
270 if let Some(grad) = param.grad() {
271 let mut grad_tensor = grad.clone();
272 self.process_group.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
273 }
274 }
275 }
276 }
277 }
278
279 pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
281 let mut total_norm_sq = 0.0f32;
282
283 for param in self.module.parameters() {
284 if let Some(grad) = param.grad() {
285 let grad_vec = grad.to_vec();
286 let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
287 total_norm_sq += norm_sq;
288 }
289 }
290
291 let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
293 self.process_group.all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
294 let global_norm = norm_tensor.to_vec()[0].sqrt();
295
296 if global_norm > max_norm {
298 let clip_coef = max_norm / (global_norm + 1e-6);
299 for param in self.module.parameters() {
300 if let Some(grad) = param.grad() {
301 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
302 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
303 param.variable().set_grad(clipped_tensor);
304 }
305 }
306 }
307
308 global_norm
309 }
310
311 pub fn memory_estimate(&self) -> FSDPMemoryStats {
313 let params = self.module.parameters();
314 let total_params: usize = params.iter().map(|p| p.numel()).sum();
315 let world_size = self.process_group.world_size();
316
317 let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
319
320 let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
321 ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
322 ShardingStrategy::ShardGradOp => (
323 param_memory,
324 param_memory / world_size,
325 param_memory * 2 / world_size,
326 ),
327 ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
328 param_memory / world_size,
329 param_memory / world_size,
330 param_memory * 2 / world_size,
331 ),
332 };
333
334 FSDPMemoryStats {
335 total_params,
336 param_memory_bytes: sharded_params,
337 grad_memory_bytes: sharded_grads,
338 optim_memory_bytes: sharded_optim,
339 world_size,
340 }
341 }
342}
343
344impl<M: Module> Module for FullyShardedDataParallel<M> {
345 fn forward(&self, input: &Variable) -> Variable {
346 self.module.forward(input)
349 }
350
351 fn parameters(&self) -> Vec<Parameter> {
352 self.module.parameters()
353 }
354
355 fn train(&mut self) {
356 self.module.train();
357 }
358
359 fn eval(&mut self) {
360 self.module.eval();
361 }
362
363 fn is_training(&self) -> bool {
364 self.module.is_training()
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct FSDPMemoryStats {
371 pub total_params: usize,
373 pub param_memory_bytes: usize,
375 pub grad_memory_bytes: usize,
377 pub optim_memory_bytes: usize,
379 pub world_size: usize,
381}
382
383impl FSDPMemoryStats {
384 pub fn total_memory_mb(&self) -> f32 {
386 (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
387 / (1024.0 * 1024.0)
388 }
389
390 pub fn memory_savings(&self) -> f32 {
392 if self.world_size > 1 {
393 1.0 - (1.0 / self.world_size as f32)
394 } else {
395 0.0
396 }
397 }
398}
399
400#[allow(dead_code)]
409pub struct ColumnParallelLinear {
410 weight: Parameter,
412 bias: Option<Parameter>,
414 process_group: ProcessGroup,
416 in_features: usize,
418 out_features: usize,
420 gather_output: bool,
422}
423
424impl ColumnParallelLinear {
425 pub fn new(
427 in_features: usize,
428 out_features: usize,
429 bias: bool,
430 process_group: ProcessGroup,
431 gather_output: bool,
432 ) -> Self {
433 let world_size = process_group.world_size();
434 let local_out_features = out_features / world_size;
435
436 let weight_data = Tensor::randn(&[local_out_features, in_features]);
437 let weight = Parameter::new(weight_data, true);
438
439 let bias = if bias {
440 let bias_data = Tensor::zeros(&[local_out_features]);
441 Some(Parameter::new(bias_data, true))
442 } else {
443 None
444 };
445
446 Self {
447 weight,
448 bias,
449 process_group,
450 in_features,
451 out_features,
452 gather_output,
453 }
454 }
455}
456
457impl Module for ColumnParallelLinear {
458 fn forward(&self, input: &Variable) -> Variable {
459 let weight_var = Variable::new(self.weight.data(), false);
461 let output = input.matmul(&weight_var.transpose(0, 1));
462
463 let output = if let Some(ref bias) = self.bias {
465 let bias_var = Variable::new(bias.data(), false);
466 output.add(&bias_var)
467 } else {
468 output
469 };
470
471 if self.gather_output {
473 let gathered = self.process_group.all_gather_tensor(&output.data());
474 Variable::new(gathered, output.requires_grad())
475 } else {
476 output
477 }
478 }
479
480 fn parameters(&self) -> Vec<Parameter> {
481 let mut params = vec![self.weight.clone()];
482 if let Some(ref bias) = self.bias {
483 params.push(bias.clone());
484 }
485 params
486 }
487}
488
489#[allow(dead_code)]
494pub struct RowParallelLinear {
495 weight: Parameter,
497 bias: Option<Parameter>,
499 process_group: ProcessGroup,
501 in_features: usize,
503 out_features: usize,
505 input_is_parallel: bool,
507}
508
509impl RowParallelLinear {
510 pub fn new(
512 in_features: usize,
513 out_features: usize,
514 bias: bool,
515 process_group: ProcessGroup,
516 input_is_parallel: bool,
517 ) -> Self {
518 let world_size = process_group.world_size();
519 let rank = process_group.rank();
520 let local_in_features = in_features / world_size;
521
522 let weight_data = Tensor::randn(&[out_features, local_in_features]);
523 let weight = Parameter::new(weight_data, true);
524
525 let bias = if bias && rank == 0 {
527 let bias_data = Tensor::zeros(&[out_features]);
528 Some(Parameter::new(bias_data, true))
529 } else {
530 None
531 };
532
533 Self {
534 weight,
535 bias,
536 process_group,
537 in_features,
538 out_features,
539 input_is_parallel,
540 }
541 }
542}
543
544impl Module for RowParallelLinear {
545 fn forward(&self, input: &Variable) -> Variable {
546 let local_input = if self.input_is_parallel {
548 input.clone()
549 } else {
550 let world_size = self.process_group.world_size();
552 let rank = self.process_group.rank();
553 let data = input.data();
554 let shape = data.shape();
555 let feature_dim = shape[shape.len() - 1];
556 let local_features = feature_dim / world_size;
557 let start = rank * local_features;
558 let end = start + local_features;
559
560 let sliced = if shape.len() == 2 {
562 data.slice(&[0..shape[0], start..end])
563 } else {
564 data.clone() };
566 Variable::new(sliced, input.requires_grad())
567 };
568
569 let weight_var = Variable::new(self.weight.data(), false);
571 let local_output = local_input.matmul(&weight_var.transpose(0, 1));
572
573 let mut output_data = local_output.data().clone();
575 self.process_group.all_reduce_tensor(&mut output_data, ReduceOp::Sum);
576 let output = Variable::new(output_data, local_output.requires_grad());
577
578 if let Some(ref bias) = self.bias {
580 let bias_var = Variable::new(bias.data(), false);
581 output.add(&bias_var)
582 } else {
583 output
584 }
585 }
586
587 fn parameters(&self) -> Vec<Parameter> {
588 let mut params = vec![self.weight.clone()];
589 if let Some(ref bias) = self.bias {
590 params.push(bias.clone());
591 }
592 params
593 }
594}
595
596#[cfg(test)]
601mod tests {
602 use super::*;
603 use axonml_nn::Linear;
604
605 #[test]
606 fn test_sharding_strategy_default() {
607 assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
608 }
609
610 #[test]
611 fn test_fsdp_creation() {
612 let model = Linear::new(10, 5);
613 let pg = ProcessGroup::mock();
614 let fsdp = FullyShardedDataParallel::new(model, pg);
615
616 assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
617 }
618
619 #[test]
620 fn test_fsdp_forward() {
621 let model = Linear::new(4, 2);
622 let pg = ProcessGroup::mock();
623 let mut fsdp = FullyShardedDataParallel::new(model, pg);
624
625 fsdp.gather_parameters();
627
628 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
629 let output = fsdp.forward(&input);
630
631 assert_eq!(output.data().shape(), &[1, 2]);
632 }
633
634 #[test]
635 fn test_fsdp_builder() {
636 let model = Linear::new(10, 5);
637 let pg = ProcessGroup::mock();
638
639 let fsdp = FullyShardedDataParallel::new(model, pg)
640 .sharding_strategy(ShardingStrategy::ShardGradOp)
641 .cpu_offload(CPUOffload::Params)
642 .mixed_precision(true);
643
644 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
645 }
646
647 #[test]
648 fn test_fsdp_memory_stats() {
649 let model = Linear::new(100, 50);
650 let pg = ProcessGroup::mock();
651 let fsdp = FullyShardedDataParallel::new(model, pg);
652
653 let stats = fsdp.memory_estimate();
654 assert!(stats.total_params > 0);
655 assert!(stats.total_memory_mb() > 0.0);
656 }
657
658 #[test]
659 fn test_fsdp_no_shard() {
660 let model = Linear::new(10, 5);
661 let pg = ProcessGroup::mock();
662 let fsdp = FullyShardedDataParallel::new(model, pg)
663 .sharding_strategy(ShardingStrategy::NoShard);
664
665 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
666 }
667
668 #[test]
669 fn test_column_parallel_linear() {
670 let pg = ProcessGroup::mock();
671 let layer = ColumnParallelLinear::new(8, 4, true, pg, false); let input = Variable::new(Tensor::randn(&[2, 8]), false);
675 let output = layer.forward(&input);
676
677 assert_eq!(output.data().shape(), &[2, 4]);
679 }
680
681 #[test]
682 fn test_row_parallel_linear() {
683 let pg = ProcessGroup::mock();
684 let layer = RowParallelLinear::new(8, 4, true, pg, false);
685
686 let input = Variable::new(Tensor::randn(&[2, 8]), false);
687 let output = layer.forward(&input);
688
689 assert_eq!(output.data().shape(), &[2, 4]);
690 }
691}