Skip to main content

axonml_distributed/
fsdp.rs

1//! FSDP - Fully Sharded Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/fsdp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// Sharding Strategy
25// =============================================================================
26
27/// Strategy for sharding parameters in FSDP.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum ShardingStrategy {
30    /// Shard parameters, gradients, and optimizer state (ZeRO-3)
31    #[default]
32    FullShard,
33    /// Shard gradients and optimizer state only (ZeRO-2)
34    ShardGradOp,
35    /// No sharding, replicate across ranks (DDP-like)
36    NoShard,
37    /// Hybrid sharding within node, replicate across nodes
38    HybridShard,
39}
40
41// =============================================================================
42// FSDP State
43// =============================================================================
44
45/// State for a sharded parameter.
46#[derive(Debug)]
47struct ShardedParam {
48    /// Local shard of the parameter
49    local_shard: Tensor<f32>,
50    /// Original shape before sharding
51    original_shape: Vec<usize>,
52    /// Number of elements in original parameter
53    numel: usize,
54    /// Padding added for even sharding (for uneven divisions).
55    /// Used for debugging/diagnostics.
56    pub padding: usize,
57}
58
59/// CPU offload configuration.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61pub enum CPUOffload {
62    /// No CPU offloading
63    #[default]
64    None,
65    /// Offload parameters to CPU when not in use
66    Params,
67    /// Offload both parameters and gradients
68    Full,
69}
70
71// =============================================================================
72// Fully Sharded Data Parallel
73// =============================================================================
74
75/// Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
76///
77/// FSDP shards model parameters across devices, gathering them only when needed
78/// for computation and sharding them again afterward.
79pub struct FullyShardedDataParallel<M: Module> {
80    /// Wrapped module
81    module: M,
82    /// Process group for communication
83    process_group: ProcessGroup,
84    /// Sharding strategy
85    sharding_strategy: ShardingStrategy,
86    /// CPU offload configuration
87    cpu_offload: CPUOffload,
88    /// Sharded parameter states
89    sharded_params: Vec<ShardedParam>,
90    /// Whether module is currently gathered (unsharded)
91    is_gathered: bool,
92    /// Mixed precision compute dtype
93    mixed_precision: bool,
94}
95
96impl<M: Module> FullyShardedDataParallel<M> {
97    /// Creates a new FSDP wrapper.
98    pub fn new(module: M, process_group: ProcessGroup) -> Self {
99        let mut fsdp = Self {
100            module,
101            process_group,
102            sharding_strategy: ShardingStrategy::default(),
103            cpu_offload: CPUOffload::default(),
104            sharded_params: Vec::new(),
105            is_gathered: true,
106            mixed_precision: false,
107        };
108
109        // Initialize sharding
110        fsdp.shard_parameters();
111        fsdp
112    }
113
114    /// Builder: set sharding strategy.
115    pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
116        self.sharding_strategy = strategy;
117        self.shard_parameters();
118        self
119    }
120
121    /// Builder: set CPU offload configuration.
122    pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
123        self.cpu_offload = offload;
124        self
125    }
126
127    /// Builder: enable mixed precision.
128    pub fn mixed_precision(mut self, enabled: bool) -> Self {
129        self.mixed_precision = enabled;
130        self
131    }
132
133    /// Returns reference to wrapped module.
134    pub fn module(&self) -> &M {
135        &self.module
136    }
137
138    /// Returns mutable reference to wrapped module.
139    pub fn module_mut(&mut self) -> &mut M {
140        &mut self.module
141    }
142
143    /// Returns the process group.
144    pub fn process_group(&self) -> &ProcessGroup {
145        &self.process_group
146    }
147
148    /// Returns the sharding strategy.
149    pub fn strategy(&self) -> ShardingStrategy {
150        self.sharding_strategy
151    }
152
153    /// Shards parameters across devices.
154    fn shard_parameters(&mut self) {
155        if self.sharding_strategy == ShardingStrategy::NoShard {
156            return;
157        }
158
159        let world_size = self.process_group.world_size();
160        let rank = self.process_group.rank();
161
162        self.sharded_params.clear();
163
164        for param in self.module.parameters() {
165            let data = param.data();
166            let shape = data.shape().to_vec();
167            let numel = data.numel();
168
169            // Calculate shard size with padding for even division
170            let shard_size = numel.div_ceil(world_size);
171            let padding = shard_size * world_size - numel;
172
173            // Get local shard
174            let flat_data = data.to_vec();
175            let start = rank * shard_size;
176            let end = ((rank + 1) * shard_size).min(flat_data.len());
177
178            let mut shard_data: Vec<f32> = if start < flat_data.len() {
179                flat_data[start..end].to_vec()
180            } else {
181                vec![0.0; shard_size]
182            };
183
184            // Pad to shard_size
185            while shard_data.len() < shard_size {
186                shard_data.push(0.0);
187            }
188
189            self.sharded_params.push(ShardedParam {
190                local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
191                original_shape: shape,
192                numel,
193                padding,
194            });
195        }
196
197        self.is_gathered = false;
198    }
199
200    /// Gathers all parameter shards before forward pass.
201    pub fn gather_parameters(&mut self) {
202        if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
203            return;
204        }
205
206        let params = self.module.parameters();
207
208        for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
209            // All-gather the shards
210            let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
211
212            // Reshape back to original shape (removing padding)
213            let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
214            let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
215
216            param.update_data(restored);
217        }
218
219        self.is_gathered = true;
220    }
221
222    /// Shards parameters after forward/backward pass.
223    pub fn reshard_parameters(&mut self) {
224        if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
225            return;
226        }
227
228        self.shard_parameters();
229    }
230
231    /// Synchronizes gradients across all ranks.
232    ///
233    /// For NoShard/HybridShard: all-reduces gradients so every rank has the
234    /// averaged gradient. For ShardGradOp/FullShard: reduce-scatters to produce
235    /// sharded gradient slices (each rank gets its shard of the averaged gradient).
236    pub fn sync_gradients(&self) {
237        match self.sharding_strategy {
238            ShardingStrategy::NoShard => {
239                // Full all-reduce like DDP
240                for param in self.module.parameters() {
241                    if let Some(grad) = param.grad() {
242                        let mut grad_tensor = grad.clone();
243                        self.process_group
244                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
245                        param.set_grad(grad_tensor);
246                    }
247                }
248            }
249            ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
250                // Reduce-scatter gradients to get sharded gradient slices
251                for param in self.module.parameters() {
252                    if let Some(grad) = param.grad() {
253                        let reduced = self
254                            .process_group
255                            .reduce_scatter_tensor(&grad, ReduceOp::Average);
256                        // Write the reduce-scattered gradient shard back
257                        param.set_grad(reduced);
258                    }
259                }
260            }
261            ShardingStrategy::HybridShard => {
262                // All-reduce within node, reduce-scatter across nodes
263                for param in self.module.parameters() {
264                    if let Some(grad) = param.grad() {
265                        let mut grad_tensor = grad.clone();
266                        self.process_group
267                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
268                        param.set_grad(grad_tensor);
269                    }
270                }
271            }
272        }
273    }
274
275    /// Clips gradients by global norm.
276    pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
277        let mut total_norm_sq = 0.0f32;
278
279        for param in self.module.parameters() {
280            if let Some(grad) = param.grad() {
281                let grad_vec = grad.to_vec();
282                let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
283                total_norm_sq += norm_sq;
284            }
285        }
286
287        // All-reduce total norm across ranks
288        let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
289        self.process_group
290            .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
291        let global_norm = norm_tensor.to_vec()[0].sqrt();
292
293        // Clip if necessary
294        if global_norm > max_norm {
295            let clip_coef = max_norm / (global_norm + 1e-6);
296            for param in self.module.parameters() {
297                if let Some(grad) = param.grad() {
298                    let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
299                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
300                    param.variable().set_grad(clipped_tensor);
301                }
302            }
303        }
304
305        global_norm
306    }
307
308    /// Estimates memory usage with different sharding strategies.
309    pub fn memory_estimate(&self) -> FSDPMemoryStats {
310        let params = self.module.parameters();
311        let total_params: usize = params.iter().map(|p| p.numel()).sum();
312        let world_size = self.process_group.world_size();
313
314        let bytes_per_param = 4; // f32
315        let param_memory = total_params * bytes_per_param;
316
317        let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
318            ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
319            ShardingStrategy::ShardGradOp => (
320                param_memory,
321                param_memory / world_size,
322                param_memory * 2 / world_size,
323            ),
324            ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
325                param_memory / world_size,
326                param_memory / world_size,
327                param_memory * 2 / world_size,
328            ),
329        };
330
331        FSDPMemoryStats {
332            total_params,
333            param_memory_bytes: sharded_params,
334            grad_memory_bytes: sharded_grads,
335            optim_memory_bytes: sharded_optim,
336            world_size,
337        }
338    }
339}
340
341impl<M: Module> Module for FullyShardedDataParallel<M> {
342    fn forward(&self, input: &Variable) -> Variable {
343        // Note: In a real implementation, gather would be called automatically
344        // through hooks before forward and reshard after
345        self.module.forward(input)
346    }
347
348    fn parameters(&self) -> Vec<Parameter> {
349        self.module.parameters()
350    }
351
352    fn train(&mut self) {
353        self.module.train();
354    }
355
356    fn eval(&mut self) {
357        self.module.eval();
358    }
359
360    fn is_training(&self) -> bool {
361        self.module.is_training()
362    }
363}
364
365/// Memory statistics for FSDP.
366#[derive(Debug, Clone)]
367pub struct FSDPMemoryStats {
368    /// Total number of parameters
369    pub total_params: usize,
370    /// Memory for parameters (bytes)
371    pub param_memory_bytes: usize,
372    /// Memory for gradients (bytes)
373    pub grad_memory_bytes: usize,
374    /// Memory for optimizer state (bytes)
375    pub optim_memory_bytes: usize,
376    /// World size (number of ranks)
377    pub world_size: usize,
378}
379
380impl FSDPMemoryStats {
381    /// Total memory per rank in MB.
382    pub fn total_memory_mb(&self) -> f32 {
383        (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
384            / (1024.0 * 1024.0)
385    }
386
387    /// Memory savings compared to no sharding.
388    pub fn memory_savings(&self) -> f32 {
389        if self.world_size > 1 {
390            1.0 - (1.0 / self.world_size as f32)
391        } else {
392            0.0
393        }
394    }
395}
396
397// =============================================================================
398// Tensor Parallelism
399// =============================================================================
400
401/// Column-parallel linear layer.
402///
403/// Splits the weight matrix along the column dimension across ranks.
404/// Each rank computes a portion of the output features.
405pub struct ColumnParallelLinear {
406    /// Local weight shard
407    weight: Parameter,
408    /// Bias (replicated on all ranks)
409    bias: Option<Parameter>,
410    /// Process group
411    process_group: ProcessGroup,
412    /// Input features
413    in_features: usize,
414    /// Output features (total across all ranks)
415    out_features: usize,
416    /// Whether to gather output
417    gather_output: bool,
418}
419
420impl ColumnParallelLinear {
421    /// Creates a new column-parallel linear layer.
422    pub fn new(
423        in_features: usize,
424        out_features: usize,
425        bias: bool,
426        process_group: ProcessGroup,
427        gather_output: bool,
428    ) -> Self {
429        let world_size = process_group.world_size();
430        let local_out_features = out_features / world_size;
431
432        let weight_data = Tensor::randn(&[local_out_features, in_features]);
433        let weight = Parameter::new(weight_data, true);
434
435        let bias = if bias {
436            let bias_data = Tensor::zeros(&[local_out_features]);
437            Some(Parameter::new(bias_data, true))
438        } else {
439            None
440        };
441
442        Self {
443            weight,
444            bias,
445            process_group,
446            in_features,
447            out_features,
448            gather_output,
449        }
450    }
451}
452
453impl Module for ColumnParallelLinear {
454    fn forward(&self, input: &Variable) -> Variable {
455        // Local matmul: input @ weight.T
456        let weight_var = Variable::new(self.weight.data(), false);
457        let output = input.matmul(&weight_var.transpose(0, 1));
458
459        // Add bias
460        let output = if let Some(ref bias) = self.bias {
461            let bias_var = Variable::new(bias.data(), false);
462            output.add(&bias_var)
463        } else {
464            output
465        };
466
467        // Optionally gather output across ranks
468        if self.gather_output {
469            let gathered = self.process_group.all_gather_tensor(&output.data());
470            Variable::new(gathered, output.requires_grad())
471        } else {
472            output
473        }
474    }
475
476    fn parameters(&self) -> Vec<Parameter> {
477        let mut params = vec![self.weight.clone()];
478        if let Some(ref bias) = self.bias {
479            params.push(bias.clone());
480        }
481        params
482    }
483}
484
485/// Row-parallel linear layer.
486///
487/// Splits the weight matrix along the row dimension across ranks.
488/// Each rank has a portion of the input features.
489pub struct RowParallelLinear {
490    /// Local weight shard
491    weight: Parameter,
492    /// Bias (only on rank 0)
493    bias: Option<Parameter>,
494    /// Process group
495    process_group: ProcessGroup,
496    /// Input features (total across all ranks)
497    in_features: usize,
498    /// Output features
499    out_features: usize,
500    /// Whether input is already split
501    input_is_parallel: bool,
502}
503
504impl RowParallelLinear {
505    /// Creates a new row-parallel linear layer.
506    pub fn new(
507        in_features: usize,
508        out_features: usize,
509        bias: bool,
510        process_group: ProcessGroup,
511        input_is_parallel: bool,
512    ) -> Self {
513        let world_size = process_group.world_size();
514        let rank = process_group.rank();
515        let local_in_features = in_features / world_size;
516
517        let weight_data = Tensor::randn(&[out_features, local_in_features]);
518        let weight = Parameter::new(weight_data, true);
519
520        // Only rank 0 has bias
521        let bias = if bias && rank == 0 {
522            let bias_data = Tensor::zeros(&[out_features]);
523            Some(Parameter::new(bias_data, true))
524        } else {
525            None
526        };
527
528        Self {
529            weight,
530            bias,
531            process_group,
532            in_features,
533            out_features,
534            input_is_parallel,
535        }
536    }
537}
538
539impl Module for RowParallelLinear {
540    fn forward(&self, input: &Variable) -> Variable {
541        // If input is not parallel, take local shard
542        let local_input = if self.input_is_parallel {
543            input.clone()
544        } else {
545            // Split input along feature dimension for row parallelism
546            let world_size = self.process_group.world_size();
547            let rank = self.process_group.rank();
548            let data = input.data();
549            let shape = data.shape();
550            let feature_dim = shape[shape.len() - 1];
551            let local_features = feature_dim / world_size;
552            let start = rank * local_features;
553            let end = start + local_features;
554
555            // Slice the last dimension
556            let sliced = if shape.len() == 2 {
557                data.slice(&[0..shape[0], start..end])
558            } else {
559                data.clone() // Fallback for other shapes
560            };
561            Variable::new(sliced, input.requires_grad())
562        };
563
564        // Local matmul
565        let weight_var = Variable::new(self.weight.data(), false);
566        let local_output = local_input.matmul(&weight_var.transpose(0, 1));
567
568        // All-reduce to combine partial outputs
569        let mut output_data = local_output.data().clone();
570        self.process_group
571            .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
572        let output = Variable::new(output_data, local_output.requires_grad());
573
574        // Add bias (only on rank 0, then broadcast)
575        if let Some(ref bias) = self.bias {
576            let bias_var = Variable::new(bias.data(), false);
577            output.add(&bias_var)
578        } else {
579            output
580        }
581    }
582
583    fn parameters(&self) -> Vec<Parameter> {
584        let mut params = vec![self.weight.clone()];
585        if let Some(ref bias) = self.bias {
586            params.push(bias.clone());
587        }
588        params
589    }
590}
591
592// =============================================================================
593// Tests
594// =============================================================================
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use axonml_nn::Linear;
600
601    #[test]
602    fn test_sharding_strategy_default() {
603        assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
604    }
605
606    #[test]
607    fn test_fsdp_creation() {
608        let model = Linear::new(10, 5);
609        let pg = ProcessGroup::mock();
610        let fsdp = FullyShardedDataParallel::new(model, pg);
611
612        assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
613    }
614
615    #[test]
616    fn test_fsdp_forward() {
617        let model = Linear::new(4, 2);
618        let pg = ProcessGroup::mock();
619        let mut fsdp = FullyShardedDataParallel::new(model, pg);
620
621        // Gather before forward
622        fsdp.gather_parameters();
623
624        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
625        let output = fsdp.forward(&input);
626
627        assert_eq!(output.data().shape(), &[1, 2]);
628    }
629
630    #[test]
631    fn test_fsdp_builder() {
632        let model = Linear::new(10, 5);
633        let pg = ProcessGroup::mock();
634
635        let fsdp = FullyShardedDataParallel::new(model, pg)
636            .sharding_strategy(ShardingStrategy::ShardGradOp)
637            .cpu_offload(CPUOffload::Params)
638            .mixed_precision(true);
639
640        assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
641    }
642
643    #[test]
644    fn test_fsdp_memory_stats() {
645        let model = Linear::new(100, 50);
646        let pg = ProcessGroup::mock();
647        let fsdp = FullyShardedDataParallel::new(model, pg);
648
649        let stats = fsdp.memory_estimate();
650        assert!(stats.total_params > 0);
651        assert!(stats.total_memory_mb() > 0.0);
652    }
653
654    #[test]
655    fn test_fsdp_no_shard() {
656        let model = Linear::new(10, 5);
657        let pg = ProcessGroup::mock();
658        let fsdp =
659            FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
660
661        assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
662    }
663
664    #[test]
665    fn test_column_parallel_linear() {
666        let pg = ProcessGroup::mock();
667        // With world_size=1, local_out_features = out_features / 1 = 4
668        let layer = ColumnParallelLinear::new(8, 4, true, pg, false); // Don't gather for simple test
669
670        let input = Variable::new(Tensor::randn(&[2, 8]), false);
671        let output = layer.forward(&input);
672
673        // Output shape should be [batch, local_out_features] = [2, 4]
674        assert_eq!(output.data().shape(), &[2, 4]);
675    }
676
677    #[test]
678    fn test_row_parallel_linear() {
679        let pg = ProcessGroup::mock();
680        let layer = RowParallelLinear::new(8, 4, true, pg, false);
681
682        let input = Variable::new(Tensor::randn(&[2, 8]), false);
683        let output = layer.forward(&input);
684
685        assert_eq!(output.data().shape(), &[2, 4]);
686    }
687}