pub struct DistributedOps<'a> { /* private fields */ }Expand description
Distributed tensor operations
Implementations§
Source§impl<'a> DistributedOps<'a>
impl<'a> DistributedOps<'a>
Sourcepub fn new(mesh: &'a DeviceMesh) -> Self
pub fn new(mesh: &'a DeviceMesh) -> Self
Create a new distributed operations context
Sourcepub fn all_reduce_f32(
&self,
data: &mut [f32],
op: ReduceOp,
group: &str,
) -> DistResult<()>
pub fn all_reduce_f32( &self, data: &mut [f32], op: ReduceOp, group: &str, ) -> DistResult<()>
All-reduce a f32 tensor across all ranks in a group
After this operation, all ranks will have the same reduced values.
Sourcepub fn all_reduce_i32(
&self,
data: &mut [i32],
op: ReduceOp,
group: &str,
) -> DistResult<()>
pub fn all_reduce_i32( &self, data: &mut [i32], op: ReduceOp, group: &str, ) -> DistResult<()>
All-reduce a i32 tensor across all ranks in a group
Sourcepub fn broadcast_f32(
&self,
data: &mut [f32],
root_rank: usize,
group: &str,
) -> DistResult<()>
pub fn broadcast_f32( &self, data: &mut [f32], root_rank: usize, group: &str, ) -> DistResult<()>
Broadcast f32 tensor from root rank to all ranks in group
Sourcepub fn all_gather_f32(
&self,
local: &[f32],
output: &mut [f32],
group: &str,
) -> DistResult<()>
pub fn all_gather_f32( &self, local: &[f32], output: &mut [f32], group: &str, ) -> DistResult<()>
All-gather f32 tensors from all ranks
Each rank contributes local.len() elements, and receives
local.len() * world_size elements in the output.
Sourcepub fn scatter_f32(
&self,
data: &[f32],
chunk: &mut [f32],
root_rank: usize,
) -> DistResult<()>
pub fn scatter_f32( &self, data: &[f32], chunk: &mut [f32], root_rank: usize, ) -> DistResult<()>
Scatter a tensor: divide data among ranks
Only the root rank’s data is used for input. After this operation,
each rank’s chunk will contain its portion of the data.
Sourcepub fn gather_f32(
&self,
local: &[f32],
output: &mut [f32],
root_rank: usize,
) -> DistResult<()>
pub fn gather_f32( &self, local: &[f32], output: &mut [f32], root_rank: usize, ) -> DistResult<()>
Gather tensors from all ranks to root
Each rank’s local data is gathered to the root rank’s output.
Only the root rank’s output will contain the complete gathered data.
Sourcepub fn reduce_scatter_f32(
&self,
data: &mut [f32],
output: &mut [f32],
op: ReduceOp,
group: &str,
) -> DistResult<()>
pub fn reduce_scatter_f32( &self, data: &mut [f32], output: &mut [f32], op: ReduceOp, group: &str, ) -> DistResult<()>
Reduce-scatter: reduce and distribute results
Combines reduction and scatter in one operation. After this operation, each rank has a portion of the reduced result.
Sourcepub fn barrier(&self, group: &str) -> DistResult<()>
pub fn barrier(&self, group: &str) -> DistResult<()>
Barrier synchronization across all ranks in a group