pub struct DDP<M: Module<T>, T: Float> { /* private fields */ }Expand description
Distributed Data Parallel module wrapper.
Wraps an inner Module and provides [sync_gradients] to allreduce
parameter gradients across all ranks. Parameters are grouped into
buckets (default 25 MB) for efficient communication.
let ddp = DDP::new(model, backend);
loop {
let output = ddp.module().forward(&input)?;
let loss = criterion.forward(&output, &target)?;
ferrotorch_core::backward(&loss)?;
ddp.sync_gradients()?;
optimizer.step()?;
optimizer.zero_grad()?;
}Implementations§
Source§impl<M: Module<T>, T: Float> DDP<M, T>
impl<M: Module<T>, T: Float> DDP<M, T>
Sourcepub fn new(module: M, backend: Arc<dyn Backend>) -> Self
pub fn new(module: M, backend: Arc<dyn Backend>) -> Self
Wrap a module for distributed data-parallel training.
Parameters are assigned to ~25 MB gradient buckets in reverse order (matching PyTorch’s convention — backward computes gradients in reverse parameter order, so the first bucket fills first).
Sourcepub fn with_bucket_size(
module: M,
backend: Arc<dyn Backend>,
bucket_size_bytes: usize,
) -> Self
pub fn with_bucket_size( module: M, backend: Arc<dyn Backend>, bucket_size_bytes: usize, ) -> Self
Wrap a module with a custom bucket size (in bytes).
Sourcepub fn module_mut(&mut self) -> &mut M
pub fn module_mut(&mut self) -> &mut M
Mutable access to the inner module (for train/eval mode, etc.).
Sourcepub fn into_inner(self) -> M
pub fn into_inner(self) -> M
Consume the DDP wrapper and return the inner module.
Sourcepub fn sync_gradients(&self) -> FerrotorchResult<()>
pub fn sync_gradients(&self) -> FerrotorchResult<()>
Allreduce parameter gradients across ranks using gradient bucketing.
Parameters are grouped into ~25 MB buckets. Each bucket is allreduced independently as a single flat buffer. This enables future overlapped communication where the first bucket can start transferring while backward is still computing later gradients.
Call this after backward() and before optimizer.step().
Sourcepub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()>
pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()>
Allreduce parameter gradients with bucket-level parallelism.
Like [sync_gradients], but processes all buckets concurrently using
std::thread::scope. Each bucket’s allreduce runs in its own thread,
overlapping communication across buckets. All threads complete before
this method returns.
This provides communication/computation overlap when backward and sync run on different threads, and communication overlap across buckets even in the synchronous case.
Sourcepub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()>
pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()>
Broadcast model parameters from root rank to all other ranks.
Ensures all ranks start with identical weights. Call once before the training loop begins.
§Warning
This replaces the Parameter objects in the module. Any optimizer
that holds references to the old parameters must be re-initialized
after calling this method, otherwise optimizer state (momentum,
adaptive learning rates, etc.) will refer to stale parameters.
Trait Implementations§
Source§impl<M: Module<T>, T: Float> Module<T> for DDP<M, T>
impl<M: Module<T>, T: Float> Module<T> for DDP<M, T>
Source§fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
Source§fn parameters(&self) -> Vec<&Parameter<T>>
fn parameters(&self) -> Vec<&Parameter<T>>
Source§fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
Source§fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>
Source§fn is_training(&self) -> bool
fn is_training(&self) -> bool
Source§fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>
fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>
Source§fn state_dict(&self) -> HashMap<String, Tensor<T>>
fn state_dict(&self) -> HashMap<String, Tensor<T>>
Source§fn buffers(&self) -> Vec<&Buffer<T>>
fn buffers(&self) -> Vec<&Buffer<T>>
Source§fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>
fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>
Source§fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>
fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>
Source§fn as_any(&self) -> Option<&(dyn Any + 'static)>
fn as_any(&self) -> Option<&(dyn Any + 'static)>
Source§fn children(&self) -> Vec<&dyn Module<T>>
fn children(&self) -> Vec<&dyn Module<T>>
Source§fn named_children(&self) -> Vec<(String, &dyn Module<T>)>
fn named_children(&self) -> Vec<(String, &dyn Module<T>)>
Source§fn modules(&self) -> Vec<&dyn Module<T>>where
Self: Sized,
fn modules(&self) -> Vec<&dyn Module<T>>where
Self: Sized,
Source§fn descendants_dyn(&self) -> Vec<&dyn Module<T>>
fn descendants_dyn(&self) -> Vec<&dyn Module<T>>
self in depth-first order. Object-safe.Source§fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>where
Self: Sized,
fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>where
Self: Sized,
""; children paths are joined with ..Source§fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>
fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>
Source§fn with_forward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_forward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
HookedModule and register a forward hook.
Returns the wrapper paired with a HookHandle that can be used to
remove the hook later. The wrapper implements Module<T> itself, so
it slots into any place the original module did. Mirrors
torch.nn.Module.register_forward_hook.Source§fn with_forward_pre_hook(
self,
hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Sync + Send>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_forward_pre_hook(
self,
hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Sync + Send>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
HookedModule and register a forward
pre-hook. See Self::with_forward_hook. Mirrors
torch.nn.Module.register_forward_pre_hook.Source§fn with_backward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_backward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
HookedModule and register a backward hook.
See Self::with_forward_hook. Mirrors
torch.nn.Module.register_backward_hook.Source§fn zero_grad(&self) -> Result<(), FerrotorchError>
fn zero_grad(&self) -> Result<(), FerrotorchError>
None. Read moreSource§fn requires_grad_(&mut self, requires_grad: bool)
fn requires_grad_(&mut self, requires_grad: bool)
requires_grad on every parameter (freeze / unfreeze the
module). Mirrors torch.nn.Module.requires_grad_.Source§fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))
fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))
torch.nn.Module.apply for the parameter case (true apply recurses
over all submodules; the recursive form requires &mut dyn Module
which conflicts with this trait’s &mut self borrow). Read moreAuto Trait Implementations§
impl<M, T> Freeze for DDP<M, T>where
M: Freeze,
impl<M, T> !RefUnwindSafe for DDP<M, T>
impl<M, T> Send for DDP<M, T>
impl<M, T> Sync for DDP<M, T>
impl<M, T> Unpin for DDP<M, T>
impl<M, T> UnsafeUnpin for DDP<M, T>where
M: UnsafeUnpin,
impl<M, T> !UnwindSafe for DDP<M, T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more