pub trait TensorParallel: Send + Sync {
// Required methods
fn world_size(&self) -> usize;
fn rank(&self) -> usize;
fn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>;
fn all_gather(
&self,
local: &Tensor,
output: &mut Tensor,
) -> BackendResult<()>;
fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>;
fn barrier(&self) -> BackendResult<()>;
}Expand description
Trait for tensor parallel communication primitives
Required Methods§
Sourcefn world_size(&self) -> usize
fn world_size(&self) -> usize
Number of devices (world size)
Sourcefn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>
fn all_reduce_sum(&self, tensor: &mut Tensor) -> BackendResult<()>
All-reduce sum: sum tensor across all devices in-place
Sourcefn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()>
fn all_gather(&self, local: &Tensor, output: &mut Tensor) -> BackendResult<()>
All-gather: gather local tensors from all devices into output output is world_size * local_size
Sourcefn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>
fn scatter(&self, input: &Tensor, output: &mut Tensor) -> BackendResult<()>
Scatter: split input across devices, each gets 1/world_size
Sourcefn barrier(&self) -> BackendResult<()>
fn barrier(&self) -> BackendResult<()>
Barrier: synchronize all devices
Dyn Compatibility§
This trait is dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".