pub struct FSDP<M: Module<T>, T: Float> { /* private fields */ }Expand description
Fully Sharded Data Parallel module wrapper.
Wraps an inner Module and shards each parameter across ranks so that
each rank only stores 1 / world_size of the full parameter tensor.
§Forward pass
Before calling the inner module’s forward(), FSDP all-gathers each
shard to reconstruct the full parameter tensor and installs it into the
module. The full-parameter tensors are stored in [full_params] so
that backward can accumulate gradients on them.
§Gradient synchronization
After backward(), call [sync_gradients] to:
- Read gradients from the full-parameter tensors stored during forward.
- Reduce-scatter the full gradients so each rank gets only its shard portion of the gradient.
- Set each shard parameter’s gradient from the reduce-scattered result.
§Example
let mut fsdp = FSDP::new(model, backend)?;
loop {
let output = fsdp.forward(&input)?;
let loss = criterion.forward(&output, &target)?;
ferrotorch_core::backward(&loss)?;
fsdp.sync_gradients()?;
optimizer.step()?;
optimizer.zero_grad()?;
}Implementations§
Source§impl<M: Module<T>, T: Float> FSDP<M, T>
impl<M: Module<T>, T: Float> FSDP<M, T>
Sourcepub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self>
pub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self>
Wrap a module for fully-sharded data-parallel training.
Each parameter is split evenly across world_size ranks. This rank
keeps only its shard (the rank-th chunk). The original parameter
shapes are recorded for reconstruction during forward.
§Panics
Panics if any parameter’s element count is not evenly divisible by
world_size.
Sourcepub fn module_mut(&mut self) -> &mut M
pub fn module_mut(&mut self) -> &mut M
Mutable access to the inner module.
Sourcepub fn into_inner(self) -> M
pub fn into_inner(self) -> M
Consume the wrapper and return the inner module.
Sourcepub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
Reconstruct full parameters from shards across all ranks and run the inner module’s forward pass.
The all-gathered full-parameter tensors are stored in self.full_params
so their gradients can be read after backward.
Sourcepub fn sync_gradients(&mut self) -> FerrotorchResult<()>
pub fn sync_gradients(&mut self) -> FerrotorchResult<()>
Reduce-scatter gradients from the full-parameter tensors stored during forward, then set each shard parameter’s gradient.
Call this after backward() and before optimizer.step().
§How it works
- For each parameter, read the gradient from the full-param tensor
that was used during forward (stored in
self.full_params). - Reduce-scatter the full gradient across ranks (mean reduction) so each rank gets only its shard portion.
- Set the shard parameter’s
.grad()to the reduce-scattered result.
Using reduce-scatter (not allreduce) is correct for FSDP because each rank only needs its own shard of the gradient to update its shard of the parameter.
Sourcepub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()>
pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()>
Update shard parameters from a flat data slice.
This is used by optimizers that produce a flat parameter buffer. The slice must have exactly the number of elements expected for this rank’s shards.
Auto Trait Implementations§
impl<M, T> Freeze for FSDP<M, T>where
M: Freeze,
impl<M, T> !RefUnwindSafe for FSDP<M, T>
impl<M, T> Send for FSDP<M, T>
impl<M, T> Sync for FSDP<M, T>
impl<M, T> Unpin for FSDP<M, T>
impl<M, T> UnsafeUnpin for FSDP<M, T>where
M: UnsafeUnpin,
impl<M, T> !UnwindSafe for FSDP<M, T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more