Skip to main content

TensorParallel

Trait TensorParallel 

Source
pub trait TensorParallel: Send + Sync {
    // Required methods
    fn world_size(&self) -> usize;
    fn rank(&self) -> usize;
    fn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>;
    fn all_gather(
        &self,
        local: &Tensor,
        output: &mut Tensor,
    ) -> BackendResult<()>;
    fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>;
    fn barrier(&self) -> BackendResult<()>;
}
Expand description

Trait for tensor parallel communication primitives

Required Methods§

Source

fn world_size(&self) -> usize

Number of devices (world size)

Source

fn rank(&self) -> usize

This device’s rank (0-indexed)

Source

fn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>

All-reduce sum: sum tensor across all devices in-place

Source

fn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()>

All-gather: gather local tensors from all devices into output output is world_size * local_size

Source

fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>

Scatter: split input across devices, each gets 1/world_size

Source

fn barrier(&self) -> BackendResult<()>

Barrier: synchronize all devices

Dyn Compatibility§

This trait is dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§