Skip to main content

DDP

Struct DDP 

Source
pub struct DDP<M: Module<T>, T: Float> { /* private fields */ }
Expand description

Distributed Data Parallel module wrapper.

Wraps an inner Module and provides [sync_gradients] to allreduce parameter gradients across all ranks. Parameters are grouped into buckets (default 25 MB) for efficient communication.

let ddp = DDP::new(model, backend);

loop {
    let output = ddp.module().forward(&input)?;
    let loss = criterion.forward(&output, &target)?;
    ferrotorch_core::backward(&loss)?;
    ddp.sync_gradients()?;
    optimizer.step()?;
    optimizer.zero_grad()?;
}

Implementations§

Source§

impl<M: Module<T>, T: Float> DDP<M, T>

Source

pub fn new(module: M, backend: Arc<dyn Backend>) -> Self

Wrap a module for distributed data-parallel training.

Parameters are assigned to ~25 MB gradient buckets in reverse order (matching PyTorch’s convention — backward computes gradients in reverse parameter order, so the first bucket fills first).

Source

pub fn with_bucket_size( module: M, backend: Arc<dyn Backend>, bucket_size_bytes: usize, ) -> Self

Wrap a module with a custom bucket size (in bytes).

Source

pub fn module(&self) -> &M

Immutable access to the inner module (for forward pass, etc.).

Source

pub fn module_mut(&mut self) -> &mut M

Mutable access to the inner module (for train/eval mode, etc.).

Source

pub fn into_inner(self) -> M

Consume the DDP wrapper and return the inner module.

Source

pub fn backend(&self) -> &Arc<dyn Backend>

The backend used for communication.

Source

pub fn sync_gradients(&self) -> FerrotorchResult<()>

Allreduce parameter gradients across ranks using gradient bucketing.

Parameters are grouped into ~25 MB buckets. Each bucket is allreduced independently as a single flat buffer. This enables future overlapped communication where the first bucket can start transferring while backward is still computing later gradients.

Call this after backward() and before optimizer.step().

Source

pub fn overlapped_sync_gradients(&self) -> FerrotorchResult<()>

Allreduce parameter gradients with bucket-level parallelism.

Like [sync_gradients], but processes all buckets concurrently using std::thread::scope. Each bucket’s allreduce runs in its own thread, overlapping communication across buckets. All threads complete before this method returns.

This provides communication/computation overlap when backward and sync run on different threads, and communication overlap across buckets even in the synchronous case.

Source

pub fn broadcast_parameters(&mut self, root: usize) -> FerrotorchResult<()>

Broadcast model parameters from root rank to all other ranks.

Ensures all ranks start with identical weights. Call once before the training loop begins.

§Warning

This replaces the Parameter objects in the module. Any optimizer that holds references to the old parameters must be re-initialized after calling this method, otherwise optimizer state (momentum, adaptive learning rates, etc.) will refer to stale parameters.

Trait Implementations§

Source§

impl<M: Module<T>, T: Float> Module<T> for DDP<M, T>

Source§

fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>

Forward pass. Takes input tensor, returns output tensor.
Source§

fn parameters(&self) -> Vec<&Parameter<T>>

Iterate over all learnable parameters.
Source§

fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>

Iterate over all learnable parameters mutably.
Source§

fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>

Named parameters for state dict serialization. Read more
Source§

fn train(&mut self)

Set training mode. Affects dropout, batchnorm, etc.
Source§

fn eval(&mut self)

Set evaluation mode.
Source§

fn is_training(&self) -> bool

Whether the module is in training mode.
Source§

fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>

Move all parameters to a device. Read more
Source§

fn state_dict(&self) -> HashMap<String, Tensor<T>>

Export parameters as a state dict.
Source§

fn load_state_dict( &mut self, state: &HashMap<String, Tensor<T>>, strict: bool, ) -> Result<(), FerrotorchError>

Load parameters from a state dict. Read more

Auto Trait Implementations§

§

impl<M, T> Freeze for DDP<M, T>
where M: Freeze,

§

impl<M, T> !RefUnwindSafe for DDP<M, T>

§

impl<M, T> Send for DDP<M, T>

§

impl<M, T> Sync for DDP<M, T>

§

impl<M, T> Unpin for DDP<M, T>
where M: Unpin, T: Unpin,

§

impl<M, T> UnsafeUnpin for DDP<M, T>
where M: UnsafeUnpin,

§

impl<M, T> !UnwindSafe for DDP<M, T>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,