pub struct FSDP<M: Module<T>, T: Float> { /* private fields */ }Expand description
Fully Sharded Data Parallel module wrapper.
Wraps an inner Module and shards each parameter across ranks so that
each rank only stores 1 / world_size of the full parameter tensor.
§Forward pass
Before calling the inner module’s forward(), FSDP all-gathers each
shard to reconstruct the full parameter tensor and installs it into the
module. The full-parameter tensors are stored in [full_params] so
that backward can accumulate gradients on them.
§Gradient synchronization
After backward(), call [sync_gradients] to:
- Read gradients from the full-parameter tensors stored during forward.
- Reduce-scatter the full gradients so each rank gets only its shard portion of the gradient.
- Set each shard parameter’s gradient from the reduce-scattered result.
§Example
let mut fsdp = FSDP::new(model, backend)?;
loop {
let output = fsdp.forward(&input)?;
let loss = criterion.forward(&output, &target)?;
ferrotorch_core::backward(&loss)?;
fsdp.sync_gradients()?;
optimizer.step()?;
optimizer.zero_grad()?;
}Implementations§
Source§impl<M: Module<T>, T: Float> FSDP<M, T>
impl<M: Module<T>, T: Float> FSDP<M, T>
Sourcepub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self>
pub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self>
Wrap a module for fully-sharded data-parallel training.
Each parameter is split evenly across world_size ranks. This rank
keeps only its shard (the rank-th chunk). The original parameter
shapes are recorded for reconstruction during forward.
§Panics
Panics if any parameter’s element count is not evenly divisible by
world_size.
Sourcepub fn new_with_strategy(
module: M,
backend: Arc<dyn Backend>,
strategy: ShardingStrategy,
) -> FerrotorchResult<Self>
pub fn new_with_strategy( module: M, backend: Arc<dyn Backend>, strategy: ShardingStrategy, ) -> FerrotorchResult<Self>
Wrap a module for data-parallel training with a specific
ShardingStrategy.
FullShard— shard parameters, gradients, and optimizer state (the classic FSDP / ZeRO-3 behavior; identical to [new]).ShardGradOp— keep parameters replicated on every rank and only shard gradients + optimizer state (ZeRO-2). After calling the optimizer step on the shard gradients, the caller must call [broadcast_updated_params] to re-sync the updated parameter shards back to every rank. CL-372.NoShard— no sharding (ZeRO-0 / DDP equivalent). Gradients are allreduced across ranks insync_gradientsand all ranks update the full parameters locally.
Sourcepub fn strategy(&self) -> ShardingStrategy
pub fn strategy(&self) -> ShardingStrategy
Return the active sharding strategy.
Sourcepub fn prefetch_forward_params(&mut self) -> FerrotorchResult<()>
pub fn prefetch_forward_params(&mut self) -> FerrotorchResult<()>
Kick off asynchronous all-gathers for every parameter so the
next forward call consumes the pre-gathered
tensors instead of blocking on a fresh all-gather.
This is FSDP’s equivalent of PyTorch’s backward prefetch:
communication for layer N+1 (or the next forward pass) overlaps
with compute for layer N. The caller should insert local
compute (e.g., the previous layer’s backward, or input
preprocessing) between prefetch_forward_params and forward
to realize the overlap.
Only valid for ShardingStrategy::FullShard — the other
strategies keep parameters replicated on every rank so there’s
nothing to all-gather. Calling this on a non-FullShard FSDP
returns an InvalidArgument error.
§Invariant
Exactly one prefetch_forward_params → forward pair should be
in flight at any time on a given FSDP instance. Calling
prefetch_forward_params twice in a row (without an intervening
forward) returns an InvalidArgument error.
CL-373.
Sourcepub fn has_pending_prefetch(&self) -> bool
pub fn has_pending_prefetch(&self) -> bool
True if a prefetch is currently pending. Primarily useful for tests and diagnostics.
Sourcepub fn module_mut(&mut self) -> &mut M
pub fn module_mut(&mut self) -> &mut M
Mutable access to the inner module.
Sourcepub fn into_inner(self) -> M
pub fn into_inner(self) -> M
Consume the wrapper and return the inner module.
Sourcepub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
Reconstruct full parameters from shards across all ranks and run the inner module’s forward pass.
The all-gathered full-parameter tensors are stored in self.full_params
so their gradients can be read after backward.
For ShardGradOp and NoShard strategies, parameters are already
full on every rank, so no all-gather happens and full_params is
populated from the current parameter tensors directly.
If prefetch_forward_params was
called earlier, the pending async all-gather handles are consumed
here instead of running the synchronous all_gather — this is how
FSDP hides all-gather latency behind whatever local compute
happened between prefetch_forward_params and forward.
Sourcepub fn sync_gradients(&mut self) -> FerrotorchResult<()>
pub fn sync_gradients(&mut self) -> FerrotorchResult<()>
Reduce-scatter gradients from the full-parameter tensors stored during forward, then set each shard parameter’s gradient.
Call this after backward() and before optimizer.step().
§How it works
- For each parameter, read the gradient from the full-param tensor
that was used during forward (stored in
self.full_params). - Reduce-scatter the full gradient across ranks (mean reduction) so each rank gets only its shard portion.
- Set the shard parameter’s
.grad()to the reduce-scattered result.
Using reduce-scatter (not allreduce) is correct for FSDP because each rank only needs its own shard of the gradient to update its shard of the parameter.
Sourcepub fn broadcast_updated_params(&mut self) -> FerrotorchResult<()>
pub fn broadcast_updated_params(&mut self) -> FerrotorchResult<()>
For ShardGradOp: after optimizer.step(), each rank has
applied the update to its own shard of the full parameter
(because sync_gradients zeroed the non-shard positions of the
gradient). This method re-syncs the parameter tensors so every
rank has the fully updated parameter, by summing contributions
via an allreduce: each rank contributes its updated shard, zero
elsewhere; the sum across ranks is the full updated parameter.
More precisely, this method reconstructs the full parameter as
an allgather of per-rank shards. It is a no-op for FullShard
and NoShard strategies (they already have consistent
parameters after step).
Call this AFTER optimizer.step() and BEFORE the next
forward(). CL-372.
Sourcepub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()>
pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()>
Update shard parameters from a flat data slice.
This is used by optimizers that produce a flat parameter buffer. The slice must have exactly the number of elements expected for this rank’s shards.
Auto Trait Implementations§
impl<M, T> Freeze for FSDP<M, T>where
M: Freeze,
impl<M, T> !RefUnwindSafe for FSDP<M, T>
impl<M, T> Send for FSDP<M, T>
impl<M, T> !Sync for FSDP<M, T>
impl<M, T> Unpin for FSDP<M, T>
impl<M, T> UnsafeUnpin for FSDP<M, T>where
M: UnsafeUnpin,
impl<M, T> !UnwindSafe for FSDP<M, T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more