Skip to main content

axonml_distributed/
fsdp.rs

1//! FSDP - Fully Sharded Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/fsdp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::backend::ReduceOp;
19use crate::process_group::ProcessGroup;
20use axonml_autograd::Variable;
21use axonml_nn::{Module, Parameter};
22use axonml_tensor::Tensor;
23
24// =============================================================================
25// Sharding Strategy
26// =============================================================================
27
28/// Strategy for sharding parameters in FSDP.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
30pub enum ShardingStrategy {
31    /// Shard parameters, gradients, and optimizer state (ZeRO-3)
32    #[default]
33    FullShard,
34    /// Shard gradients and optimizer state only (ZeRO-2)
35    ShardGradOp,
36    /// No sharding, replicate across ranks (DDP-like)
37    NoShard,
38    /// Hybrid sharding within node, replicate across nodes
39    HybridShard,
40}
41
42// =============================================================================
43// FSDP State
44// =============================================================================
45
46/// State for a sharded parameter.
47#[derive(Debug)]
48struct ShardedParam {
49    /// Local shard of the parameter
50    local_shard: Tensor<f32>,
51    /// Original shape before sharding
52    original_shape: Vec<usize>,
53    /// Number of elements in original parameter
54    numel: usize,
55    /// Padding added for even sharding (for uneven divisions).
56    /// Used for debugging/diagnostics.
57    #[allow(dead_code)]
58    pub padding: usize,
59}
60
61/// CPU offload configuration.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
63pub enum CPUOffload {
64    /// No CPU offloading
65    #[default]
66    None,
67    /// Offload parameters to CPU when not in use
68    Params,
69    /// Offload both parameters and gradients
70    Full,
71}
72
73// =============================================================================
74// Fully Sharded Data Parallel
75// =============================================================================
76
77/// Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
78///
79/// FSDP shards model parameters across devices, gathering them only when needed
80/// for computation and sharding them again afterward.
81pub struct FullyShardedDataParallel<M: Module> {
82    /// Wrapped module
83    module: M,
84    /// Process group for communication
85    process_group: ProcessGroup,
86    /// Sharding strategy
87    sharding_strategy: ShardingStrategy,
88    /// CPU offload configuration
89    cpu_offload: CPUOffload,
90    /// Sharded parameter states
91    sharded_params: Vec<ShardedParam>,
92    /// Whether module is currently gathered (unsharded)
93    is_gathered: bool,
94    /// Mixed precision compute dtype
95    mixed_precision: bool,
96}
97
98impl<M: Module> FullyShardedDataParallel<M> {
99    /// Creates a new FSDP wrapper.
100    pub fn new(module: M, process_group: ProcessGroup) -> Self {
101        let mut fsdp = Self {
102            module,
103            process_group,
104            sharding_strategy: ShardingStrategy::default(),
105            cpu_offload: CPUOffload::default(),
106            sharded_params: Vec::new(),
107            is_gathered: true,
108            mixed_precision: false,
109        };
110
111        // Initialize sharding
112        fsdp.shard_parameters();
113        fsdp
114    }
115
116    /// Builder: set sharding strategy.
117    pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
118        self.sharding_strategy = strategy;
119        self.shard_parameters();
120        self
121    }
122
123    /// Builder: set CPU offload configuration.
124    pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
125        self.cpu_offload = offload;
126        self
127    }
128
129    /// Builder: enable mixed precision.
130    pub fn mixed_precision(mut self, enabled: bool) -> Self {
131        self.mixed_precision = enabled;
132        self
133    }
134
135    /// Returns reference to wrapped module.
136    pub fn module(&self) -> &M {
137        &self.module
138    }
139
140    /// Returns mutable reference to wrapped module.
141    pub fn module_mut(&mut self) -> &mut M {
142        &mut self.module
143    }
144
145    /// Returns the process group.
146    pub fn process_group(&self) -> &ProcessGroup {
147        &self.process_group
148    }
149
150    /// Returns the sharding strategy.
151    pub fn strategy(&self) -> ShardingStrategy {
152        self.sharding_strategy
153    }
154
155    /// Shards parameters across devices.
156    fn shard_parameters(&mut self) {
157        if self.sharding_strategy == ShardingStrategy::NoShard {
158            return;
159        }
160
161        let world_size = self.process_group.world_size();
162        let rank = self.process_group.rank();
163
164        self.sharded_params.clear();
165
166        for param in self.module.parameters() {
167            let data = param.data();
168            let shape = data.shape().to_vec();
169            let numel = data.numel();
170
171            // Calculate shard size with padding for even division
172            let shard_size = numel.div_ceil(world_size);
173            let padding = shard_size * world_size - numel;
174
175            // Get local shard
176            let flat_data = data.to_vec();
177            let start = rank * shard_size;
178            let end = ((rank + 1) * shard_size).min(flat_data.len());
179
180            let mut shard_data: Vec<f32> = if start < flat_data.len() {
181                flat_data[start..end].to_vec()
182            } else {
183                vec![0.0; shard_size]
184            };
185
186            // Pad to shard_size
187            while shard_data.len() < shard_size {
188                shard_data.push(0.0);
189            }
190
191            self.sharded_params.push(ShardedParam {
192                local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
193                original_shape: shape,
194                numel,
195                padding,
196            });
197        }
198
199        self.is_gathered = false;
200    }
201
202    /// Gathers all parameter shards before forward pass.
203    pub fn gather_parameters(&mut self) {
204        if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
205            return;
206        }
207
208        let params = self.module.parameters();
209
210        for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
211            // All-gather the shards
212            let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
213
214            // Reshape back to original shape (removing padding)
215            let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
216            let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
217
218            param.update_data(restored);
219        }
220
221        self.is_gathered = true;
222    }
223
224    /// Shards parameters after forward/backward pass.
225    pub fn reshard_parameters(&mut self) {
226        if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
227            return;
228        }
229
230        self.shard_parameters();
231    }
232
233    /// Synchronizes gradients across all ranks.
234    ///
235    /// For NoShard/HybridShard: all-reduces gradients so every rank has the
236    /// averaged gradient. For ShardGradOp/FullShard: reduce-scatters to produce
237    /// sharded gradient slices (each rank gets its shard of the averaged gradient).
238    pub fn sync_gradients(&self) {
239        match self.sharding_strategy {
240            ShardingStrategy::NoShard => {
241                // Full all-reduce like DDP
242                for param in self.module.parameters() {
243                    if let Some(grad) = param.grad() {
244                        let mut grad_tensor = grad.clone();
245                        self.process_group
246                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
247                        param.set_grad(grad_tensor);
248                    }
249                }
250            }
251            ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
252                // Reduce-scatter gradients to get sharded gradient slices
253                for param in self.module.parameters() {
254                    if let Some(grad) = param.grad() {
255                        let reduced = self
256                            .process_group
257                            .reduce_scatter_tensor(&grad, ReduceOp::Average);
258                        // Write the reduce-scattered gradient shard back
259                        param.set_grad(reduced);
260                    }
261                }
262            }
263            ShardingStrategy::HybridShard => {
264                // All-reduce within node, reduce-scatter across nodes
265                for param in self.module.parameters() {
266                    if let Some(grad) = param.grad() {
267                        let mut grad_tensor = grad.clone();
268                        self.process_group
269                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
270                        param.set_grad(grad_tensor);
271                    }
272                }
273            }
274        }
275    }
276
277    /// Clips gradients by global norm.
278    pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
279        let mut total_norm_sq = 0.0f32;
280
281        for param in self.module.parameters() {
282            if let Some(grad) = param.grad() {
283                let grad_vec = grad.to_vec();
284                let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
285                total_norm_sq += norm_sq;
286            }
287        }
288
289        // All-reduce total norm across ranks
290        let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
291        self.process_group
292            .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
293        let global_norm = norm_tensor.to_vec()[0].sqrt();
294
295        // Clip if necessary
296        if global_norm > max_norm {
297            let clip_coef = max_norm / (global_norm + 1e-6);
298            for param in self.module.parameters() {
299                if let Some(grad) = param.grad() {
300                    let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
301                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
302                    param.variable().set_grad(clipped_tensor);
303                }
304            }
305        }
306
307        global_norm
308    }
309
310    /// Estimates memory usage with different sharding strategies.
311    pub fn memory_estimate(&self) -> FSDPMemoryStats {
312        let params = self.module.parameters();
313        let total_params: usize = params.iter().map(|p| p.numel()).sum();
314        let world_size = self.process_group.world_size();
315
316        let bytes_per_param = 4; // f32
317        let param_memory = total_params * bytes_per_param;
318
319        let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
320            ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
321            ShardingStrategy::ShardGradOp => (
322                param_memory,
323                param_memory / world_size,
324                param_memory * 2 / world_size,
325            ),
326            ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
327                param_memory / world_size,
328                param_memory / world_size,
329                param_memory * 2 / world_size,
330            ),
331        };
332
333        FSDPMemoryStats {
334            total_params,
335            param_memory_bytes: sharded_params,
336            grad_memory_bytes: sharded_grads,
337            optim_memory_bytes: sharded_optim,
338            world_size,
339        }
340    }
341}
342
343impl<M: Module> Module for FullyShardedDataParallel<M> {
344    fn forward(&self, input: &Variable) -> Variable {
345        // Note: In a real implementation, gather would be called automatically
346        // through hooks before forward and reshard after
347        self.module.forward(input)
348    }
349
350    fn parameters(&self) -> Vec<Parameter> {
351        self.module.parameters()
352    }
353
354    fn train(&mut self) {
355        self.module.train();
356    }
357
358    fn eval(&mut self) {
359        self.module.eval();
360    }
361
362    fn is_training(&self) -> bool {
363        self.module.is_training()
364    }
365}
366
367/// Memory statistics for FSDP.
368#[derive(Debug, Clone)]
369pub struct FSDPMemoryStats {
370    /// Total number of parameters
371    pub total_params: usize,
372    /// Memory for parameters (bytes)
373    pub param_memory_bytes: usize,
374    /// Memory for gradients (bytes)
375    pub grad_memory_bytes: usize,
376    /// Memory for optimizer state (bytes)
377    pub optim_memory_bytes: usize,
378    /// World size (number of ranks)
379    pub world_size: usize,
380}
381
382impl FSDPMemoryStats {
383    /// Total memory per rank in MB.
384    pub fn total_memory_mb(&self) -> f32 {
385        (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
386            / (1024.0 * 1024.0)
387    }
388
389    /// Memory savings compared to no sharding.
390    pub fn memory_savings(&self) -> f32 {
391        if self.world_size > 1 {
392            1.0 - (1.0 / self.world_size as f32)
393        } else {
394            0.0
395        }
396    }
397}
398
399// =============================================================================
400// Tensor Parallelism
401// =============================================================================
402
403/// Column-parallel linear layer.
404///
405/// Splits the weight matrix along the column dimension across ranks.
406/// Each rank computes a portion of the output features.
407pub struct ColumnParallelLinear {
408    /// Local weight shard
409    weight: Parameter,
410    /// Bias (replicated on all ranks)
411    bias: Option<Parameter>,
412    /// Process group
413    process_group: ProcessGroup,
414    /// Input features
415    #[allow(dead_code)]
416    in_features: usize,
417    /// Output features (total across all ranks)
418    #[allow(dead_code)]
419    out_features: usize,
420    /// Whether to gather output
421    gather_output: bool,
422}
423
424impl ColumnParallelLinear {
425    /// Creates a new column-parallel linear layer.
426    pub fn new(
427        in_features: usize,
428        out_features: usize,
429        bias: bool,
430        process_group: ProcessGroup,
431        gather_output: bool,
432    ) -> Self {
433        let world_size = process_group.world_size();
434        let local_out_features = out_features / world_size;
435
436        let weight_data = Tensor::randn(&[local_out_features, in_features]);
437        let weight = Parameter::new(weight_data, true);
438
439        let bias = if bias {
440            let bias_data = Tensor::zeros(&[local_out_features]);
441            Some(Parameter::new(bias_data, true))
442        } else {
443            None
444        };
445
446        Self {
447            weight,
448            bias,
449            process_group,
450            in_features,
451            out_features,
452            gather_output,
453        }
454    }
455}
456
457impl Module for ColumnParallelLinear {
458    fn forward(&self, input: &Variable) -> Variable {
459        // Local matmul: input @ weight.T
460        let weight_var = Variable::new(self.weight.data(), false);
461        let output = input.matmul(&weight_var.transpose(0, 1));
462
463        // Add bias
464        let output = if let Some(ref bias) = self.bias {
465            let bias_var = Variable::new(bias.data(), false);
466            output.add(&bias_var)
467        } else {
468            output
469        };
470
471        // Optionally gather output across ranks
472        if self.gather_output {
473            let gathered = self.process_group.all_gather_tensor(&output.data());
474            Variable::new(gathered, output.requires_grad())
475        } else {
476            output
477        }
478    }
479
480    fn parameters(&self) -> Vec<Parameter> {
481        let mut params = vec![self.weight.clone()];
482        if let Some(ref bias) = self.bias {
483            params.push(bias.clone());
484        }
485        params
486    }
487}
488
489/// Row-parallel linear layer.
490///
491/// Splits the weight matrix along the row dimension across ranks.
492/// Each rank has a portion of the input features.
493pub struct RowParallelLinear {
494    /// Local weight shard
495    weight: Parameter,
496    /// Bias (only on rank 0)
497    bias: Option<Parameter>,
498    /// Process group
499    process_group: ProcessGroup,
500    /// Input features (total across all ranks)
501    #[allow(dead_code)]
502    in_features: usize,
503    /// Output features
504    #[allow(dead_code)]
505    out_features: usize,
506    /// Whether input is already split
507    input_is_parallel: bool,
508}
509
510impl RowParallelLinear {
511    /// Creates a new row-parallel linear layer.
512    pub fn new(
513        in_features: usize,
514        out_features: usize,
515        bias: bool,
516        process_group: ProcessGroup,
517        input_is_parallel: bool,
518    ) -> Self {
519        let world_size = process_group.world_size();
520        let rank = process_group.rank();
521        let local_in_features = in_features / world_size;
522
523        let weight_data = Tensor::randn(&[out_features, local_in_features]);
524        let weight = Parameter::new(weight_data, true);
525
526        // Only rank 0 has bias
527        let bias = if bias && rank == 0 {
528            let bias_data = Tensor::zeros(&[out_features]);
529            Some(Parameter::new(bias_data, true))
530        } else {
531            None
532        };
533
534        Self {
535            weight,
536            bias,
537            process_group,
538            in_features,
539            out_features,
540            input_is_parallel,
541        }
542    }
543}
544
545impl Module for RowParallelLinear {
546    fn forward(&self, input: &Variable) -> Variable {
547        // If input is not parallel, take local shard
548        let local_input = if self.input_is_parallel {
549            input.clone()
550        } else {
551            // Split input along feature dimension for row parallelism
552            let world_size = self.process_group.world_size();
553            let rank = self.process_group.rank();
554            let data = input.data();
555            let shape = data.shape();
556            let feature_dim = shape[shape.len() - 1];
557            let local_features = feature_dim / world_size;
558            let start = rank * local_features;
559            let end = start + local_features;
560
561            // Slice the last dimension
562            let sliced = if shape.len() == 2 {
563                data.slice(&[0..shape[0], start..end])
564            } else {
565                data.clone() // Fallback for other shapes
566            };
567            Variable::new(sliced, input.requires_grad())
568        };
569
570        // Local matmul
571        let weight_var = Variable::new(self.weight.data(), false);
572        let local_output = local_input.matmul(&weight_var.transpose(0, 1));
573
574        // All-reduce to combine partial outputs
575        let mut output_data = local_output.data().clone();
576        self.process_group
577            .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
578        let output = Variable::new(output_data, local_output.requires_grad());
579
580        // Add bias (only on rank 0, then broadcast)
581        if let Some(ref bias) = self.bias {
582            let bias_var = Variable::new(bias.data(), false);
583            output.add(&bias_var)
584        } else {
585            output
586        }
587    }
588
589    fn parameters(&self) -> Vec<Parameter> {
590        let mut params = vec![self.weight.clone()];
591        if let Some(ref bias) = self.bias {
592            params.push(bias.clone());
593        }
594        params
595    }
596}
597
598// =============================================================================
599// Tests
600// =============================================================================
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605    use axonml_nn::Linear;
606
607    #[test]
608    fn test_sharding_strategy_default() {
609        assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
610    }
611
612    #[test]
613    fn test_fsdp_creation() {
614        let model = Linear::new(10, 5);
615        let pg = ProcessGroup::mock();
616        let fsdp = FullyShardedDataParallel::new(model, pg);
617
618        assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
619    }
620
621    #[test]
622    fn test_fsdp_forward() {
623        let model = Linear::new(4, 2);
624        let pg = ProcessGroup::mock();
625        let mut fsdp = FullyShardedDataParallel::new(model, pg);
626
627        // Gather before forward
628        fsdp.gather_parameters();
629
630        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
631        let output = fsdp.forward(&input);
632
633        assert_eq!(output.data().shape(), &[1, 2]);
634    }
635
636    #[test]
637    fn test_fsdp_builder() {
638        let model = Linear::new(10, 5);
639        let pg = ProcessGroup::mock();
640
641        let fsdp = FullyShardedDataParallel::new(model, pg)
642            .sharding_strategy(ShardingStrategy::ShardGradOp)
643            .cpu_offload(CPUOffload::Params)
644            .mixed_precision(true);
645
646        assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
647    }
648
649    #[test]
650    fn test_fsdp_memory_stats() {
651        let model = Linear::new(100, 50);
652        let pg = ProcessGroup::mock();
653        let fsdp = FullyShardedDataParallel::new(model, pg);
654
655        let stats = fsdp.memory_estimate();
656        assert!(stats.total_params > 0);
657        assert!(stats.total_memory_mb() > 0.0);
658    }
659
660    #[test]
661    fn test_fsdp_no_shard() {
662        let model = Linear::new(10, 5);
663        let pg = ProcessGroup::mock();
664        let fsdp =
665            FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
666
667        assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
668    }
669
670    #[test]
671    fn test_column_parallel_linear() {
672        let pg = ProcessGroup::mock();
673        // With world_size=1, local_out_features = out_features / 1 = 4
674        let layer = ColumnParallelLinear::new(8, 4, true, pg, false); // Don't gather for simple test
675
676        let input = Variable::new(Tensor::randn(&[2, 8]), false);
677        let output = layer.forward(&input);
678
679        // Output shape should be [batch, local_out_features] = [2, 4]
680        assert_eq!(output.data().shape(), &[2, 4]);
681    }
682
683    #[test]
684    fn test_row_parallel_linear() {
685        let pg = ProcessGroup::mock();
686        let layer = RowParallelLinear::new(8, 4, true, pg, false);
687
688        let input = Variable::new(Tensor::randn(&[2, 8]), false);
689        let output = layer.forward(&input);
690
691        assert_eq!(output.data().shape(), &[2, 4]);
692    }
693}