Skip to main content

axonml_distributed/
fsdp.rs

1//! FSDP - Fully Sharded Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/fsdp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// Sharding Strategy
25// =============================================================================
26
27/// Strategy for sharding parameters in FSDP.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum ShardingStrategy {
30    /// Shard parameters, gradients, and optimizer state (ZeRO-3)
31    #[default]
32    FullShard,
33    /// Shard gradients and optimizer state only (ZeRO-2)
34    ShardGradOp,
35    /// No sharding, replicate across ranks (DDP-like)
36    NoShard,
37    /// Hybrid sharding within node, replicate across nodes
38    HybridShard,
39}
40
41// =============================================================================
42// FSDP State
43// =============================================================================
44
45/// State for a sharded parameter.
46#[derive(Debug)]
47struct ShardedParam {
48    /// Local shard of the parameter
49    local_shard: Tensor<f32>,
50    /// Original shape before sharding
51    original_shape: Vec<usize>,
52    /// Number of elements in original parameter
53    numel: usize,
54    /// Padding added for even sharding (for uneven divisions).
55    /// Used for debugging/diagnostics.
56    #[allow(dead_code)]
57    pub padding: usize,
58}
59
60/// CPU offload configuration.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum CPUOffload {
63    /// No CPU offloading
64    #[default]
65    None,
66    /// Offload parameters to CPU when not in use
67    Params,
68    /// Offload both parameters and gradients
69    Full,
70}
71
72// =============================================================================
73// Fully Sharded Data Parallel
74// =============================================================================
75
76/// Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
77///
78/// FSDP shards model parameters across devices, gathering them only when needed
79/// for computation and sharding them again afterward.
80pub struct FullyShardedDataParallel<M: Module> {
81    /// Wrapped module
82    module: M,
83    /// Process group for communication
84    process_group: ProcessGroup,
85    /// Sharding strategy
86    sharding_strategy: ShardingStrategy,
87    /// CPU offload configuration
88    cpu_offload: CPUOffload,
89    /// Sharded parameter states
90    sharded_params: Vec<ShardedParam>,
91    /// Whether module is currently gathered (unsharded)
92    is_gathered: bool,
93    /// Mixed precision compute dtype
94    mixed_precision: bool,
95}
96
97impl<M: Module> FullyShardedDataParallel<M> {
98    /// Creates a new FSDP wrapper.
99    pub fn new(module: M, process_group: ProcessGroup) -> Self {
100        let mut fsdp = Self {
101            module,
102            process_group,
103            sharding_strategy: ShardingStrategy::default(),
104            cpu_offload: CPUOffload::default(),
105            sharded_params: Vec::new(),
106            is_gathered: true,
107            mixed_precision: false,
108        };
109
110        // Initialize sharding
111        fsdp.shard_parameters();
112        fsdp
113    }
114
115    /// Builder: set sharding strategy.
116    pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
117        self.sharding_strategy = strategy;
118        self.shard_parameters();
119        self
120    }
121
122    /// Builder: set CPU offload configuration.
123    pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
124        self.cpu_offload = offload;
125        self
126    }
127
128    /// Builder: enable mixed precision.
129    pub fn mixed_precision(mut self, enabled: bool) -> Self {
130        self.mixed_precision = enabled;
131        self
132    }
133
134    /// Returns reference to wrapped module.
135    pub fn module(&self) -> &M {
136        &self.module
137    }
138
139    /// Returns mutable reference to wrapped module.
140    pub fn module_mut(&mut self) -> &mut M {
141        &mut self.module
142    }
143
144    /// Returns the process group.
145    pub fn process_group(&self) -> &ProcessGroup {
146        &self.process_group
147    }
148
149    /// Returns the sharding strategy.
150    pub fn strategy(&self) -> ShardingStrategy {
151        self.sharding_strategy
152    }
153
154    /// Shards parameters across devices.
155    fn shard_parameters(&mut self) {
156        if self.sharding_strategy == ShardingStrategy::NoShard {
157            return;
158        }
159
160        let world_size = self.process_group.world_size();
161        let rank = self.process_group.rank();
162
163        self.sharded_params.clear();
164
165        for param in self.module.parameters() {
166            let data = param.data();
167            let shape = data.shape().to_vec();
168            let numel = data.numel();
169
170            // Calculate shard size with padding for even division
171            let shard_size = numel.div_ceil(world_size);
172            let padding = shard_size * world_size - numel;
173
174            // Get local shard
175            let flat_data = data.to_vec();
176            let start = rank * shard_size;
177            let end = ((rank + 1) * shard_size).min(flat_data.len());
178
179            let mut shard_data: Vec<f32> = if start < flat_data.len() {
180                flat_data[start..end].to_vec()
181            } else {
182                vec![0.0; shard_size]
183            };
184
185            // Pad to shard_size
186            while shard_data.len() < shard_size {
187                shard_data.push(0.0);
188            }
189
190            self.sharded_params.push(ShardedParam {
191                local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
192                original_shape: shape,
193                numel,
194                padding,
195            });
196        }
197
198        self.is_gathered = false;
199    }
200
201    /// Gathers all parameter shards before forward pass.
202    pub fn gather_parameters(&mut self) {
203        if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
204            return;
205        }
206
207        let params = self.module.parameters();
208
209        for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
210            // All-gather the shards
211            let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
212
213            // Reshape back to original shape (removing padding)
214            let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
215            let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
216
217            param.update_data(restored);
218        }
219
220        self.is_gathered = true;
221    }
222
223    /// Shards parameters after forward/backward pass.
224    pub fn reshard_parameters(&mut self) {
225        if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
226            return;
227        }
228
229        self.shard_parameters();
230    }
231
232    /// Synchronizes gradients across all ranks.
233    ///
234    /// For NoShard/HybridShard: all-reduces gradients so every rank has the
235    /// averaged gradient. For ShardGradOp/FullShard: reduce-scatters to produce
236    /// sharded gradient slices (each rank gets its shard of the averaged gradient).
237    pub fn sync_gradients(&self) {
238        match self.sharding_strategy {
239            ShardingStrategy::NoShard => {
240                // Full all-reduce like DDP
241                for param in self.module.parameters() {
242                    if let Some(grad) = param.grad() {
243                        let mut grad_tensor = grad.clone();
244                        self.process_group
245                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
246                        param.set_grad(grad_tensor);
247                    }
248                }
249            }
250            ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
251                // Reduce-scatter gradients to get sharded gradient slices
252                for param in self.module.parameters() {
253                    if let Some(grad) = param.grad() {
254                        let reduced = self
255                            .process_group
256                            .reduce_scatter_tensor(&grad, ReduceOp::Average);
257                        // Write the reduce-scattered gradient shard back
258                        param.set_grad(reduced);
259                    }
260                }
261            }
262            ShardingStrategy::HybridShard => {
263                // All-reduce within node, reduce-scatter across nodes
264                for param in self.module.parameters() {
265                    if let Some(grad) = param.grad() {
266                        let mut grad_tensor = grad.clone();
267                        self.process_group
268                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
269                        param.set_grad(grad_tensor);
270                    }
271                }
272            }
273        }
274    }
275
276    /// Clips gradients by global norm.
277    pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
278        let mut total_norm_sq = 0.0f32;
279
280        for param in self.module.parameters() {
281            if let Some(grad) = param.grad() {
282                let grad_vec = grad.to_vec();
283                let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
284                total_norm_sq += norm_sq;
285            }
286        }
287
288        // All-reduce total norm across ranks
289        let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
290        self.process_group
291            .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
292        let global_norm = norm_tensor.to_vec()[0].sqrt();
293
294        // Clip if necessary
295        if global_norm > max_norm {
296            let clip_coef = max_norm / (global_norm + 1e-6);
297            for param in self.module.parameters() {
298                if let Some(grad) = param.grad() {
299                    let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
300                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
301                    param.variable().set_grad(clipped_tensor);
302                }
303            }
304        }
305
306        global_norm
307    }
308
309    /// Estimates memory usage with different sharding strategies.
310    pub fn memory_estimate(&self) -> FSDPMemoryStats {
311        let params = self.module.parameters();
312        let total_params: usize = params.iter().map(|p| p.numel()).sum();
313        let world_size = self.process_group.world_size();
314
315        let bytes_per_param = 4; // f32
316        let param_memory = total_params * bytes_per_param;
317
318        let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
319            ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
320            ShardingStrategy::ShardGradOp => (
321                param_memory,
322                param_memory / world_size,
323                param_memory * 2 / world_size,
324            ),
325            ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
326                param_memory / world_size,
327                param_memory / world_size,
328                param_memory * 2 / world_size,
329            ),
330        };
331
332        FSDPMemoryStats {
333            total_params,
334            param_memory_bytes: sharded_params,
335            grad_memory_bytes: sharded_grads,
336            optim_memory_bytes: sharded_optim,
337            world_size,
338        }
339    }
340}
341
342impl<M: Module> Module for FullyShardedDataParallel<M> {
343    fn forward(&self, input: &Variable) -> Variable {
344        // Note: In a real implementation, gather would be called automatically
345        // through hooks before forward and reshard after
346        self.module.forward(input)
347    }
348
349    fn parameters(&self) -> Vec<Parameter> {
350        self.module.parameters()
351    }
352
353    fn train(&mut self) {
354        self.module.train();
355    }
356
357    fn eval(&mut self) {
358        self.module.eval();
359    }
360
361    fn is_training(&self) -> bool {
362        self.module.is_training()
363    }
364}
365
366/// Memory statistics for FSDP.
367#[derive(Debug, Clone)]
368pub struct FSDPMemoryStats {
369    /// Total number of parameters
370    pub total_params: usize,
371    /// Memory for parameters (bytes)
372    pub param_memory_bytes: usize,
373    /// Memory for gradients (bytes)
374    pub grad_memory_bytes: usize,
375    /// Memory for optimizer state (bytes)
376    pub optim_memory_bytes: usize,
377    /// World size (number of ranks)
378    pub world_size: usize,
379}
380
381impl FSDPMemoryStats {
382    /// Total memory per rank in MB.
383    pub fn total_memory_mb(&self) -> f32 {
384        (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
385            / (1024.0 * 1024.0)
386    }
387
388    /// Memory savings compared to no sharding.
389    pub fn memory_savings(&self) -> f32 {
390        if self.world_size > 1 {
391            1.0 - (1.0 / self.world_size as f32)
392        } else {
393            0.0
394        }
395    }
396}
397
398// =============================================================================
399// Tensor Parallelism
400// =============================================================================
401
402/// Column-parallel linear layer.
403///
404/// Splits the weight matrix along the column dimension across ranks.
405/// Each rank computes a portion of the output features.
406pub struct ColumnParallelLinear {
407    /// Local weight shard
408    weight: Parameter,
409    /// Bias (replicated on all ranks)
410    bias: Option<Parameter>,
411    /// Process group
412    process_group: ProcessGroup,
413    /// Input features
414    #[allow(dead_code)]
415    in_features: usize,
416    /// Output features (total across all ranks)
417    #[allow(dead_code)]
418    out_features: usize,
419    /// Whether to gather output
420    gather_output: bool,
421}
422
423impl ColumnParallelLinear {
424    /// Creates a new column-parallel linear layer.
425    pub fn new(
426        in_features: usize,
427        out_features: usize,
428        bias: bool,
429        process_group: ProcessGroup,
430        gather_output: bool,
431    ) -> Self {
432        let world_size = process_group.world_size();
433        let local_out_features = out_features / world_size;
434
435        let weight_data = Tensor::randn(&[local_out_features, in_features]);
436        let weight = Parameter::new(weight_data, true);
437
438        let bias = if bias {
439            let bias_data = Tensor::zeros(&[local_out_features]);
440            Some(Parameter::new(bias_data, true))
441        } else {
442            None
443        };
444
445        Self {
446            weight,
447            bias,
448            process_group,
449            in_features,
450            out_features,
451            gather_output,
452        }
453    }
454}
455
456impl Module for ColumnParallelLinear {
457    fn forward(&self, input: &Variable) -> Variable {
458        // Local matmul: input @ weight.T
459        let weight_var = Variable::new(self.weight.data(), false);
460        let output = input.matmul(&weight_var.transpose(0, 1));
461
462        // Add bias
463        let output = if let Some(ref bias) = self.bias {
464            let bias_var = Variable::new(bias.data(), false);
465            output.add(&bias_var)
466        } else {
467            output
468        };
469
470        // Optionally gather output across ranks
471        if self.gather_output {
472            let gathered = self.process_group.all_gather_tensor(&output.data());
473            Variable::new(gathered, output.requires_grad())
474        } else {
475            output
476        }
477    }
478
479    fn parameters(&self) -> Vec<Parameter> {
480        let mut params = vec![self.weight.clone()];
481        if let Some(ref bias) = self.bias {
482            params.push(bias.clone());
483        }
484        params
485    }
486}
487
488/// Row-parallel linear layer.
489///
490/// Splits the weight matrix along the row dimension across ranks.
491/// Each rank has a portion of the input features.
492pub struct RowParallelLinear {
493    /// Local weight shard
494    weight: Parameter,
495    /// Bias (only on rank 0)
496    bias: Option<Parameter>,
497    /// Process group
498    process_group: ProcessGroup,
499    /// Input features (total across all ranks)
500    #[allow(dead_code)]
501    in_features: usize,
502    /// Output features
503    #[allow(dead_code)]
504    out_features: usize,
505    /// Whether input is already split
506    input_is_parallel: bool,
507}
508
509impl RowParallelLinear {
510    /// Creates a new row-parallel linear layer.
511    pub fn new(
512        in_features: usize,
513        out_features: usize,
514        bias: bool,
515        process_group: ProcessGroup,
516        input_is_parallel: bool,
517    ) -> Self {
518        let world_size = process_group.world_size();
519        let rank = process_group.rank();
520        let local_in_features = in_features / world_size;
521
522        let weight_data = Tensor::randn(&[out_features, local_in_features]);
523        let weight = Parameter::new(weight_data, true);
524
525        // Only rank 0 has bias
526        let bias = if bias && rank == 0 {
527            let bias_data = Tensor::zeros(&[out_features]);
528            Some(Parameter::new(bias_data, true))
529        } else {
530            None
531        };
532
533        Self {
534            weight,
535            bias,
536            process_group,
537            in_features,
538            out_features,
539            input_is_parallel,
540        }
541    }
542}
543
544impl Module for RowParallelLinear {
545    fn forward(&self, input: &Variable) -> Variable {
546        // If input is not parallel, take local shard
547        let local_input = if self.input_is_parallel {
548            input.clone()
549        } else {
550            // Split input along feature dimension for row parallelism
551            let world_size = self.process_group.world_size();
552            let rank = self.process_group.rank();
553            let data = input.data();
554            let shape = data.shape();
555            let feature_dim = shape[shape.len() - 1];
556            let local_features = feature_dim / world_size;
557            let start = rank * local_features;
558            let end = start + local_features;
559
560            // Slice the last dimension
561            let sliced = if shape.len() == 2 {
562                data.slice(&[0..shape[0], start..end])
563            } else {
564                data.clone() // Fallback for other shapes
565            };
566            Variable::new(sliced, input.requires_grad())
567        };
568
569        // Local matmul
570        let weight_var = Variable::new(self.weight.data(), false);
571        let local_output = local_input.matmul(&weight_var.transpose(0, 1));
572
573        // All-reduce to combine partial outputs
574        let mut output_data = local_output.data().clone();
575        self.process_group
576            .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
577        let output = Variable::new(output_data, local_output.requires_grad());
578
579        // Add bias (only on rank 0, then broadcast)
580        if let Some(ref bias) = self.bias {
581            let bias_var = Variable::new(bias.data(), false);
582            output.add(&bias_var)
583        } else {
584            output
585        }
586    }
587
588    fn parameters(&self) -> Vec<Parameter> {
589        let mut params = vec![self.weight.clone()];
590        if let Some(ref bias) = self.bias {
591            params.push(bias.clone());
592        }
593        params
594    }
595}
596
597// =============================================================================
598// Tests
599// =============================================================================
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use axonml_nn::Linear;
605
606    #[test]
607    fn test_sharding_strategy_default() {
608        assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
609    }
610
611    #[test]
612    fn test_fsdp_creation() {
613        let model = Linear::new(10, 5);
614        let pg = ProcessGroup::mock();
615        let fsdp = FullyShardedDataParallel::new(model, pg);
616
617        assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
618    }
619
620    #[test]
621    fn test_fsdp_forward() {
622        let model = Linear::new(4, 2);
623        let pg = ProcessGroup::mock();
624        let mut fsdp = FullyShardedDataParallel::new(model, pg);
625
626        // Gather before forward
627        fsdp.gather_parameters();
628
629        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
630        let output = fsdp.forward(&input);
631
632        assert_eq!(output.data().shape(), &[1, 2]);
633    }
634
635    #[test]
636    fn test_fsdp_builder() {
637        let model = Linear::new(10, 5);
638        let pg = ProcessGroup::mock();
639
640        let fsdp = FullyShardedDataParallel::new(model, pg)
641            .sharding_strategy(ShardingStrategy::ShardGradOp)
642            .cpu_offload(CPUOffload::Params)
643            .mixed_precision(true);
644
645        assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
646    }
647
648    #[test]
649    fn test_fsdp_memory_stats() {
650        let model = Linear::new(100, 50);
651        let pg = ProcessGroup::mock();
652        let fsdp = FullyShardedDataParallel::new(model, pg);
653
654        let stats = fsdp.memory_estimate();
655        assert!(stats.total_params > 0);
656        assert!(stats.total_memory_mb() > 0.0);
657    }
658
659    #[test]
660    fn test_fsdp_no_shard() {
661        let model = Linear::new(10, 5);
662        let pg = ProcessGroup::mock();
663        let fsdp =
664            FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
665
666        assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
667    }
668
669    #[test]
670    fn test_column_parallel_linear() {
671        let pg = ProcessGroup::mock();
672        // With world_size=1, local_out_features = out_features / 1 = 4
673        let layer = ColumnParallelLinear::new(8, 4, true, pg, false); // Don't gather for simple test
674
675        let input = Variable::new(Tensor::randn(&[2, 8]), false);
676        let output = layer.forward(&input);
677
678        // Output shape should be [batch, local_out_features] = [2, 4]
679        assert_eq!(output.data().shape(), &[2, 4]);
680    }
681
682    #[test]
683    fn test_row_parallel_linear() {
684        let pg = ProcessGroup::mock();
685        let layer = RowParallelLinear::new(8, 4, true, pg, false);
686
687        let input = Variable::new(Tensor::randn(&[2, 8]), false);
688        let output = layer.forward(&input);
689
690        assert_eq!(output.data().shape(), &[2, 4]);
691    }
692}