Skip to main content

DDP

Struct DDP 

Source
pub struct DDP<M: Module<T>, T: Float> { /* private fields */ }
Expand description

Distributed Data Parallel module wrapper.

Wraps an inner Module and provides [sync_gradients] to allreduce parameter gradients across all ranks. Parameters are grouped into buckets (default 25 MB) for efficient communication.

let ddp = DDP::new(model, backend);

loop {
    let output = ddp.module().forward(&input)?;
    let loss = criterion.forward(&output, &target)?;
    ferrotorch_core::backward(&loss)?;
    ddp.sync_gradients()?;
    optimizer.step()?;
    optimizer.zero_grad()?;
}

Implementations§

Source§

impl<M: Module<T>, T: Float> DDP<M, T>

Source

pub fn new(module: M, backend: Arc<dyn Backend>) -> Self

Wrap a module for distributed data-parallel training.

Parameters are assigned to ~25 MB gradient buckets in reverse order (matching PyTorch’s convention — backward computes gradients in reverse parameter order, so the first bucket fills first).

Source

pub fn with_bucket_size( module: M, backend: Arc<dyn Backend>, bucket_size_bytes: usize, ) -> Self

Wrap a module with a custom bucket size (in bytes).

Source

pub fn module(&self) -> &M

Immutable access to the inner module (for forward pass, etc.).

Source

pub fn module_mut(&mut self) -> &mut M

Mutable access to the inner module (for train/eval mode, etc.).

Source

pub fn into_inner(self) -> M

Consume the DDP wrapper and return the inner module.

Source

pub fn backend(&self) -> &Arc<dyn Backend>

The backend used for communication.

Source

pub fn sync_gradients(&self) -> FerrotorchResult<()>

Allreduce parameter gradients across ranks using gradient bucketing.

Parameters are grouped into ~25 MB buckets. Each bucket is allreduced independently as a single flat buffer. This enables future overlapped communication where the first bucket can start transferring while backward is still computing later gradients.

Call this after backward() and before optimizer.step().

Source

pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()>

Allreduce parameter gradients with bucket-level parallelism.

Like [sync_gradients], but processes all buckets concurrently using std::thread::scope. Each bucket’s allreduce runs in its own thread, overlapping communication across buckets. All threads complete before this method returns.

This provides communication/computation overlap when backward and sync run on different threads, and communication overlap across buckets even in the synchronous case.

Source

pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()>

Broadcast model parameters from root rank to all other ranks.

Ensures all ranks start with identical weights. Call once before the training loop begins.

§Warning

This replaces the Parameter objects in the module. Any optimizer that holds references to the old parameters must be re-initialized after calling this method, otherwise optimizer state (momentum, adaptive learning rates, etc.) will refer to stale parameters.

Trait Implementations§

Source§

impl<M: Module<T>, T: Float> Module<T> for DDP<M, T>

Source§

fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>

Forward pass. Takes input tensor, returns output tensor.
Source§

fn parameters(&self) -> Vec<&Parameter<T>>

Iterate over all learnable parameters.
Source§

fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>

Iterate over all learnable parameters mutably.
Source§

fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>

Named parameters for state dict serialization. Read more
Source§

fn train(&mut self)

Set training mode. Affects dropout, batchnorm, etc.
Source§

fn eval(&mut self)

Set evaluation mode.
Source§

fn is_training(&self) -> bool

Whether the module is in training mode.
Source§

fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>

Move all parameters and buffers to a device. Read more
Source§

fn state_dict(&self) -> HashMap<String, Tensor<T>>

Export parameters and buffers as a state dict (torch parity). Read more
Source§

fn buffers(&self) -> Vec<&Buffer<T>>

Iterate over all non-trainable buffers (e.g. running mean / variance in BatchNorm). Default returns empty — concrete modules with buffers override.
Source§

fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>

Mutable iteration over all buffers. Default returns empty.
Source§

fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>

Named buffers (dot-separated paths for nested modules). Default returns empty.
Source§

fn as_any(&self) -> Option<&(dyn Any + 'static)>

Downcast hook for type-erased buffer-loader dispatch. (#984) Read more
Source§

fn children(&self) -> Vec<&dyn Module<T>>

Direct child modules. Default returns empty (leaf module).
Source§

fn named_children(&self) -> Vec<(String, &dyn Module<T>)>

Direct child modules with their attribute names. Default returns empty.
Source§

fn modules(&self) -> Vec<&dyn Module<T>>
where Self: Sized,

All modules in this subtree, depth-first (self first, then each child’s descendants in order). Read more
Source§

fn descendants_dyn(&self) -> Vec<&dyn Module<T>>

All strict descendants of self in depth-first order. Object-safe.
Source§

fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
where Self: Sized,

All modules in this subtree with dot-separated path names. The root is named ""; children paths are joined with ..
Source§

fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>

Strict descendants with dot-paths. Object-safe.
Source§

fn with_forward_hook( self, hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward hook. Returns the wrapper paired with a HookHandle that can be used to remove the hook later. The wrapper implements Module<T> itself, so it slots into any place the original module did. Mirrors torch.nn.Module.register_forward_hook.
Source§

fn with_forward_pre_hook( self, hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Sync + Send>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward pre-hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_forward_pre_hook.
Source§

fn with_backward_hook( self, hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a backward hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_backward_hook.
Source§

fn zero_grad(&self) -> Result<(), FerrotorchError>

Set the gradient of every parameter to None. Read more
Source§

fn requires_grad_(&mut self, requires_grad: bool)

Toggle requires_grad on every parameter (freeze / unfreeze the module). Mirrors torch.nn.Module.requires_grad_.
Source§

fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))

Apply a function to every parameter in this module. Mirrors torch.nn.Module.apply for the parameter case (true apply recurses over all submodules; the recursive form requires &mut dyn Module which conflicts with this trait’s &mut self borrow). Read more
Source§

fn load_state_dict( &mut self, state: &HashMap<String, Tensor<T>>, strict: bool, ) -> Result<(), FerrotorchError>

Load parameters from a state dict. Read more

Auto Trait Implementations§

§

impl<M, T> Freeze for DDP<M, T>
where M: Freeze,

§

impl<M, T> !RefUnwindSafe for DDP<M, T>

§

impl<M, T> Send for DDP<M, T>

§

impl<M, T> Sync for DDP<M, T>

§

impl<M, T> Unpin for DDP<M, T>
where M: Unpin, T: Unpin,

§

impl<M, T> UnsafeUnpin for DDP<M, T>
where M: UnsafeUnpin,

§

impl<M, T> !UnwindSafe for DDP<M, T>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,