1use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum ShardingStrategy {
30 #[default]
32 FullShard,
33 ShardGradOp,
35 NoShard,
37 HybridShard,
39}
40
41#[derive(Debug)]
47struct ShardedParam {
48 local_shard: Tensor<f32>,
50 original_shape: Vec<usize>,
52 numel: usize,
54 pub padding: usize,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61pub enum CPUOffload {
62 #[default]
64 None,
65 Params,
67 Full,
69}
70
71pub struct FullyShardedDataParallel<M: Module> {
80 module: M,
82 process_group: ProcessGroup,
84 sharding_strategy: ShardingStrategy,
86 cpu_offload: CPUOffload,
88 sharded_params: Vec<ShardedParam>,
90 is_gathered: bool,
92 mixed_precision: bool,
94}
95
96impl<M: Module> FullyShardedDataParallel<M> {
97 pub fn new(module: M, process_group: ProcessGroup) -> Self {
99 let mut fsdp = Self {
100 module,
101 process_group,
102 sharding_strategy: ShardingStrategy::default(),
103 cpu_offload: CPUOffload::default(),
104 sharded_params: Vec::new(),
105 is_gathered: true,
106 mixed_precision: false,
107 };
108
109 fsdp.shard_parameters();
111 fsdp
112 }
113
114 pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
116 self.sharding_strategy = strategy;
117 self.shard_parameters();
118 self
119 }
120
121 pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
123 self.cpu_offload = offload;
124 self
125 }
126
127 pub fn mixed_precision(mut self, enabled: bool) -> Self {
129 self.mixed_precision = enabled;
130 self
131 }
132
133 pub fn module(&self) -> &M {
135 &self.module
136 }
137
138 pub fn module_mut(&mut self) -> &mut M {
140 &mut self.module
141 }
142
143 pub fn process_group(&self) -> &ProcessGroup {
145 &self.process_group
146 }
147
148 pub fn strategy(&self) -> ShardingStrategy {
150 self.sharding_strategy
151 }
152
153 fn shard_parameters(&mut self) {
155 if self.sharding_strategy == ShardingStrategy::NoShard {
156 return;
157 }
158
159 let world_size = self.process_group.world_size();
160 let rank = self.process_group.rank();
161
162 self.sharded_params.clear();
163
164 for param in self.module.parameters() {
165 let data = param.data();
166 let shape = data.shape().to_vec();
167 let numel = data.numel();
168
169 let shard_size = numel.div_ceil(world_size);
171 let padding = shard_size * world_size - numel;
172
173 let flat_data = data.to_vec();
175 let start = rank * shard_size;
176 let end = ((rank + 1) * shard_size).min(flat_data.len());
177
178 let mut shard_data: Vec<f32> = if start < flat_data.len() {
179 flat_data[start..end].to_vec()
180 } else {
181 vec![0.0; shard_size]
182 };
183
184 while shard_data.len() < shard_size {
186 shard_data.push(0.0);
187 }
188
189 self.sharded_params.push(ShardedParam {
190 local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
191 original_shape: shape,
192 numel,
193 padding,
194 });
195 }
196
197 self.is_gathered = false;
198 }
199
200 pub fn gather_parameters(&mut self) {
202 if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
203 return;
204 }
205
206 let params = self.module.parameters();
207
208 for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
209 let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
211
212 let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
214 let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
215
216 param.update_data(restored);
217 }
218
219 self.is_gathered = true;
220 }
221
222 pub fn reshard_parameters(&mut self) {
224 if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
225 return;
226 }
227
228 self.shard_parameters();
229 }
230
231 pub fn sync_gradients(&self) {
237 match self.sharding_strategy {
238 ShardingStrategy::NoShard => {
239 for param in self.module.parameters() {
241 if let Some(grad) = param.grad() {
242 let mut grad_tensor = grad.clone();
243 self.process_group
244 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
245 param.set_grad(grad_tensor);
246 }
247 }
248 }
249 ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
250 for param in self.module.parameters() {
252 if let Some(grad) = param.grad() {
253 let reduced = self
254 .process_group
255 .reduce_scatter_tensor(&grad, ReduceOp::Average);
256 param.set_grad(reduced);
258 }
259 }
260 }
261 ShardingStrategy::HybridShard => {
262 for param in self.module.parameters() {
264 if let Some(grad) = param.grad() {
265 let mut grad_tensor = grad.clone();
266 self.process_group
267 .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
268 param.set_grad(grad_tensor);
269 }
270 }
271 }
272 }
273 }
274
275 pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
277 let mut total_norm_sq = 0.0f32;
278
279 for param in self.module.parameters() {
280 if let Some(grad) = param.grad() {
281 let grad_vec = grad.to_vec();
282 let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
283 total_norm_sq += norm_sq;
284 }
285 }
286
287 let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
289 self.process_group
290 .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
291 let global_norm = norm_tensor.to_vec()[0].sqrt();
292
293 if global_norm > max_norm {
295 let clip_coef = max_norm / (global_norm + 1e-6);
296 for param in self.module.parameters() {
297 if let Some(grad) = param.grad() {
298 let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
299 let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
300 param.variable().set_grad(clipped_tensor);
301 }
302 }
303 }
304
305 global_norm
306 }
307
308 pub fn memory_estimate(&self) -> FSDPMemoryStats {
310 let params = self.module.parameters();
311 let total_params: usize = params.iter().map(|p| p.numel()).sum();
312 let world_size = self.process_group.world_size();
313
314 let bytes_per_param = 4; let param_memory = total_params * bytes_per_param;
316
317 let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
318 ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
319 ShardingStrategy::ShardGradOp => (
320 param_memory,
321 param_memory / world_size,
322 param_memory * 2 / world_size,
323 ),
324 ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
325 param_memory / world_size,
326 param_memory / world_size,
327 param_memory * 2 / world_size,
328 ),
329 };
330
331 FSDPMemoryStats {
332 total_params,
333 param_memory_bytes: sharded_params,
334 grad_memory_bytes: sharded_grads,
335 optim_memory_bytes: sharded_optim,
336 world_size,
337 }
338 }
339}
340
341impl<M: Module> Module for FullyShardedDataParallel<M> {
342 fn forward(&self, input: &Variable) -> Variable {
343 self.module.forward(input)
346 }
347
348 fn parameters(&self) -> Vec<Parameter> {
349 self.module.parameters()
350 }
351
352 fn train(&mut self) {
353 self.module.train();
354 }
355
356 fn eval(&mut self) {
357 self.module.eval();
358 }
359
360 fn is_training(&self) -> bool {
361 self.module.is_training()
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct FSDPMemoryStats {
368 pub total_params: usize,
370 pub param_memory_bytes: usize,
372 pub grad_memory_bytes: usize,
374 pub optim_memory_bytes: usize,
376 pub world_size: usize,
378}
379
380impl FSDPMemoryStats {
381 pub fn total_memory_mb(&self) -> f32 {
383 (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
384 / (1024.0 * 1024.0)
385 }
386
387 pub fn memory_savings(&self) -> f32 {
389 if self.world_size > 1 {
390 1.0 - (1.0 / self.world_size as f32)
391 } else {
392 0.0
393 }
394 }
395}
396
397pub struct ColumnParallelLinear {
406 weight: Parameter,
408 bias: Option<Parameter>,
410 process_group: ProcessGroup,
412 in_features: usize,
414 out_features: usize,
416 gather_output: bool,
418}
419
420impl ColumnParallelLinear {
421 pub fn new(
423 in_features: usize,
424 out_features: usize,
425 bias: bool,
426 process_group: ProcessGroup,
427 gather_output: bool,
428 ) -> Self {
429 let world_size = process_group.world_size();
430 let local_out_features = out_features / world_size;
431
432 let weight_data = Tensor::randn(&[local_out_features, in_features]);
433 let weight = Parameter::new(weight_data, true);
434
435 let bias = if bias {
436 let bias_data = Tensor::zeros(&[local_out_features]);
437 Some(Parameter::new(bias_data, true))
438 } else {
439 None
440 };
441
442 Self {
443 weight,
444 bias,
445 process_group,
446 in_features,
447 out_features,
448 gather_output,
449 }
450 }
451}
452
453impl Module for ColumnParallelLinear {
454 fn forward(&self, input: &Variable) -> Variable {
455 let weight_var = Variable::new(self.weight.data(), false);
457 let output = input.matmul(&weight_var.transpose(0, 1));
458
459 let output = if let Some(ref bias) = self.bias {
461 let bias_var = Variable::new(bias.data(), false);
462 output.add(&bias_var)
463 } else {
464 output
465 };
466
467 if self.gather_output {
469 let gathered = self.process_group.all_gather_tensor(&output.data());
470 Variable::new(gathered, output.requires_grad())
471 } else {
472 output
473 }
474 }
475
476 fn parameters(&self) -> Vec<Parameter> {
477 let mut params = vec![self.weight.clone()];
478 if let Some(ref bias) = self.bias {
479 params.push(bias.clone());
480 }
481 params
482 }
483}
484
485pub struct RowParallelLinear {
490 weight: Parameter,
492 bias: Option<Parameter>,
494 process_group: ProcessGroup,
496 in_features: usize,
498 out_features: usize,
500 input_is_parallel: bool,
502}
503
504impl RowParallelLinear {
505 pub fn new(
507 in_features: usize,
508 out_features: usize,
509 bias: bool,
510 process_group: ProcessGroup,
511 input_is_parallel: bool,
512 ) -> Self {
513 let world_size = process_group.world_size();
514 let rank = process_group.rank();
515 let local_in_features = in_features / world_size;
516
517 let weight_data = Tensor::randn(&[out_features, local_in_features]);
518 let weight = Parameter::new(weight_data, true);
519
520 let bias = if bias && rank == 0 {
522 let bias_data = Tensor::zeros(&[out_features]);
523 Some(Parameter::new(bias_data, true))
524 } else {
525 None
526 };
527
528 Self {
529 weight,
530 bias,
531 process_group,
532 in_features,
533 out_features,
534 input_is_parallel,
535 }
536 }
537}
538
539impl Module for RowParallelLinear {
540 fn forward(&self, input: &Variable) -> Variable {
541 let local_input = if self.input_is_parallel {
543 input.clone()
544 } else {
545 let world_size = self.process_group.world_size();
547 let rank = self.process_group.rank();
548 let data = input.data();
549 let shape = data.shape();
550 let feature_dim = shape[shape.len() - 1];
551 let local_features = feature_dim / world_size;
552 let start = rank * local_features;
553 let end = start + local_features;
554
555 let sliced = if shape.len() == 2 {
557 data.slice(&[0..shape[0], start..end])
558 } else {
559 data.clone() };
561 Variable::new(sliced, input.requires_grad())
562 };
563
564 let weight_var = Variable::new(self.weight.data(), false);
566 let local_output = local_input.matmul(&weight_var.transpose(0, 1));
567
568 let mut output_data = local_output.data().clone();
570 self.process_group
571 .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
572 let output = Variable::new(output_data, local_output.requires_grad());
573
574 if let Some(ref bias) = self.bias {
576 let bias_var = Variable::new(bias.data(), false);
577 output.add(&bias_var)
578 } else {
579 output
580 }
581 }
582
583 fn parameters(&self) -> Vec<Parameter> {
584 let mut params = vec![self.weight.clone()];
585 if let Some(ref bias) = self.bias {
586 params.push(bias.clone());
587 }
588 params
589 }
590}
591
592#[cfg(test)]
597mod tests {
598 use super::*;
599 use axonml_nn::Linear;
600
601 #[test]
602 fn test_sharding_strategy_default() {
603 assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
604 }
605
606 #[test]
607 fn test_fsdp_creation() {
608 let model = Linear::new(10, 5);
609 let pg = ProcessGroup::mock();
610 let fsdp = FullyShardedDataParallel::new(model, pg);
611
612 assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
613 }
614
615 #[test]
616 fn test_fsdp_forward() {
617 let model = Linear::new(4, 2);
618 let pg = ProcessGroup::mock();
619 let mut fsdp = FullyShardedDataParallel::new(model, pg);
620
621 fsdp.gather_parameters();
623
624 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
625 let output = fsdp.forward(&input);
626
627 assert_eq!(output.data().shape(), &[1, 2]);
628 }
629
630 #[test]
631 fn test_fsdp_builder() {
632 let model = Linear::new(10, 5);
633 let pg = ProcessGroup::mock();
634
635 let fsdp = FullyShardedDataParallel::new(model, pg)
636 .sharding_strategy(ShardingStrategy::ShardGradOp)
637 .cpu_offload(CPUOffload::Params)
638 .mixed_precision(true);
639
640 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
641 }
642
643 #[test]
644 fn test_fsdp_memory_stats() {
645 let model = Linear::new(100, 50);
646 let pg = ProcessGroup::mock();
647 let fsdp = FullyShardedDataParallel::new(model, pg);
648
649 let stats = fsdp.memory_estimate();
650 assert!(stats.total_params > 0);
651 assert!(stats.total_memory_mb() > 0.0);
652 }
653
654 #[test]
655 fn test_fsdp_no_shard() {
656 let model = Linear::new(10, 5);
657 let pg = ProcessGroup::mock();
658 let fsdp =
659 FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
660
661 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
662 }
663
664 #[test]
665 fn test_column_parallel_linear() {
666 let pg = ProcessGroup::mock();
667 let layer = ColumnParallelLinear::new(8, 4, true, pg, false); let input = Variable::new(Tensor::randn(&[2, 8]), false);
671 let output = layer.forward(&input);
672
673 assert_eq!(output.data().shape(), &[2, 4]);
675 }
676
677 #[test]
678 fn test_row_parallel_linear() {
679 let pg = ProcessGroup::mock();
680 let layer = RowParallelLinear::new(8, 4, true, pg, false);
681
682 let input = Variable::new(Tensor::randn(&[2, 8]), false);
683 let output = layer.forward(&input);
684
685 assert_eq!(output.data().shape(), &[2, 4]);
686 }
687}