Skip to main content

axonml_distributed/
fsdp.rs

1//! FSDP - Fully Sharded Data Parallel
2//!
3//! Implements Fully Sharded Data Parallel training for scaling to multiple GPUs/nodes.
4//! FSDP shards model parameters, gradients, and optimizer states across devices.
5//!
6//! Reference: "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models"
7//! https://arxiv.org/abs/1910.02054
8//!
9//! # Example
10//! ```rust,ignore
11//! use axonml_distributed::fsdp::{FullyShardedDataParallel, ShardingStrategy};
12//!
13//! let model = MyModel::new();
14//! let fsdp_model = FullyShardedDataParallel::new(model, process_group)
15//!     .sharding_strategy(ShardingStrategy::FullShard)
16//!     .cpu_offload(false);
17//! ```
18//!
19//! @version 0.1.0
20
21use crate::backend::ReduceOp;
22use crate::process_group::ProcessGroup;
23use axonml_autograd::Variable;
24use axonml_nn::{Module, Parameter};
25use axonml_tensor::Tensor;
26
27// =============================================================================
28// Sharding Strategy
29// =============================================================================
30
31/// Strategy for sharding parameters in FSDP.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ShardingStrategy {
34    /// Shard parameters, gradients, and optimizer state (ZeRO-3)
35    FullShard,
36    /// Shard gradients and optimizer state only (ZeRO-2)
37    ShardGradOp,
38    /// No sharding, replicate across ranks (DDP-like)
39    NoShard,
40    /// Hybrid sharding within node, replicate across nodes
41    HybridShard,
42}
43
44impl Default for ShardingStrategy {
45    fn default() -> Self {
46        Self::FullShard
47    }
48}
49
50// =============================================================================
51// FSDP State
52// =============================================================================
53
54/// State for a sharded parameter.
55#[derive(Debug)]
56#[allow(dead_code)]
57struct ShardedParam {
58    /// Local shard of the parameter
59    local_shard: Tensor<f32>,
60    /// Original shape before sharding
61    original_shape: Vec<usize>,
62    /// Number of elements in original parameter
63    numel: usize,
64    /// Padding added for even sharding (for uneven divisions)
65    padding: usize,
66}
67
68/// CPU offload configuration.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum CPUOffload {
71    /// No CPU offloading
72    None,
73    /// Offload parameters to CPU when not in use
74    Params,
75    /// Offload both parameters and gradients
76    Full,
77}
78
79impl Default for CPUOffload {
80    fn default() -> Self {
81        Self::None
82    }
83}
84
85// =============================================================================
86// Fully Sharded Data Parallel
87// =============================================================================
88
89/// Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
90///
91/// FSDP shards model parameters across devices, gathering them only when needed
92/// for computation and sharding them again afterward.
93pub struct FullyShardedDataParallel<M: Module> {
94    /// Wrapped module
95    module: M,
96    /// Process group for communication
97    process_group: ProcessGroup,
98    /// Sharding strategy
99    sharding_strategy: ShardingStrategy,
100    /// CPU offload configuration
101    cpu_offload: CPUOffload,
102    /// Sharded parameter states
103    sharded_params: Vec<ShardedParam>,
104    /// Whether module is currently gathered (unsharded)
105    is_gathered: bool,
106    /// Mixed precision compute dtype
107    mixed_precision: bool,
108}
109
110impl<M: Module> FullyShardedDataParallel<M> {
111    /// Creates a new FSDP wrapper.
112    pub fn new(module: M, process_group: ProcessGroup) -> Self {
113        let mut fsdp = Self {
114            module,
115            process_group,
116            sharding_strategy: ShardingStrategy::default(),
117            cpu_offload: CPUOffload::default(),
118            sharded_params: Vec::new(),
119            is_gathered: true,
120            mixed_precision: false,
121        };
122
123        // Initialize sharding
124        fsdp.shard_parameters();
125        fsdp
126    }
127
128    /// Builder: set sharding strategy.
129    pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
130        self.sharding_strategy = strategy;
131        self.shard_parameters();
132        self
133    }
134
135    /// Builder: set CPU offload configuration.
136    pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
137        self.cpu_offload = offload;
138        self
139    }
140
141    /// Builder: enable mixed precision.
142    pub fn mixed_precision(mut self, enabled: bool) -> Self {
143        self.mixed_precision = enabled;
144        self
145    }
146
147    /// Returns reference to wrapped module.
148    pub fn module(&self) -> &M {
149        &self.module
150    }
151
152    /// Returns mutable reference to wrapped module.
153    pub fn module_mut(&mut self) -> &mut M {
154        &mut self.module
155    }
156
157    /// Returns the process group.
158    pub fn process_group(&self) -> &ProcessGroup {
159        &self.process_group
160    }
161
162    /// Returns the sharding strategy.
163    pub fn strategy(&self) -> ShardingStrategy {
164        self.sharding_strategy
165    }
166
167    /// Shards parameters across devices.
168    fn shard_parameters(&mut self) {
169        if self.sharding_strategy == ShardingStrategy::NoShard {
170            return;
171        }
172
173        let world_size = self.process_group.world_size();
174        let rank = self.process_group.rank();
175
176        self.sharded_params.clear();
177
178        for param in self.module.parameters() {
179            let data = param.data();
180            let shape = data.shape().to_vec();
181            let numel = data.numel();
182
183            // Calculate shard size with padding for even division
184            let shard_size = (numel + world_size - 1) / world_size;
185            let padding = shard_size * world_size - numel;
186
187            // Get local shard
188            let flat_data = data.to_vec();
189            let start = rank * shard_size;
190            let end = ((rank + 1) * shard_size).min(flat_data.len());
191
192            let mut shard_data: Vec<f32> = if start < flat_data.len() {
193                flat_data[start..end].to_vec()
194            } else {
195                vec![0.0; shard_size]
196            };
197
198            // Pad to shard_size
199            while shard_data.len() < shard_size {
200                shard_data.push(0.0);
201            }
202
203            self.sharded_params.push(ShardedParam {
204                local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
205                original_shape: shape,
206                numel,
207                padding,
208            });
209        }
210
211        self.is_gathered = false;
212    }
213
214    /// Gathers all parameter shards before forward pass.
215    pub fn gather_parameters(&mut self) {
216        if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
217            return;
218        }
219
220        let _world_size = self.process_group.world_size();
221        let params = self.module.parameters();
222
223        for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
224            // All-gather the shards
225            let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
226
227            // Reshape back to original shape (removing padding)
228            let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
229            let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
230
231            param.update_data(restored);
232        }
233
234        self.is_gathered = true;
235    }
236
237    /// Shards parameters after forward/backward pass.
238    pub fn reshard_parameters(&mut self) {
239        if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
240            return;
241        }
242
243        self.shard_parameters();
244    }
245
246    /// Synchronizes gradients across all ranks.
247    pub fn sync_gradients(&self) {
248        match self.sharding_strategy {
249            ShardingStrategy::NoShard => {
250                // Full all-reduce like DDP
251                for param in self.module.parameters() {
252                    if let Some(grad) = param.grad() {
253                        let mut grad_tensor = grad.clone();
254                        self.process_group.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
255                    }
256                }
257            }
258            ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
259                // Reduce-scatter gradients to get sharded gradients
260                for param in self.module.parameters() {
261                    if let Some(grad) = param.grad() {
262                        let _reduced = self.process_group.reduce_scatter_tensor(&grad, ReduceOp::Average);
263                        // In full implementation, would update parameter's gradient shard
264                    }
265                }
266            }
267            ShardingStrategy::HybridShard => {
268                // All-reduce within node, reduce-scatter across nodes
269                for param in self.module.parameters() {
270                    if let Some(grad) = param.grad() {
271                        let mut grad_tensor = grad.clone();
272                        self.process_group.all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
273                    }
274                }
275            }
276        }
277    }
278
279    /// Clips gradients by global norm.
280    pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
281        let mut total_norm_sq = 0.0f32;
282
283        for param in self.module.parameters() {
284            if let Some(grad) = param.grad() {
285                let grad_vec = grad.to_vec();
286                let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
287                total_norm_sq += norm_sq;
288            }
289        }
290
291        // All-reduce total norm across ranks
292        let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
293        self.process_group.all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
294        let global_norm = norm_tensor.to_vec()[0].sqrt();
295
296        // Clip if necessary
297        if global_norm > max_norm {
298            let clip_coef = max_norm / (global_norm + 1e-6);
299            for param in self.module.parameters() {
300                if let Some(grad) = param.grad() {
301                    let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
302                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
303                    param.variable().set_grad(clipped_tensor);
304                }
305            }
306        }
307
308        global_norm
309    }
310
311    /// Estimates memory usage with different sharding strategies.
312    pub fn memory_estimate(&self) -> FSDPMemoryStats {
313        let params = self.module.parameters();
314        let total_params: usize = params.iter().map(|p| p.numel()).sum();
315        let world_size = self.process_group.world_size();
316
317        let bytes_per_param = 4; // f32
318        let param_memory = total_params * bytes_per_param;
319
320        let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
321            ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
322            ShardingStrategy::ShardGradOp => (
323                param_memory,
324                param_memory / world_size,
325                param_memory * 2 / world_size,
326            ),
327            ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
328                param_memory / world_size,
329                param_memory / world_size,
330                param_memory * 2 / world_size,
331            ),
332        };
333
334        FSDPMemoryStats {
335            total_params,
336            param_memory_bytes: sharded_params,
337            grad_memory_bytes: sharded_grads,
338            optim_memory_bytes: sharded_optim,
339            world_size,
340        }
341    }
342}
343
344impl<M: Module> Module for FullyShardedDataParallel<M> {
345    fn forward(&self, input: &Variable) -> Variable {
346        // Note: In a real implementation, gather would be called automatically
347        // through hooks before forward and reshard after
348        self.module.forward(input)
349    }
350
351    fn parameters(&self) -> Vec<Parameter> {
352        self.module.parameters()
353    }
354
355    fn train(&mut self) {
356        self.module.train();
357    }
358
359    fn eval(&mut self) {
360        self.module.eval();
361    }
362
363    fn is_training(&self) -> bool {
364        self.module.is_training()
365    }
366}
367
368/// Memory statistics for FSDP.
369#[derive(Debug, Clone)]
370pub struct FSDPMemoryStats {
371    /// Total number of parameters
372    pub total_params: usize,
373    /// Memory for parameters (bytes)
374    pub param_memory_bytes: usize,
375    /// Memory for gradients (bytes)
376    pub grad_memory_bytes: usize,
377    /// Memory for optimizer state (bytes)
378    pub optim_memory_bytes: usize,
379    /// World size (number of ranks)
380    pub world_size: usize,
381}
382
383impl FSDPMemoryStats {
384    /// Total memory per rank in MB.
385    pub fn total_memory_mb(&self) -> f32 {
386        (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
387            / (1024.0 * 1024.0)
388    }
389
390    /// Memory savings compared to no sharding.
391    pub fn memory_savings(&self) -> f32 {
392        if self.world_size > 1 {
393            1.0 - (1.0 / self.world_size as f32)
394        } else {
395            0.0
396        }
397    }
398}
399
400// =============================================================================
401// Tensor Parallelism
402// =============================================================================
403
404/// Column-parallel linear layer.
405///
406/// Splits the weight matrix along the column dimension across ranks.
407/// Each rank computes a portion of the output features.
408#[allow(dead_code)]
409pub struct ColumnParallelLinear {
410    /// Local weight shard
411    weight: Parameter,
412    /// Bias (replicated on all ranks)
413    bias: Option<Parameter>,
414    /// Process group
415    process_group: ProcessGroup,
416    /// Input features
417    in_features: usize,
418    /// Output features (total across all ranks)
419    out_features: usize,
420    /// Whether to gather output
421    gather_output: bool,
422}
423
424impl ColumnParallelLinear {
425    /// Creates a new column-parallel linear layer.
426    pub fn new(
427        in_features: usize,
428        out_features: usize,
429        bias: bool,
430        process_group: ProcessGroup,
431        gather_output: bool,
432    ) -> Self {
433        let world_size = process_group.world_size();
434        let local_out_features = out_features / world_size;
435
436        let weight_data = Tensor::randn(&[local_out_features, in_features]);
437        let weight = Parameter::new(weight_data, true);
438
439        let bias = if bias {
440            let bias_data = Tensor::zeros(&[local_out_features]);
441            Some(Parameter::new(bias_data, true))
442        } else {
443            None
444        };
445
446        Self {
447            weight,
448            bias,
449            process_group,
450            in_features,
451            out_features,
452            gather_output,
453        }
454    }
455}
456
457impl Module for ColumnParallelLinear {
458    fn forward(&self, input: &Variable) -> Variable {
459        // Local matmul: input @ weight.T
460        let weight_var = Variable::new(self.weight.data(), false);
461        let output = input.matmul(&weight_var.transpose(0, 1));
462
463        // Add bias
464        let output = if let Some(ref bias) = self.bias {
465            let bias_var = Variable::new(bias.data(), false);
466            output.add(&bias_var)
467        } else {
468            output
469        };
470
471        // Optionally gather output across ranks
472        if self.gather_output {
473            let gathered = self.process_group.all_gather_tensor(&output.data());
474            Variable::new(gathered, output.requires_grad())
475        } else {
476            output
477        }
478    }
479
480    fn parameters(&self) -> Vec<Parameter> {
481        let mut params = vec![self.weight.clone()];
482        if let Some(ref bias) = self.bias {
483            params.push(bias.clone());
484        }
485        params
486    }
487}
488
489/// Row-parallel linear layer.
490///
491/// Splits the weight matrix along the row dimension across ranks.
492/// Each rank has a portion of the input features.
493#[allow(dead_code)]
494pub struct RowParallelLinear {
495    /// Local weight shard
496    weight: Parameter,
497    /// Bias (only on rank 0)
498    bias: Option<Parameter>,
499    /// Process group
500    process_group: ProcessGroup,
501    /// Input features (total across all ranks)
502    in_features: usize,
503    /// Output features
504    out_features: usize,
505    /// Whether input is already split
506    input_is_parallel: bool,
507}
508
509impl RowParallelLinear {
510    /// Creates a new row-parallel linear layer.
511    pub fn new(
512        in_features: usize,
513        out_features: usize,
514        bias: bool,
515        process_group: ProcessGroup,
516        input_is_parallel: bool,
517    ) -> Self {
518        let world_size = process_group.world_size();
519        let rank = process_group.rank();
520        let local_in_features = in_features / world_size;
521
522        let weight_data = Tensor::randn(&[out_features, local_in_features]);
523        let weight = Parameter::new(weight_data, true);
524
525        // Only rank 0 has bias
526        let bias = if bias && rank == 0 {
527            let bias_data = Tensor::zeros(&[out_features]);
528            Some(Parameter::new(bias_data, true))
529        } else {
530            None
531        };
532
533        Self {
534            weight,
535            bias,
536            process_group,
537            in_features,
538            out_features,
539            input_is_parallel,
540        }
541    }
542}
543
544impl Module for RowParallelLinear {
545    fn forward(&self, input: &Variable) -> Variable {
546        // If input is not parallel, take local shard
547        let local_input = if self.input_is_parallel {
548            input.clone()
549        } else {
550            // Split input along feature dimension for row parallelism
551            let world_size = self.process_group.world_size();
552            let rank = self.process_group.rank();
553            let data = input.data();
554            let shape = data.shape();
555            let feature_dim = shape[shape.len() - 1];
556            let local_features = feature_dim / world_size;
557            let start = rank * local_features;
558            let end = start + local_features;
559
560            // Slice the last dimension
561            let sliced = if shape.len() == 2 {
562                data.slice(&[0..shape[0], start..end])
563            } else {
564                data.clone() // Fallback for other shapes
565            };
566            Variable::new(sliced, input.requires_grad())
567        };
568
569        // Local matmul
570        let weight_var = Variable::new(self.weight.data(), false);
571        let local_output = local_input.matmul(&weight_var.transpose(0, 1));
572
573        // All-reduce to combine partial outputs
574        let mut output_data = local_output.data().clone();
575        self.process_group.all_reduce_tensor(&mut output_data, ReduceOp::Sum);
576        let output = Variable::new(output_data, local_output.requires_grad());
577
578        // Add bias (only on rank 0, then broadcast)
579        if let Some(ref bias) = self.bias {
580            let bias_var = Variable::new(bias.data(), false);
581            output.add(&bias_var)
582        } else {
583            output
584        }
585    }
586
587    fn parameters(&self) -> Vec<Parameter> {
588        let mut params = vec![self.weight.clone()];
589        if let Some(ref bias) = self.bias {
590            params.push(bias.clone());
591        }
592        params
593    }
594}
595
596// =============================================================================
597// Tests
598// =============================================================================
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603    use axonml_nn::Linear;
604
605    #[test]
606    fn test_sharding_strategy_default() {
607        assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
608    }
609
610    #[test]
611    fn test_fsdp_creation() {
612        let model = Linear::new(10, 5);
613        let pg = ProcessGroup::mock();
614        let fsdp = FullyShardedDataParallel::new(model, pg);
615
616        assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
617    }
618
619    #[test]
620    fn test_fsdp_forward() {
621        let model = Linear::new(4, 2);
622        let pg = ProcessGroup::mock();
623        let mut fsdp = FullyShardedDataParallel::new(model, pg);
624
625        // Gather before forward
626        fsdp.gather_parameters();
627
628        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
629        let output = fsdp.forward(&input);
630
631        assert_eq!(output.data().shape(), &[1, 2]);
632    }
633
634    #[test]
635    fn test_fsdp_builder() {
636        let model = Linear::new(10, 5);
637        let pg = ProcessGroup::mock();
638
639        let fsdp = FullyShardedDataParallel::new(model, pg)
640            .sharding_strategy(ShardingStrategy::ShardGradOp)
641            .cpu_offload(CPUOffload::Params)
642            .mixed_precision(true);
643
644        assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
645    }
646
647    #[test]
648    fn test_fsdp_memory_stats() {
649        let model = Linear::new(100, 50);
650        let pg = ProcessGroup::mock();
651        let fsdp = FullyShardedDataParallel::new(model, pg);
652
653        let stats = fsdp.memory_estimate();
654        assert!(stats.total_params > 0);
655        assert!(stats.total_memory_mb() > 0.0);
656    }
657
658    #[test]
659    fn test_fsdp_no_shard() {
660        let model = Linear::new(10, 5);
661        let pg = ProcessGroup::mock();
662        let fsdp = FullyShardedDataParallel::new(model, pg)
663            .sharding_strategy(ShardingStrategy::NoShard);
664
665        assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
666    }
667
668    #[test]
669    fn test_column_parallel_linear() {
670        let pg = ProcessGroup::mock();
671        // With world_size=1, local_out_features = out_features / 1 = 4
672        let layer = ColumnParallelLinear::new(8, 4, true, pg, false); // Don't gather for simple test
673
674        let input = Variable::new(Tensor::randn(&[2, 8]), false);
675        let output = layer.forward(&input);
676
677        // Output shape should be [batch, local_out_features] = [2, 4]
678        assert_eq!(output.data().shape(), &[2, 4]);
679    }
680
681    #[test]
682    fn test_row_parallel_linear() {
683        let pg = ProcessGroup::mock();
684        let layer = RowParallelLinear::new(8, 4, true, pg, false);
685
686        let input = Variable::new(Tensor::randn(&[2, 8]), false);
687        let output = layer.forward(&input);
688
689        assert_eq!(output.data().shape(), &[2, 4]);
690    }
691}