pub struct FullyShardedDataParallel<M: Module> { /* private fields */ }Expand description
Fully Sharded Data Parallel wrapper for memory-efficient distributed training.
FSDP shards model parameters across devices, gathering them only when needed for computation and sharding them again afterward.
Implementations§
Source§impl<M: Module> FullyShardedDataParallel<M>
impl<M: Module> FullyShardedDataParallel<M>
Sourcepub fn new(module: M, process_group: ProcessGroup) -> Self
pub fn new(module: M, process_group: ProcessGroup) -> Self
Creates a new FSDP wrapper.
Sourcepub fn sharding_strategy(self, strategy: ShardingStrategy) -> Self
pub fn sharding_strategy(self, strategy: ShardingStrategy) -> Self
Builder: set sharding strategy.
Sourcepub fn cpu_offload(self, offload: CPUOffload) -> Self
pub fn cpu_offload(self, offload: CPUOffload) -> Self
Builder: set CPU offload configuration.
Sourcepub fn mixed_precision(self, enabled: bool) -> Self
pub fn mixed_precision(self, enabled: bool) -> Self
Builder: enable mixed precision.
Sourcepub fn module_mut(&mut self) -> &mut M
pub fn module_mut(&mut self) -> &mut M
Returns mutable reference to wrapped module.
Sourcepub fn process_group(&self) -> &ProcessGroup
pub fn process_group(&self) -> &ProcessGroup
Returns the process group.
Sourcepub fn strategy(&self) -> ShardingStrategy
pub fn strategy(&self) -> ShardingStrategy
Returns the sharding strategy.
Sourcepub fn gather_parameters(&mut self)
pub fn gather_parameters(&mut self)
Gathers all parameter shards before forward pass.
Sourcepub fn reshard_parameters(&mut self)
pub fn reshard_parameters(&mut self)
Shards parameters after forward/backward pass.
Sourcepub fn sync_gradients(&self)
pub fn sync_gradients(&self)
Synchronizes gradients across all ranks.
Sourcepub fn clip_grad_norm(&self, max_norm: f32) -> f32
pub fn clip_grad_norm(&self, max_norm: f32) -> f32
Clips gradients by global norm.
Sourcepub fn memory_estimate(&self) -> FSDPMemoryStats
pub fn memory_estimate(&self) -> FSDPMemoryStats
Estimates memory usage with different sharding strategies.
Trait Implementations§
Source§impl<M: Module> Module for FullyShardedDataParallel<M>
impl<M: Module> Module for FullyShardedDataParallel<M>
Source§fn is_training(&self) -> bool
fn is_training(&self) -> bool
Returns whether the module is in training mode.
Source§fn named_parameters(&self) -> HashMap<String, Parameter>
fn named_parameters(&self) -> HashMap<String, Parameter>
Returns named parameters of this module.
Source§fn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Returns the number of trainable parameters.
Source§fn set_training(&mut self, _training: bool)
fn set_training(&mut self, _training: bool)
Sets the training mode.
Auto Trait Implementations§
impl<M> Freeze for FullyShardedDataParallel<M>where
M: Freeze,
impl<M> !RefUnwindSafe for FullyShardedDataParallel<M>
impl<M> Send for FullyShardedDataParallel<M>
impl<M> Sync for FullyShardedDataParallel<M>
impl<M> Unpin for FullyShardedDataParallel<M>where
M: Unpin,
impl<M> !UnwindSafe for FullyShardedDataParallel<M>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more