Skip to main content

reduce_scatter

Function reduce_scatter 

Source
pub fn reduce_scatter<T>(
    transport: &T,
    buf: SymmetricBuffer,
    local: &[f32],
    output: &mut [f32],
    op: ReduceKind,
) -> Result<(), CollectiveError>
Expand description

ReduceScatter: equivalent to AllReduce followed by partition — every rank ends up with one chunk_size-element slice of the reduced result. Rank r gets element indices [r*chunk_size, (r+1)*chunk_size).

local.len() is the full vector (num_ranks * chunk_size); output.len() is chunk_size.