Skip to main content

axonml_distributed/
fsdp.rs

1//! FSDP - Fully Sharded Data Parallel
2//!
3//! # File
4//! `crates/axonml-distributed/src/fsdp.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_autograd::Variable;
20use axonml_nn::{Module, Parameter};
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// Sharding Strategy
25// =============================================================================
26
27/// Strategy for sharding parameters in FSDP.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum ShardingStrategy {
30    /// Shard parameters, gradients, and optimizer state (ZeRO-3)
31    #[default]
32    FullShard,
33    /// Shard gradients and optimizer state only (ZeRO-2)
34    ShardGradOp,
35    /// No sharding, replicate across ranks (DDP-like)
36    NoShard,
37    /// Hybrid sharding within node, replicate across nodes
38    HybridShard,
39}
40
41// =============================================================================
42// FSDP State
43// =============================================================================
44
45/// State for a sharded parameter.
46#[derive(Debug)]
47#[allow(dead_code)]
48struct ShardedParam {
49    /// Local shard of the parameter
50    local_shard: Tensor<f32>,
51    /// Original shape before sharding
52    original_shape: Vec<usize>,
53    /// Number of elements in original parameter
54    numel: usize,
55    /// Padding added for even sharding (for uneven divisions)
56    padding: usize,
57}
58
59/// CPU offload configuration.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61pub enum CPUOffload {
62    /// No CPU offloading
63    #[default]
64    None,
65    /// Offload parameters to CPU when not in use
66    Params,
67    /// Offload both parameters and gradients
68    Full,
69}
70
71// =============================================================================
72// Fully Sharded Data Parallel
73// =============================================================================
74
75/// Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
76///
77/// FSDP shards model parameters across devices, gathering them only when needed
78/// for computation and sharding them again afterward.
79pub struct FullyShardedDataParallel<M: Module> {
80    /// Wrapped module
81    module: M,
82    /// Process group for communication
83    process_group: ProcessGroup,
84    /// Sharding strategy
85    sharding_strategy: ShardingStrategy,
86    /// CPU offload configuration
87    cpu_offload: CPUOffload,
88    /// Sharded parameter states
89    sharded_params: Vec<ShardedParam>,
90    /// Whether module is currently gathered (unsharded)
91    is_gathered: bool,
92    /// Mixed precision compute dtype
93    mixed_precision: bool,
94}
95
96impl<M: Module> FullyShardedDataParallel<M> {
97    /// Creates a new FSDP wrapper.
98    pub fn new(module: M, process_group: ProcessGroup) -> Self {
99        let mut fsdp = Self {
100            module,
101            process_group,
102            sharding_strategy: ShardingStrategy::default(),
103            cpu_offload: CPUOffload::default(),
104            sharded_params: Vec::new(),
105            is_gathered: true,
106            mixed_precision: false,
107        };
108
109        // Initialize sharding
110        fsdp.shard_parameters();
111        fsdp
112    }
113
114    /// Builder: set sharding strategy.
115    pub fn sharding_strategy(mut self, strategy: ShardingStrategy) -> Self {
116        self.sharding_strategy = strategy;
117        self.shard_parameters();
118        self
119    }
120
121    /// Builder: set CPU offload configuration.
122    pub fn cpu_offload(mut self, offload: CPUOffload) -> Self {
123        self.cpu_offload = offload;
124        self
125    }
126
127    /// Builder: enable mixed precision.
128    pub fn mixed_precision(mut self, enabled: bool) -> Self {
129        self.mixed_precision = enabled;
130        self
131    }
132
133    /// Returns reference to wrapped module.
134    pub fn module(&self) -> &M {
135        &self.module
136    }
137
138    /// Returns mutable reference to wrapped module.
139    pub fn module_mut(&mut self) -> &mut M {
140        &mut self.module
141    }
142
143    /// Returns the process group.
144    pub fn process_group(&self) -> &ProcessGroup {
145        &self.process_group
146    }
147
148    /// Returns the sharding strategy.
149    pub fn strategy(&self) -> ShardingStrategy {
150        self.sharding_strategy
151    }
152
153    /// Shards parameters across devices.
154    fn shard_parameters(&mut self) {
155        if self.sharding_strategy == ShardingStrategy::NoShard {
156            return;
157        }
158
159        let world_size = self.process_group.world_size();
160        let rank = self.process_group.rank();
161
162        self.sharded_params.clear();
163
164        for param in self.module.parameters() {
165            let data = param.data();
166            let shape = data.shape().to_vec();
167            let numel = data.numel();
168
169            // Calculate shard size with padding for even division
170            let shard_size = numel.div_ceil(world_size);
171            let padding = shard_size * world_size - numel;
172
173            // Get local shard
174            let flat_data = data.to_vec();
175            let start = rank * shard_size;
176            let end = ((rank + 1) * shard_size).min(flat_data.len());
177
178            let mut shard_data: Vec<f32> = if start < flat_data.len() {
179                flat_data[start..end].to_vec()
180            } else {
181                vec![0.0; shard_size]
182            };
183
184            // Pad to shard_size
185            while shard_data.len() < shard_size {
186                shard_data.push(0.0);
187            }
188
189            self.sharded_params.push(ShardedParam {
190                local_shard: Tensor::from_vec(shard_data, &[shard_size]).unwrap(),
191                original_shape: shape,
192                numel,
193                padding,
194            });
195        }
196
197        self.is_gathered = false;
198    }
199
200    /// Gathers all parameter shards before forward pass.
201    pub fn gather_parameters(&mut self) {
202        if self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
203            return;
204        }
205
206        let _world_size = self.process_group.world_size();
207        let params = self.module.parameters();
208
209        for (param, sharded) in params.iter().zip(self.sharded_params.iter()) {
210            // All-gather the shards
211            let gathered = self.process_group.all_gather_tensor(&sharded.local_shard);
212
213            // Reshape back to original shape (removing padding)
214            let flat: Vec<f32> = gathered.to_vec().into_iter().take(sharded.numel).collect();
215            let restored = Tensor::from_vec(flat, &sharded.original_shape).unwrap();
216
217            param.update_data(restored);
218        }
219
220        self.is_gathered = true;
221    }
222
223    /// Shards parameters after forward/backward pass.
224    pub fn reshard_parameters(&mut self) {
225        if !self.is_gathered || self.sharding_strategy == ShardingStrategy::NoShard {
226            return;
227        }
228
229        self.shard_parameters();
230    }
231
232    /// Synchronizes gradients across all ranks.
233    pub fn sync_gradients(&self) {
234        match self.sharding_strategy {
235            ShardingStrategy::NoShard => {
236                // Full all-reduce like DDP
237                for param in self.module.parameters() {
238                    if let Some(grad) = param.grad() {
239                        let mut grad_tensor = grad.clone();
240                        self.process_group
241                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
242                    }
243                }
244            }
245            ShardingStrategy::ShardGradOp | ShardingStrategy::FullShard => {
246                // Reduce-scatter gradients to get sharded gradients
247                for param in self.module.parameters() {
248                    if let Some(grad) = param.grad() {
249                        let _reduced = self
250                            .process_group
251                            .reduce_scatter_tensor(&grad, ReduceOp::Average);
252                        // In full implementation, would update parameter's gradient shard
253                    }
254                }
255            }
256            ShardingStrategy::HybridShard => {
257                // All-reduce within node, reduce-scatter across nodes
258                for param in self.module.parameters() {
259                    if let Some(grad) = param.grad() {
260                        let mut grad_tensor = grad.clone();
261                        self.process_group
262                            .all_reduce_tensor(&mut grad_tensor, ReduceOp::Average);
263                    }
264                }
265            }
266        }
267    }
268
269    /// Clips gradients by global norm.
270    pub fn clip_grad_norm(&self, max_norm: f32) -> f32 {
271        let mut total_norm_sq = 0.0f32;
272
273        for param in self.module.parameters() {
274            if let Some(grad) = param.grad() {
275                let grad_vec = grad.to_vec();
276                let norm_sq: f32 = grad_vec.iter().map(|x| x * x).sum();
277                total_norm_sq += norm_sq;
278            }
279        }
280
281        // All-reduce total norm across ranks
282        let mut norm_tensor = Tensor::from_vec(vec![total_norm_sq], &[1]).unwrap();
283        self.process_group
284            .all_reduce_tensor(&mut norm_tensor, ReduceOp::Sum);
285        let global_norm = norm_tensor.to_vec()[0].sqrt();
286
287        // Clip if necessary
288        if global_norm > max_norm {
289            let clip_coef = max_norm / (global_norm + 1e-6);
290            for param in self.module.parameters() {
291                if let Some(grad) = param.grad() {
292                    let clipped: Vec<f32> = grad.to_vec().iter().map(|x| x * clip_coef).collect();
293                    let clipped_tensor = Tensor::from_vec(clipped, grad.shape()).unwrap();
294                    param.variable().set_grad(clipped_tensor);
295                }
296            }
297        }
298
299        global_norm
300    }
301
302    /// Estimates memory usage with different sharding strategies.
303    pub fn memory_estimate(&self) -> FSDPMemoryStats {
304        let params = self.module.parameters();
305        let total_params: usize = params.iter().map(|p| p.numel()).sum();
306        let world_size = self.process_group.world_size();
307
308        let bytes_per_param = 4; // f32
309        let param_memory = total_params * bytes_per_param;
310
311        let (sharded_params, sharded_grads, sharded_optim) = match self.sharding_strategy {
312            ShardingStrategy::NoShard => (param_memory, param_memory, param_memory * 2),
313            ShardingStrategy::ShardGradOp => (
314                param_memory,
315                param_memory / world_size,
316                param_memory * 2 / world_size,
317            ),
318            ShardingStrategy::FullShard | ShardingStrategy::HybridShard => (
319                param_memory / world_size,
320                param_memory / world_size,
321                param_memory * 2 / world_size,
322            ),
323        };
324
325        FSDPMemoryStats {
326            total_params,
327            param_memory_bytes: sharded_params,
328            grad_memory_bytes: sharded_grads,
329            optim_memory_bytes: sharded_optim,
330            world_size,
331        }
332    }
333}
334
335impl<M: Module> Module for FullyShardedDataParallel<M> {
336    fn forward(&self, input: &Variable) -> Variable {
337        // Note: In a real implementation, gather would be called automatically
338        // through hooks before forward and reshard after
339        self.module.forward(input)
340    }
341
342    fn parameters(&self) -> Vec<Parameter> {
343        self.module.parameters()
344    }
345
346    fn train(&mut self) {
347        self.module.train();
348    }
349
350    fn eval(&mut self) {
351        self.module.eval();
352    }
353
354    fn is_training(&self) -> bool {
355        self.module.is_training()
356    }
357}
358
359/// Memory statistics for FSDP.
360#[derive(Debug, Clone)]
361pub struct FSDPMemoryStats {
362    /// Total number of parameters
363    pub total_params: usize,
364    /// Memory for parameters (bytes)
365    pub param_memory_bytes: usize,
366    /// Memory for gradients (bytes)
367    pub grad_memory_bytes: usize,
368    /// Memory for optimizer state (bytes)
369    pub optim_memory_bytes: usize,
370    /// World size (number of ranks)
371    pub world_size: usize,
372}
373
374impl FSDPMemoryStats {
375    /// Total memory per rank in MB.
376    pub fn total_memory_mb(&self) -> f32 {
377        (self.param_memory_bytes + self.grad_memory_bytes + self.optim_memory_bytes) as f32
378            / (1024.0 * 1024.0)
379    }
380
381    /// Memory savings compared to no sharding.
382    pub fn memory_savings(&self) -> f32 {
383        if self.world_size > 1 {
384            1.0 - (1.0 / self.world_size as f32)
385        } else {
386            0.0
387        }
388    }
389}
390
391// =============================================================================
392// Tensor Parallelism
393// =============================================================================
394
395/// Column-parallel linear layer.
396///
397/// Splits the weight matrix along the column dimension across ranks.
398/// Each rank computes a portion of the output features.
399#[allow(dead_code)]
400pub struct ColumnParallelLinear {
401    /// Local weight shard
402    weight: Parameter,
403    /// Bias (replicated on all ranks)
404    bias: Option<Parameter>,
405    /// Process group
406    process_group: ProcessGroup,
407    /// Input features
408    in_features: usize,
409    /// Output features (total across all ranks)
410    out_features: usize,
411    /// Whether to gather output
412    gather_output: bool,
413}
414
415impl ColumnParallelLinear {
416    /// Creates a new column-parallel linear layer.
417    pub fn new(
418        in_features: usize,
419        out_features: usize,
420        bias: bool,
421        process_group: ProcessGroup,
422        gather_output: bool,
423    ) -> Self {
424        let world_size = process_group.world_size();
425        let local_out_features = out_features / world_size;
426
427        let weight_data = Tensor::randn(&[local_out_features, in_features]);
428        let weight = Parameter::new(weight_data, true);
429
430        let bias = if bias {
431            let bias_data = Tensor::zeros(&[local_out_features]);
432            Some(Parameter::new(bias_data, true))
433        } else {
434            None
435        };
436
437        Self {
438            weight,
439            bias,
440            process_group,
441            in_features,
442            out_features,
443            gather_output,
444        }
445    }
446}
447
448impl Module for ColumnParallelLinear {
449    fn forward(&self, input: &Variable) -> Variable {
450        // Local matmul: input @ weight.T
451        let weight_var = Variable::new(self.weight.data(), false);
452        let output = input.matmul(&weight_var.transpose(0, 1));
453
454        // Add bias
455        let output = if let Some(ref bias) = self.bias {
456            let bias_var = Variable::new(bias.data(), false);
457            output.add(&bias_var)
458        } else {
459            output
460        };
461
462        // Optionally gather output across ranks
463        if self.gather_output {
464            let gathered = self.process_group.all_gather_tensor(&output.data());
465            Variable::new(gathered, output.requires_grad())
466        } else {
467            output
468        }
469    }
470
471    fn parameters(&self) -> Vec<Parameter> {
472        let mut params = vec![self.weight.clone()];
473        if let Some(ref bias) = self.bias {
474            params.push(bias.clone());
475        }
476        params
477    }
478}
479
480/// Row-parallel linear layer.
481///
482/// Splits the weight matrix along the row dimension across ranks.
483/// Each rank has a portion of the input features.
484#[allow(dead_code)]
485pub struct RowParallelLinear {
486    /// Local weight shard
487    weight: Parameter,
488    /// Bias (only on rank 0)
489    bias: Option<Parameter>,
490    /// Process group
491    process_group: ProcessGroup,
492    /// Input features (total across all ranks)
493    in_features: usize,
494    /// Output features
495    out_features: usize,
496    /// Whether input is already split
497    input_is_parallel: bool,
498}
499
500impl RowParallelLinear {
501    /// Creates a new row-parallel linear layer.
502    pub fn new(
503        in_features: usize,
504        out_features: usize,
505        bias: bool,
506        process_group: ProcessGroup,
507        input_is_parallel: bool,
508    ) -> Self {
509        let world_size = process_group.world_size();
510        let rank = process_group.rank();
511        let local_in_features = in_features / world_size;
512
513        let weight_data = Tensor::randn(&[out_features, local_in_features]);
514        let weight = Parameter::new(weight_data, true);
515
516        // Only rank 0 has bias
517        let bias = if bias && rank == 0 {
518            let bias_data = Tensor::zeros(&[out_features]);
519            Some(Parameter::new(bias_data, true))
520        } else {
521            None
522        };
523
524        Self {
525            weight,
526            bias,
527            process_group,
528            in_features,
529            out_features,
530            input_is_parallel,
531        }
532    }
533}
534
535impl Module for RowParallelLinear {
536    fn forward(&self, input: &Variable) -> Variable {
537        // If input is not parallel, take local shard
538        let local_input = if self.input_is_parallel {
539            input.clone()
540        } else {
541            // Split input along feature dimension for row parallelism
542            let world_size = self.process_group.world_size();
543            let rank = self.process_group.rank();
544            let data = input.data();
545            let shape = data.shape();
546            let feature_dim = shape[shape.len() - 1];
547            let local_features = feature_dim / world_size;
548            let start = rank * local_features;
549            let end = start + local_features;
550
551            // Slice the last dimension
552            let sliced = if shape.len() == 2 {
553                data.slice(&[0..shape[0], start..end])
554            } else {
555                data.clone() // Fallback for other shapes
556            };
557            Variable::new(sliced, input.requires_grad())
558        };
559
560        // Local matmul
561        let weight_var = Variable::new(self.weight.data(), false);
562        let local_output = local_input.matmul(&weight_var.transpose(0, 1));
563
564        // All-reduce to combine partial outputs
565        let mut output_data = local_output.data().clone();
566        self.process_group
567            .all_reduce_tensor(&mut output_data, ReduceOp::Sum);
568        let output = Variable::new(output_data, local_output.requires_grad());
569
570        // Add bias (only on rank 0, then broadcast)
571        if let Some(ref bias) = self.bias {
572            let bias_var = Variable::new(bias.data(), false);
573            output.add(&bias_var)
574        } else {
575            output
576        }
577    }
578
579    fn parameters(&self) -> Vec<Parameter> {
580        let mut params = vec![self.weight.clone()];
581        if let Some(ref bias) = self.bias {
582            params.push(bias.clone());
583        }
584        params
585    }
586}
587
588// =============================================================================
589// Tests
590// =============================================================================
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use axonml_nn::Linear;
596
597    #[test]
598    fn test_sharding_strategy_default() {
599        assert_eq!(ShardingStrategy::default(), ShardingStrategy::FullShard);
600    }
601
602    #[test]
603    fn test_fsdp_creation() {
604        let model = Linear::new(10, 5);
605        let pg = ProcessGroup::mock();
606        let fsdp = FullyShardedDataParallel::new(model, pg);
607
608        assert_eq!(fsdp.strategy(), ShardingStrategy::FullShard);
609    }
610
611    #[test]
612    fn test_fsdp_forward() {
613        let model = Linear::new(4, 2);
614        let pg = ProcessGroup::mock();
615        let mut fsdp = FullyShardedDataParallel::new(model, pg);
616
617        // Gather before forward
618        fsdp.gather_parameters();
619
620        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
621        let output = fsdp.forward(&input);
622
623        assert_eq!(output.data().shape(), &[1, 2]);
624    }
625
626    #[test]
627    fn test_fsdp_builder() {
628        let model = Linear::new(10, 5);
629        let pg = ProcessGroup::mock();
630
631        let fsdp = FullyShardedDataParallel::new(model, pg)
632            .sharding_strategy(ShardingStrategy::ShardGradOp)
633            .cpu_offload(CPUOffload::Params)
634            .mixed_precision(true);
635
636        assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
637    }
638
639    #[test]
640    fn test_fsdp_memory_stats() {
641        let model = Linear::new(100, 50);
642        let pg = ProcessGroup::mock();
643        let fsdp = FullyShardedDataParallel::new(model, pg);
644
645        let stats = fsdp.memory_estimate();
646        assert!(stats.total_params > 0);
647        assert!(stats.total_memory_mb() > 0.0);
648    }
649
650    #[test]
651    fn test_fsdp_no_shard() {
652        let model = Linear::new(10, 5);
653        let pg = ProcessGroup::mock();
654        let fsdp =
655            FullyShardedDataParallel::new(model, pg).sharding_strategy(ShardingStrategy::NoShard);
656
657        assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
658    }
659
660    #[test]
661    fn test_column_parallel_linear() {
662        let pg = ProcessGroup::mock();
663        // With world_size=1, local_out_features = out_features / 1 = 4
664        let layer = ColumnParallelLinear::new(8, 4, true, pg, false); // Don't gather for simple test
665
666        let input = Variable::new(Tensor::randn(&[2, 8]), false);
667        let output = layer.forward(&input);
668
669        // Output shape should be [batch, local_out_features] = [2, 4]
670        assert_eq!(output.data().shape(), &[2, 4]);
671    }
672
673    #[test]
674    fn test_row_parallel_linear() {
675        let pg = ProcessGroup::mock();
676        let layer = RowParallelLinear::new(8, 4, true, pg, false);
677
678        let input = Variable::new(Tensor::randn(&[2, 8]), false);
679        let output = layer.forward(&input);
680
681        assert_eq!(output.data().shape(), &[2, 4]);
682    }
683}