Skip to main content

ferrotorch_distributed/
fsdp.rs

1//! Fully Sharded Data Parallel (FSDP) wrapper.
2//!
3//! [`FSDP`] wraps a [`Module`] and shards its parameters across ranks.
4//! During forward, parameters are all-gathered to form the full tensors.
5//! During gradient synchronization, full-parameter gradients are
6//! reduce-scattered so each rank only stores its shard's gradient.
7//!
8//! This reduces per-rank memory from O(params) to O(params / world_size)
9//! at the cost of additional communication during forward and backward.
10
11use std::sync::Arc;
12
13use ferrotorch_core::storage::TensorStorage;
14use ferrotorch_core::{FerrotorchResult, Float, Tensor};
15use ferrotorch_nn::{Module, Parameter};
16
17use crate::backend::Backend;
18use crate::collective::{ReduceOp, all_gather, reduce_scatter};
19
20/// Fully Sharded Data Parallel module wrapper.
21///
22/// Wraps an inner [`Module`] and shards each parameter across ranks so that
23/// each rank only stores `1 / world_size` of the full parameter tensor.
24///
25/// # Forward pass
26///
27/// Before calling the inner module's `forward()`, FSDP all-gathers each
28/// shard to reconstruct the full parameter tensor and installs it into the
29/// module. The full-parameter tensors are stored in [`full_params`] so
30/// that backward can accumulate gradients on them.
31///
32/// # Gradient synchronization
33///
34/// After `backward()`, call [`sync_gradients`] to:
35/// 1. Read gradients from the full-parameter tensors stored during forward.
36/// 2. Reduce-scatter the full gradients so each rank gets only its shard
37///    portion of the gradient.
38/// 3. Set each shard parameter's gradient from the reduce-scattered result.
39///
40/// # Example
41///
42/// ```ignore
43/// let mut fsdp = FSDP::new(model, backend)?;
44///
45/// loop {
46///     let output = fsdp.forward(&input)?;
47///     let loss = criterion.forward(&output, &target)?;
48///     ferrotorch_core::backward(&loss)?;
49///     fsdp.sync_gradients()?;
50///     optimizer.step()?;
51///     optimizer.zero_grad()?;
52/// }
53/// ```
54pub struct FSDP<M: Module<T>, T: Float> {
55    module: M,
56    backend: Arc<dyn Backend>,
57    /// Original full-parameter shapes before sharding.
58    original_shapes: Vec<Vec<usize>>,
59    /// Full-param tensors from the last forward pass, kept alive so
60    /// backward can accumulate gradients on them.
61    full_params: Vec<Tensor<T>>,
62    _marker: std::marker::PhantomData<T>,
63}
64
65impl<M: Module<T>, T: Float> FSDP<M, T> {
66    /// Wrap a module for fully-sharded data-parallel training.
67    ///
68    /// Each parameter is split evenly across `world_size` ranks. This rank
69    /// keeps only its shard (the `rank`-th chunk). The original parameter
70    /// shapes are recorded for reconstruction during forward.
71    ///
72    /// # Panics
73    ///
74    /// Panics if any parameter's element count is not evenly divisible by
75    /// `world_size`.
76    pub fn new(mut module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self> {
77        let rank = backend.rank();
78        let world_size = backend.world_size();
79        let mut original_shapes = Vec::new();
80
81        {
82            let params = module.parameters_mut();
83            for param in params {
84                let tensor = param.tensor();
85                let shape = tensor.shape().to_vec();
86                let numel = tensor.numel();
87
88                assert!(
89                    numel % world_size == 0,
90                    "FSDP: parameter with {} elements is not evenly divisible by world_size {}",
91                    numel,
92                    world_size,
93                );
94
95                original_shapes.push(shape);
96
97                let data = tensor.data_vec()?;
98                let chunk_size = numel / world_size;
99                let start = rank * chunk_size;
100                let end = start + chunk_size;
101                let shard_data = data[start..end].to_vec();
102
103                let shard_tensor =
104                    Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
105                // Shard params need requires_grad=true so the optimizer can
106                // update them.
107                *param = Parameter::new(shard_tensor);
108            }
109        }
110
111        Ok(Self {
112            module,
113            backend,
114            original_shapes,
115            full_params: Vec::new(),
116            _marker: std::marker::PhantomData,
117        })
118    }
119
120    /// Immutable access to the inner module.
121    pub fn module(&self) -> &M {
122        &self.module
123    }
124
125    /// Mutable access to the inner module.
126    pub fn module_mut(&mut self) -> &mut M {
127        &mut self.module
128    }
129
130    /// Consume the wrapper and return the inner module.
131    pub fn into_inner(self) -> M {
132        self.module
133    }
134
135    /// The backend used for communication.
136    pub fn backend(&self) -> &Arc<dyn Backend> {
137        &self.backend
138    }
139
140    /// Reconstruct full parameters from shards across all ranks and run
141    /// the inner module's forward pass.
142    ///
143    /// The all-gathered full-parameter tensors are stored in `self.full_params`
144    /// so their gradients can be read after backward.
145    pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
146        let world_size = self.backend.world_size();
147        self.full_params.clear();
148
149        {
150            let params = self.module.parameters_mut();
151            for (i, param) in params.into_iter().enumerate() {
152                let shard = param.tensor().clone();
153                let orig_shape = &self.original_shapes[i];
154
155                // All-gather the shard to get the full parameter.
156                let full = if world_size == 1 {
157                    shard
158                } else {
159                    all_gather(&shard, self.backend.as_ref())?
160                };
161
162                // Reshape to the original parameter shape and enable grad.
163                let full = Tensor::from_storage(
164                    TensorStorage::cpu(full.data_vec()?),
165                    orig_shape.clone(),
166                    true,
167                )?;
168
169                self.full_params.push(full.clone());
170
171                // Install the full parameter into the module for this forward pass.
172                *param = Parameter::new(full);
173            }
174        }
175
176        let output = self.module.forward(input)?;
177
178        // After forward, restore shard parameters so the module holds only
179        // shards at rest (saves memory).
180        self.restore_shards()?;
181
182        Ok(output)
183    }
184
185    /// Replace full parameters with their local shards to free memory.
186    fn restore_shards(&mut self) -> FerrotorchResult<()> {
187        let rank = self.backend.rank();
188        let world_size = self.backend.world_size();
189
190        let params = self.module.parameters_mut();
191        for (i, param) in params.into_iter().enumerate() {
192            let tensor = param.tensor();
193            let data = tensor.data_vec()?;
194            let numel = data.len();
195            let chunk_size = numel / world_size;
196            let start = rank * chunk_size;
197            let end = start + chunk_size;
198            let shard_data = data[start..end].to_vec();
199
200            let shard_tensor =
201                Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
202            *param = Parameter::new(shard_tensor);
203
204            // Preserve the original shape metadata.
205            let _ = &self.original_shapes[i];
206        }
207
208        Ok(())
209    }
210
211    /// Reduce-scatter gradients from the full-parameter tensors stored
212    /// during forward, then set each shard parameter's gradient.
213    ///
214    /// Call this after `backward()` and before `optimizer.step()`.
215    ///
216    /// # How it works
217    ///
218    /// 1. For each parameter, read the gradient from the full-param tensor
219    ///    that was used during forward (stored in `self.full_params`).
220    /// 2. Reduce-scatter the full gradient across ranks (mean reduction) so
221    ///    each rank gets only its shard portion.
222    /// 3. Set the shard parameter's `.grad()` to the reduce-scattered result.
223    ///
224    /// Using reduce-scatter (not allreduce) is correct for FSDP because each
225    /// rank only needs its own shard of the gradient to update its shard of
226    /// the parameter.
227    pub fn sync_gradients(&mut self) -> FerrotorchResult<()> {
228        let world_size = self.backend.world_size();
229        let params = self.module.parameters_mut();
230
231        if self.full_params.len() != params.len() {
232            return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
233                message: format!(
234                    "FSDP sync_gradients: expected {} full_params but have {}. \
235                     Was forward() called before backward()?",
236                    params.len(),
237                    self.full_params.len(),
238                ),
239            });
240        }
241
242        for (i, param) in params.into_iter().enumerate() {
243            let full_param = &self.full_params[i];
244
245            // Read the gradient from the full-parameter tensor. If no
246            // gradient was computed (e.g., parameter was unused in forward),
247            // use zeros so all ranks exchange buffers of the same size.
248            let grad = full_param.grad()?;
249            let full_grad = match grad {
250                Some(g) => g,
251                None => {
252                    let numel = full_param.numel();
253                    Tensor::from_storage(
254                        TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
255                        full_param.shape().to_vec(),
256                        false,
257                    )?
258                }
259            };
260
261            // Flatten for reduce-scatter.
262            let grad_data = full_grad.data_vec()?;
263            let flat_grad = Tensor::from_storage(
264                TensorStorage::cpu(grad_data),
265                vec![full_grad.numel()],
266                false,
267            )?;
268
269            // Reduce-scatter: each rank gets its shard of the averaged gradient.
270            let shard_grad = if world_size == 1 {
271                flat_grad
272            } else {
273                reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
274            };
275
276            // Set the shard parameter's gradient.
277            // Interior mutability: set_grad works on &self via Mutex.
278            param.tensor().set_grad(Some(shard_grad))?;
279        }
280
281        // Clear full_params to free memory now that gradients have been read.
282        self.full_params.clear();
283
284        Ok(())
285    }
286
287    /// Update shard parameters from a flat data slice.
288    ///
289    /// This is used by optimizers that produce a flat parameter buffer.
290    /// The slice must have exactly the number of elements expected for
291    /// this rank's shards.
292    pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()> {
293        let params = self.module.parameters_mut();
294        let total_shard_numel: usize = params.iter().map(|p| p.tensor().numel()).sum();
295
296        assert!(
297            flat_data.len() == total_shard_numel,
298            "FSDP update_shards: expected {} elements but got {}",
299            total_shard_numel,
300            flat_data.len(),
301        );
302
303        let mut offset = 0;
304        for param in params {
305            let numel = param.tensor().numel();
306            let shard_data = flat_data[offset..offset + numel].to_vec();
307            let shard_tensor = Tensor::from_storage(
308                TensorStorage::cpu(shard_data),
309                param.tensor().shape().to_vec(),
310                true,
311            )?;
312            *param = Parameter::new(shard_tensor);
313            offset += numel;
314        }
315
316        Ok(())
317    }
318}
319
320// FSDP does NOT implement Module<T> because forward() requires &mut self
321// (to store full_params). Callers must use fsdp.forward() directly.
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::backend::SimulatedBackend;
327    use ferrotorch_core::storage::TensorStorage;
328    use ferrotorch_core::{FerrotorchResult, Tensor};
329    use ferrotorch_nn::Parameter;
330    use std::thread;
331
332    /// Minimal module with one parameter for testing FSDP.
333    struct TestModule<T: Float> {
334        weight: Parameter<T>,
335        training: bool,
336    }
337
338    impl<T: Float> TestModule<T> {
339        fn new(data: &[T]) -> FerrotorchResult<Self> {
340            Ok(Self {
341                weight: Parameter::from_slice(data, &[data.len()])?,
342                training: true,
343            })
344        }
345    }
346
347    impl<T: Float> Module<T> for TestModule<T> {
348        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
349            // Simple forward: multiply input by weight sum (produces a scalar
350            // that depends on all weight elements).
351            let w_data = self.weight.tensor().data_vec()?;
352            let w_sum: T = w_data
353                .iter()
354                .copied()
355                .fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
356            let i_data = input.data_vec()?;
357            let out: Vec<T> = i_data.iter().map(|&x| x * w_sum).collect();
358            Tensor::from_storage(TensorStorage::cpu(out), input.shape().to_vec(), false)
359        }
360
361        fn parameters(&self) -> Vec<&Parameter<T>> {
362            vec![&self.weight]
363        }
364
365        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
366            vec![&mut self.weight]
367        }
368
369        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
370            vec![("weight".into(), &self.weight)]
371        }
372
373        fn train(&mut self) {
374            self.training = true;
375        }
376
377        fn eval(&mut self) {
378            self.training = false;
379        }
380
381        fn is_training(&self) -> bool {
382            self.training
383        }
384    }
385
386    #[test]
387    fn test_fsdp_sharding() {
388        // 2 ranks, parameter [10, 20, 30, 40].
389        // Rank 0 gets [10, 20], Rank 1 gets [30, 40].
390        let group = SimulatedBackend::create_group(2).unwrap();
391        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
392
393        let handles: Vec<_> = arcs
394            .iter()
395            .cloned()
396            .map(|b| {
397                thread::spawn(move || {
398                    let rank = b.rank();
399                    let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
400                    let fsdp = FSDP::new(model, b).unwrap();
401
402                    let shard = fsdp.module().weight.tensor().data_vec().unwrap();
403                    (rank, shard)
404                })
405            })
406            .collect();
407
408        for h in handles {
409            let (rank, shard) = h.join().unwrap();
410            if rank == 0 {
411                assert_eq!(shard, &[10.0, 20.0]);
412            } else {
413                assert_eq!(shard, &[30.0, 40.0]);
414            }
415        }
416    }
417
418    #[test]
419    fn test_fsdp_shard_requires_grad() {
420        // Shard parameters must have requires_grad=true.
421        let group = SimulatedBackend::create_group(2).unwrap();
422        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
423
424        let handles: Vec<_> = arcs
425            .iter()
426            .cloned()
427            .map(|b| {
428                thread::spawn(move || {
429                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
430                    let fsdp = FSDP::new(model, b).unwrap();
431                    fsdp.module().weight.tensor().requires_grad()
432                })
433            })
434            .collect();
435
436        for h in handles {
437            assert!(h.join().unwrap(), "shard must have requires_grad=true");
438        }
439    }
440
441    #[test]
442    fn test_fsdp_forward_restores_shards() {
443        // After forward(), parameters should be back to shard size.
444        let group = SimulatedBackend::create_group(2).unwrap();
445        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
446
447        let handles: Vec<_> = arcs
448            .iter()
449            .cloned()
450            .map(|b| {
451                thread::spawn(move || {
452                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
453                    let mut fsdp = FSDP::new(model, b).unwrap();
454
455                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
456                    let _output = fsdp.forward(&input).unwrap();
457
458                    // After forward, shard should be size 2 (4 / 2 ranks).
459                    let shard = fsdp.module().weight.tensor();
460                    assert_eq!(shard.numel(), 2);
461                    assert!(shard.requires_grad());
462                })
463            })
464            .collect();
465
466        for h in handles {
467            h.join().unwrap();
468        }
469    }
470
471    #[test]
472    fn test_fsdp_forward_produces_correct_output() {
473        // 2 ranks, param [1, 2, 3, 4], weight_sum = 10.
474        // Input [2.0] -> output should be [20.0] on all ranks.
475        let group = SimulatedBackend::create_group(2).unwrap();
476        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
477
478        let handles: Vec<_> = arcs
479            .iter()
480            .cloned()
481            .map(|b| {
482                thread::spawn(move || {
483                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
484                    let mut fsdp = FSDP::new(model, b).unwrap();
485
486                    let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
487                    let output = fsdp.forward(&input).unwrap();
488                    let data = output.data_vec().unwrap();
489                    assert!(
490                        (data[0] - 20.0).abs() < 1e-6,
491                        "expected 20.0, got {}",
492                        data[0]
493                    );
494                })
495            })
496            .collect();
497
498        for h in handles {
499            h.join().unwrap();
500        }
501    }
502
503    #[test]
504    fn test_fsdp_update_shards() {
505        let group = SimulatedBackend::create_group(1).unwrap();
506        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
507        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
508        let mut fsdp = FSDP::new(model, b).unwrap();
509
510        fsdp.update_shards(&[10.0, 20.0, 30.0, 40.0]).unwrap();
511        let data = fsdp.module().weight.tensor().data_vec().unwrap();
512        assert_eq!(data, &[10.0, 20.0, 30.0, 40.0]);
513    }
514
515    #[test]
516    #[should_panic(expected = "expected 4 elements but got 2")]
517    fn test_fsdp_update_shards_size_validation() {
518        let group = SimulatedBackend::create_group(1).unwrap();
519        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
520        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
521        let mut fsdp = FSDP::new(model, b).unwrap();
522
523        // Wrong size: should panic.
524        fsdp.update_shards(&[10.0, 20.0]).unwrap();
525    }
526
527    #[test]
528    fn test_fsdp_sync_gradients_single_rank() {
529        // Single rank: sync_gradients should pass through the gradient
530        // from the full param to the shard param.
531        let group = SimulatedBackend::create_group(1).unwrap();
532        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
533        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
534        let mut fsdp = FSDP::new(model, b).unwrap();
535
536        // Run forward to populate full_params.
537        let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
538        let _output = fsdp.forward(&input).unwrap();
539
540        // Manually set gradient on full_params (simulating backward).
541        let grad = Tensor::from_storage(
542            TensorStorage::cpu(vec![0.1f32, 0.2, 0.3, 0.4]),
543            vec![4],
544            false,
545        )
546        .unwrap();
547        fsdp.full_params[0].set_grad(Some(grad)).unwrap();
548
549        fsdp.sync_gradients().unwrap();
550
551        // Shard param should now have the full gradient (single rank = no scatter).
552        let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
553        let data = shard_grad.data_vec().unwrap();
554        assert_eq!(data, &[0.1, 0.2, 0.3, 0.4]);
555    }
556
557    #[test]
558    fn test_fsdp_sync_gradients_multi_rank() {
559        // 2 ranks, param size 4 -> shard size 2.
560        // Both ranks set identical gradients on full_params: [1, 2, 3, 4].
561        // reduce_scatter(mean) on [1,2,3,4] -> rank 0 gets [1,2], rank 1 gets [3,4].
562        let group = SimulatedBackend::create_group(2).unwrap();
563        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
564
565        let handles: Vec<_> = arcs
566            .iter()
567            .cloned()
568            .map(|b| {
569                thread::spawn(move || {
570                    let rank = b.rank();
571                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
572                    let mut fsdp = FSDP::new(model, b).unwrap();
573
574                    // Run forward.
575                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
576                    let _output = fsdp.forward(&input).unwrap();
577
578                    // Set gradient on full_params.
579                    let grad = Tensor::from_storage(
580                        TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
581                        vec![4],
582                        false,
583                    )
584                    .unwrap();
585                    fsdp.full_params[0].set_grad(Some(grad)).unwrap();
586
587                    fsdp.sync_gradients().unwrap();
588
589                    let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
590                    let data = shard_grad.data_vec().unwrap();
591                    (rank, data)
592                })
593            })
594            .collect();
595
596        for h in handles {
597            let (rank, data) = h.join().unwrap();
598            if rank == 0 {
599                // Mean of [1,2] from both ranks = [1,2].
600                assert_eq!(data.len(), 2);
601                assert!(
602                    (data[0] - 1.0).abs() < 1e-6,
603                    "rank 0: expected 1.0, got {}",
604                    data[0]
605                );
606                assert!(
607                    (data[1] - 2.0).abs() < 1e-6,
608                    "rank 0: expected 2.0, got {}",
609                    data[1]
610                );
611            } else {
612                // Mean of [3,4] from both ranks = [3,4].
613                assert_eq!(data.len(), 2);
614                assert!(
615                    (data[0] - 3.0).abs() < 1e-6,
616                    "rank 1: expected 3.0, got {}",
617                    data[0]
618                );
619                assert!(
620                    (data[1] - 4.0).abs() < 1e-6,
621                    "rank 1: expected 4.0, got {}",
622                    data[1]
623                );
624            }
625        }
626    }
627}