Skip to main content

axonml_distributed/
fsdp.rs

1//! FSDP - Fully Sharded Data Parallel
2//!
3//! Implements Fully Sharded Data Parallel training for scaling to multiple GPUs/nodes.
4//! FSDP shards model parameters, gradients, and optimizer states across devices.
5//!
6//! Reference: "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models"
7//! https://arxiv.org/abs/1910.02054
8//!
9//! # Example
10//! ```rust,ignore
11//! use axonml_distributed::fsdp::{FullyShardedDataParallel, ShardingStrategy};
12//!
13//! let model = MyModel::new();
14//! let fsdp_model = FullyShardedDataParallel::new(model, process_group)
15//!     .sharding_strategy(ShardingStrategy::FullShard)
16//!     .cpu_offload(false);
17//! ```
18//!
19//! @version 0.1.0
20
21use crate::backend::ReduceOp;
22use crate::process_group::ProcessGroup;
23use axonml_autograd::Variable;
24use axonml_nn::{Module, Parameter};
25use axonml_tensor::Tensor;
26
27// =============================================================================
28// Sharding Strategy
29// =============================================================================
30
31/// Strategy for sharding parameters in FSDP.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ShardingStrategy {
34    /// Shard parameters, gradients, and optimizer state (ZeRO-3)
35    FullShard,
36    /// Shard gradients and optimizer state only (ZeRO-2)
37    ShardGradOp,
38    /// No sharding, replicate across ranks (DDP-like)
39    NoShard,
40    /// Hybrid sharding within node, replicate across nodes
41    HybridShard,
42}
43
44impl Default for ShardingStrategy {
45    fn default() -> Self {
46        Self::FullShard
47    }
48}
49
50// =============================================================================
51// FSDP State
52// =============================================================================
53
54/// State for a sharded parameter.
55#[derive(Debug)]
56#[allow(dead_code)]
57struct ShardedParam {
58    /// Local shard of the parameter
59    local_shard: Tensor<f32>,
60    /// Original shape before sharding
61    original_shape: Vec<usize>,
62    /// Number of elements in original parameter
63    numel: usize,
64    /// Padding added for even sharding (for uneven divisions)
65    padding: usize,
66}
67
68/// CPU offload configuration.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum CPUOffload {
71    /// No CPU offloading
72    None,
73    /// Offload parameters to CPU when not in use
74    Params,
75    /// Offload both parameters and gradients
76    Full,
77}
78
79impl Default for CPUOffload {
80    fn default() -> Self {
81        Self::None
82    }
83}
84
85// =============================================================================
86// Fully Sharded Data Parallel
87// =============================================================================
88
89/// Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
90///
91/// FSDP shards model parameters across devices, gathering them only when needed
92/// for computation and sharding them again afterward.
93pub struct FullyShardedDataParallel<M: Module> {
94    /// Wrapped module
95    module: M,
96    /// Process group for communication
97    process_group: ProcessGroup,
98    /// Sharding strategy
99    sharding_strategy: ShardingStrategy,
100    /// CPU offload configuration
101    cpu_offload: CPUOffload,
102    /// Sharded parameter states
103    sharded_params: Vec<ShardedParam>,
104    /// Whether module is currently gathered (unsharded)
105    is_gathered: bool,
106    /// Mixed precision compute dtype
107    mixed_precision: bool,
108}
109
110impl<M: Module> FullyShardedDataParallel<M> {
111    /// Creates a new FSDP wrapper.
112    pub fn new(module: M, process_group: ProcessGroup) -> Self {
113        let mut fsdp = Self {
114            module,
115            process_group,
116            sharding_strategy: ShardingStrategy::default(),
117            cpu_offload: CPUOffload::default(),
118            sharded_params: Vec::new(),
119            is_gathered: true,
120            mixed_precision: false,
121        };
122
123        // Initialize sharding
124        fsdp.shard_parameters();
125        fsdp
126    }
127
128    /// Builder: set sharding strategy.
129    pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
130        self.sharding_strategy = strategy;
131        self.shard_parameters();
132        self
133    }
134
135    /// Builder: set CPU offload configuration.
136    pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
137        self.cpu_offload = offload;
138        self
139    }
140
141    /// Builder: enable mixed precision.
142    pub fn mixed_precision(mut self, enabled: bool) -> Self {
143        self.mixed_precision = enabled;
144        self
145    }
146
147    /// Returns reference to wrapped module.
148    pub fn module(&self) -> &M {
149        &self.module
150    }
151
152    /// Returns mutable reference to wrapped module.
153    pub fn module_mut(&mut self) -> &mut M {
154        &mut self.module
155    }
156
157    /// Returns the process group.
158    pub fn process_group(&self) -> &ProcessGroup {
159        &self.process_group
160    }
161
162    /// Returns the sharding strategy.
163    pub fn strategy(&self) -> ShardingStrategy {
164        self.sharding_strategy
165    }
166
167    /// Shards parameters across devices.
168    fn shard_parameters(&mut self) {
169        if self.sharding_strategy == ShardingStrategy::NoShard {
170            return;
171        }
172
173        let world_size = self.process_group.world_size();
174        let rank = self.process_group.rank();
175
176        self.sharded_params.clear();
177
178        for param in self.module.parameters() {
179            let data = param.data();
180            let shape = data.shape().to_vec();
181            let numel = data.numel();
182
183            // Calculate shard size with padding for even division
184            let shard_size = (numel + world_size - 1) / world_size;
185            let padding = shard_size * world_size - numel;
186
187            // Get local shard
188            let flat_data = data.to_vec();
189            let start = rank * shard_size;
190            let end = ((rank + 1) * shard_size).min(flat_data.len());
191
192            let mut shard_data: Vec<f32> = if start < flat_data.len() {
193                flat_data[start..end].to_vec()
194            } else {
195                vec![0.0; shard_size]
196            };
197
198            // Pad to shard_size
199            while shard_data.len() < shard_size {
200                shard_data.push(0.0);
201            }
202
203            self.sharded_params.push(ShardedParam {
204                local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
205                original_shape: shape,
206                numel,
207                padding,
208            });
209        }
210
211        self.is_gathered = false;
212    }
213
214    /// Gathers all parameter shards before forward pass.
215    pub fn gather_parameters(&mut self) {
216        if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
217            return;
218        }
219
220        let _world_size = self.process_group.world_size();
221        let params = self.module.parameters();
222
223        for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
224            // All-gather the shards
225            let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
226
227            // Reshape back to original shape (removing padding)
228            let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
229            let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
230
231            param.update_data(restored);
232        }
233
234        self.is_gathered = true;
235    }
236
237    /// Shards parameters after forward/backward pass.
238    pub fn reshard_parameters(&mut self) {
239        if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
240            return;
241        }
242
243        self.shard_parameters();
244    }
245
246    /// Synchronizes gradients across all ranks.
247    pub fn sync_gradients(&self) {
248        match self.sharding_strategy {
249            ShardingStrategy::NoShard => {
250                // Full all-reduce like DDP
251                for param in self.module.parameters() {
252                    if let Some(grad) = param.grad() {
253                        let mut grad_tensor = grad.clone();
254                        self.process_group
255                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
256                    }
257                }
258            }
259            ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
260                // Reduce-scatter gradients to get sharded gradients
261                for param in self.module.parameters() {
262                    if let Some(grad) = param.grad() {
263                        let _reduced = self
264                            .process_group
265                            .reduce_scatter_tensor(&grad, ReduceOp::Average);
266                        // In full implementation, would update parameter's gradient shard
267                    }
268                }
269            }
270            ShardingStrategy::HybridShard => {
271                // All-reduce within node, reduce-scatter across nodes
272                for param in self.module.parameters() {
273                    if let Some(grad) = param.grad() {
274                        let mut grad_tensor = grad.clone();
275                        self.process_group
276                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
277                    }
278                }
279            }
280        }
281    }
282
283    /// Clips gradients by global norm.
284    pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
285        let mut total_norm_sq = 0.0f32;
286
287        for param in self.module.parameters() {
288            if let Some(grad) = param.grad() {
289                let grad_vec = grad.to_vec();
290                let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
291                total_norm_sq += norm_sq;
292            }
293        }
294
295        // All-reduce total norm across ranks
296        let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
297        self.process_group
298            .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
299        let global_norm = norm_tensor.to_vec()[0].sqrt();
300
301        // Clip if necessary
302        if global_norm > max_norm {
303            let clip_coef = max_norm / (global_norm + 1e-6);
304            for param in self.module.parameters() {
305                if let Some(grad) = param.grad() {
306                    let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
307                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
308                    param.variable().set_grad(clipped_tensor);
309                }
310            }
311        }
312
313        global_norm
314    }
315
316    /// Estimates memory usage with different sharding strategies.
317    pub fn memory_estimate(&self) -> FSDPMemoryStats {
318        let params = self.module.parameters();
319        let total_params: usize = params.iter().map(|p| p.numel()).sum();
320        let world_size = self.process_group.world_size();
321
322        let bytes_per_param = 4; // f32
323        let param_memory = total_params * bytes_per_param;
324
325        let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
326            ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
327            ShardingStrategy::ShardGradOp => (
328                param_memory,
329                param_memory / world_size,
330                param_memory * 2 / world_size,
331            ),
332            ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
333                param_memory / world_size,
334                param_memory / world_size,
335                param_memory * 2 / world_size,
336            ),
337        };
338
339        FSDPMemoryStats {
340            total_params,
341            param_memory_bytes: sharded_params,
342            grad_memory_bytes: sharded_grads,
343            optim_memory_bytes: sharded_optim,
344            world_size,
345        }
346    }
347}
348
349impl<M: Module> Module for FullyShardedDataParallel<M> {
350    fn forward(&self, input: &Variable) -> Variable {
351        // Note: In a real implementation, gather would be called automatically
352        // through hooks before forward and reshard after
353        self.module.forward(input)
354    }
355
356    fn parameters(&self) -> Vec<Parameter> {
357        self.module.parameters()
358    }
359
360    fn train(&mut self) {
361        self.module.train();
362    }
363
364    fn eval(&mut self) {
365        self.module.eval();
366    }
367
368    fn is_training(&self) -> bool {
369        self.module.is_training()
370    }
371}
372
373/// Memory statistics for FSDP.
374#[derive(Debug, Clone)]
375pub struct FSDPMemoryStats {
376    /// Total number of parameters
377    pub total_params: usize,
378    /// Memory for parameters (bytes)
379    pub param_memory_bytes: usize,
380    /// Memory for gradients (bytes)
381    pub grad_memory_bytes: usize,
382    /// Memory for optimizer state (bytes)
383    pub optim_memory_bytes: usize,
384    /// World size (number of ranks)
385    pub world_size: usize,
386}
387
388impl FSDPMemoryStats {
389    /// Total memory per rank in MB.
390    pub fn total_memory_mb(&self) -> f32 {
391        (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
392            / (1024.0 * 1024.0)
393    }
394
395    /// Memory savings compared to no sharding.
396    pub fn memory_savings(&self) -> f32 {
397        if self.world_size > 1 {
398            1.0 - (1.0 / self.world_size as f32)
399        } else {
400            0.0
401        }
402    }
403}
404
405// =============================================================================
406// Tensor Parallelism
407// =============================================================================
408
409/// Column-parallel linear layer.
410///
411/// Splits the weight matrix along the column dimension across ranks.
412/// Each rank computes a portion of the output features.
413#[allow(dead_code)]
414pub struct ColumnParallelLinear {
415    /// Local weight shard
416    weight: Parameter,
417    /// Bias (replicated on all ranks)
418    bias: Option<Parameter>,
419    /// Process group
420    process_group: ProcessGroup,
421    /// Input features
422    in_features: usize,
423    /// Output features (total across all ranks)
424    out_features: usize,
425    /// Whether to gather output
426    gather_output: bool,
427}
428
429impl ColumnParallelLinear {
430    /// Creates a new column-parallel linear layer.
431    pub fn new(
432        in_features: usize,
433        out_features: usize,
434        bias: bool,
435        process_group: ProcessGroup,
436        gather_output: bool,
437    ) -> Self {
438        let world_size = process_group.world_size();
439        let local_out_features = out_features / world_size;
440
441        let weight_data = Tensor::randn(&[local_out_features, in_features]);
442        let weight = Parameter::new(weight_data, true);
443
444        let bias = if bias {
445            let bias_data = Tensor::zeros(&[local_out_features]);
446            Some(Parameter::new(bias_data, true))
447        } else {
448            None
449        };
450
451        Self {
452            weight,
453            bias,
454            process_group,
455            in_features,
456            out_features,
457            gather_output,
458        }
459    }
460}
461
462impl Module for ColumnParallelLinear {
463    fn forward(&self, input: &Variable) -> Variable {
464        // Local matmul: input @ weight.T
465        let weight_var = Variable::new(self.weight.data(), false);
466        let output = input.matmul(&weight_var.transpose(0, 1));
467
468        // Add bias
469        let output = if let Some(ref bias) = self.bias {
470            let bias_var = Variable::new(bias.data(), false);
471            output.add(&bias_var)
472        } else {
473            output
474        };
475
476        // Optionally gather output across ranks
477        if self.gather_output {
478            let gathered = self.process_group.all_gather_tensor(&output.data());
479            Variable::new(gathered, output.requires_grad())
480        } else {
481            output
482        }
483    }
484
485    fn parameters(&self) -> Vec<Parameter> {
486        let mut params = vec![self.weight.clone()];
487        if let Some(ref bias) = self.bias {
488            params.push(bias.clone());
489        }
490        params
491    }
492}
493
494/// Row-parallel linear layer.
495///
496/// Splits the weight matrix along the row dimension across ranks.
497/// Each rank has a portion of the input features.
498#[allow(dead_code)]
499pub struct RowParallelLinear {
500    /// Local weight shard
501    weight: Parameter,
502    /// Bias (only on rank 0)
503    bias: Option<Parameter>,
504    /// Process group
505    process_group: ProcessGroup,
506    /// Input features (total across all ranks)
507    in_features: usize,
508    /// Output features
509    out_features: usize,
510    /// Whether input is already split
511    input_is_parallel: bool,
512}
513
514impl RowParallelLinear {
515    /// Creates a new row-parallel linear layer.
516    pub fn new(
517        in_features: usize,
518        out_features: usize,
519        bias: bool,
520        process_group: ProcessGroup,
521        input_is_parallel: bool,
522    ) -> Self {
523        let world_size = process_group.world_size();
524        let rank = process_group.rank();
525        let local_in_features = in_features / world_size;
526
527        let weight_data = Tensor::randn(&[out_features, local_in_features]);
528        let weight = Parameter::new(weight_data, true);
529
530        // Only rank 0 has bias
531        let bias = if bias && rank == 0 {
532            let bias_data = Tensor::zeros(&[out_features]);
533            Some(Parameter::new(bias_data, true))
534        } else {
535            None
536        };
537
538        Self {
539            weight,
540            bias,
541            process_group,
542            in_features,
543            out_features,
544            input_is_parallel,
545        }
546    }
547}
548
549impl Module for RowParallelLinear {
550    fn forward(&self, input: &Variable) -> Variable {
551        // If input is not parallel, take local shard
552        let local_input = if self.input_is_parallel {
553            input.clone()
554        } else {
555            // Split input along feature dimension for row parallelism
556            let world_size = self.process_group.world_size();
557            let rank = self.process_group.rank();
558            let data = input.data();
559            let shape = data.shape();
560            let feature_dim = shape[shape.len() - 1];
561            let local_features = feature_dim / world_size;
562            let start = rank * local_features;
563            let end = start + local_features;
564
565            // Slice the last dimension
566            let sliced = if shape.len() == 2 {
567                data.slice(&[0..shape[0], start..end])
568            } else {
569                data.clone() // Fallback for other shapes
570            };
571            Variable::new(sliced, input.requires_grad())
572        };
573
574        // Local matmul
575        let weight_var = Variable::new(self.weight.data(), false);
576        let local_output = local_input.matmul(&weight_var.transpose(0, 1));
577
578        // All-reduce to combine partial outputs
579        let mut output_data = local_output.data().clone();
580        self.process_group
581            .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
582        let output = Variable::new(output_data, local_output.requires_grad());
583
584        // Add bias (only on rank 0, then broadcast)
585        if let Some(ref bias) = self.bias {
586            let bias_var = Variable::new(bias.data(), false);
587            output.add(&bias_var)
588        } else {
589            output
590        }
591    }
592
593    fn parameters(&self) -> Vec<Parameter> {
594        let mut params = vec![self.weight.clone()];
595        if let Some(ref bias) = self.bias {
596            params.push(bias.clone());
597        }
598        params
599    }
600}
601
602// =============================================================================
603// Tests
604// =============================================================================
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609    use axonml_nn::Linear;
610
611    #[test]
612    fn test_sharding_strategy_default() {
613        assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
614    }
615
616    #[test]
617    fn test_fsdp_creation() {
618        let model = Linear::new(10, 5);
619        let pg = ProcessGroup::mock();
620        let fsdp = FullyShardedDataParallel::new(model, pg);
621
622        assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
623    }
624
625    #[test]
626    fn test_fsdp_forward() {
627        let model = Linear::new(4, 2);
628        let pg = ProcessGroup::mock();
629        let mut fsdp = FullyShardedDataParallel::new(model, pg);
630
631        // Gather before forward
632        fsdp.gather_parameters();
633
634        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
635        let output = fsdp.forward(&input);
636
637        assert_eq!(output.data().shape(), &[1, 2]);
638    }
639
640    #[test]
641    fn test_fsdp_builder() {
642        let model = Linear::new(10, 5);
643        let pg = ProcessGroup::mock();
644
645        let fsdp = FullyShardedDataParallel::new(model, pg)
646            .sharding_strategy(ShardingStrategy::ShardGradOp)
647            .cpu_offload(CPUOffload::Params)
648            .mixed_precision(true);
649
650        assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
651    }
652
653    #[test]
654    fn test_fsdp_memory_stats() {
655        let model = Linear::new(100, 50);
656        let pg = ProcessGroup::mock();
657        let fsdp = FullyShardedDataParallel::new(model, pg);
658
659        let stats = fsdp.memory_estimate();
660        assert!(stats.total_params > 0);
661        assert!(stats.total_memory_mb() > 0.0);
662    }
663
664    #[test]
665    fn test_fsdp_no_shard() {
666        let model = Linear::new(10, 5);
667        let pg = ProcessGroup::mock();
668        let fsdp =
669            FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
670
671        assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
672    }
673
674    #[test]
675    fn test_column_parallel_linear() {
676        let pg = ProcessGroup::mock();
677        // With world_size=1, local_out_features = out_features / 1 = 4
678        let layer = ColumnParallelLinear::new(8, 4, true, pg, false); // Don't gather for simple test
679
680        let input = Variable::new(Tensor::randn(&[2, 8]), false);
681        let output = layer.forward(&input);
682
683        // Output shape should be [batch, local_out_features] = [2, 4]
684        assert_eq!(output.data().shape(), &[2, 4]);
685    }
686
687    #[test]
688    fn test_row_parallel_linear() {
689        let pg = ProcessGroup::mock();
690        let layer = RowParallelLinear::new(8, 4, true, pg, false);
691
692        let input = Variable::new(Tensor::randn(&[2, 8]), false);
693        let output = layer.forward(&input);
694
695        assert_eq!(output.data().shape(), &[2, 4]);
696    }
697}