use crate::symmetric::{CollectiveError, Rank, SymmetricBuffer, SymmetricTransport};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceKind {
Sum,
Mean,
Max,
Min,
}
impl ReduceKind {
fn fold(self, acc: f32, x: f32) -> f32 {
match self {
Self::Sum => acc + x,
Self::Mean => acc + x, Self::Max => acc.max(x),
Self::Min => acc.min(x),
}
}
fn finalize(self, acc: f32, n: usize) -> f32 {
match self {
Self::Mean => acc / (n as f32),
_ => acc,
}
}
fn identity(self) -> f32 {
match self {
Self::Sum | Self::Mean => 0.0,
Self::Max => f32::NEG_INFINITY,
Self::Min => f32::INFINITY,
}
}
}
pub fn all_reduce<T: SymmetricTransport>(
transport: &T,
buf: SymmetricBuffer, local: &mut [f32],
op: ReduceKind,
) -> Result<(), CollectiveError> {
let elems = buf.len / 4;
if local.len() != elems {
return Err(CollectiveError::LengthMismatch {
expected: elems,
got: local.len(),
});
}
let me = transport.this_rank();
let n = transport.num_ranks();
let our_buf = SymmetricBuffer {
rank: me,
offset: buf.offset,
len: buf.len,
};
let bytes = unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, buf.len) };
transport.put(our_buf, bytes)?;
transport.barrier()?;
let mut acc: Vec<f32> = vec![op.identity(); elems];
let mut scratch_bytes = vec![0u8; buf.len];
for r in 0..n {
let src = SymmetricBuffer {
rank: Rank(r),
offset: buf.offset,
len: buf.len,
};
transport.get(src, &mut scratch_bytes)?;
let scratch =
unsafe { std::slice::from_raw_parts(scratch_bytes.as_ptr() as *const f32, elems) };
for (i, &v) in scratch.iter().enumerate() {
acc[i] = op.fold(acc[i], v);
}
}
for v in acc.iter_mut() {
*v = op.finalize(*v, n as usize);
}
local.copy_from_slice(&acc);
Ok(())
}
pub fn all_gather<T: SymmetricTransport>(
transport: &T,
buf: SymmetricBuffer, local: &[f32],
output: &mut [f32],
) -> Result<(), CollectiveError> {
let elems_per_rank = buf.len / 4;
let n = transport.num_ranks() as usize;
if local.len() != elems_per_rank {
return Err(CollectiveError::LengthMismatch {
expected: elems_per_rank,
got: local.len(),
});
}
if output.len() != n * elems_per_rank {
return Err(CollectiveError::LengthMismatch {
expected: n * elems_per_rank,
got: output.len(),
});
}
let me = transport.this_rank();
let our_buf = SymmetricBuffer {
rank: me,
offset: buf.offset,
len: buf.len,
};
let bytes = unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, buf.len) };
transport.put(our_buf, bytes)?;
transport.barrier()?;
let mut scratch_bytes = vec![0u8; buf.len];
for r in 0..n {
let src = SymmetricBuffer {
rank: Rank(r as u32),
offset: buf.offset,
len: buf.len,
};
transport.get(src, &mut scratch_bytes)?;
let chunk = unsafe {
std::slice::from_raw_parts(scratch_bytes.as_ptr() as *const f32, elems_per_rank)
};
let dst_start = r * elems_per_rank;
output[dst_start..dst_start + elems_per_rank].copy_from_slice(chunk);
}
Ok(())
}
pub fn reduce_scatter<T: SymmetricTransport>(
transport: &T,
buf: SymmetricBuffer,
local: &[f32],
output: &mut [f32],
op: ReduceKind,
) -> Result<(), CollectiveError> {
let total = buf.len / 4;
let n = transport.num_ranks() as usize;
if !total.is_multiple_of(n) {
return Err(CollectiveError::TransportError {
reason: format!("reduce_scatter: total elements {total} not divisible by {n} ranks"),
});
}
let chunk = total / n;
if local.len() != total {
return Err(CollectiveError::LengthMismatch {
expected: total,
got: local.len(),
});
}
if output.len() != chunk {
return Err(CollectiveError::LengthMismatch {
expected: chunk,
got: output.len(),
});
}
let me = transport.this_rank().0 as usize;
let mut full = local.to_vec();
all_reduce(transport, buf, &mut full, op)?;
output.copy_from_slice(&full[me * chunk..(me + 1) * chunk]);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symmetric::LocalTransport;
#[test]
fn all_reduce_sum_across_4_ranks() {
let n_ranks = 4u32;
let elems = 4usize;
let bytes = elems * 4;
let ts = LocalTransport::fan_out(n_ranks, bytes);
let _buf = SymmetricBuffer {
rank: Rank(0),
offset: 0,
len: bytes,
};
let mut state: Vec<Vec<f32>> = (0..n_ranks).map(|r| vec![(r + 1) as f32; elems]).collect();
for (r, t) in ts.iter().enumerate() {
let our_buf = SymmetricBuffer {
rank: Rank(r as u32),
offset: 0,
len: bytes,
};
let raw = unsafe { std::slice::from_raw_parts(state[r].as_ptr() as *const u8, bytes) };
t.put(our_buf, raw).unwrap();
}
for (r, t) in ts.iter().enumerate() {
let mut acc = vec![0f32; elems];
let mut scratch = vec![0u8; bytes];
for src_r in 0..n_ranks {
let src = SymmetricBuffer {
rank: Rank(src_r),
offset: 0,
len: bytes,
};
t.get(src, &mut scratch).unwrap();
let view =
unsafe { std::slice::from_raw_parts(scratch.as_ptr() as *const f32, elems) };
for (i, &v) in view.iter().enumerate() {
acc[i] += v;
}
}
state[r] = acc;
}
for (r, slot) in state.iter().enumerate() {
assert_eq!(slot, &vec![10.0; elems], "rank {r} after all-reduce");
}
}
#[test]
fn all_gather_concatenates_in_rank_order() {
let n_ranks = 3u32;
let chunk = 2usize;
let bytes = chunk * 4;
let ts = LocalTransport::fan_out(n_ranks, bytes);
let _buf = SymmetricBuffer {
rank: Rank(0),
offset: 0,
len: bytes,
};
let local: Vec<Vec<f32>> = (0..n_ranks)
.map(|r| {
let r = r as f32;
vec![10.0 * r, 10.0 * r + 1.0]
})
.collect();
for (r, t) in ts.iter().enumerate() {
let our_buf = SymmetricBuffer {
rank: Rank(r as u32),
offset: 0,
len: bytes,
};
let raw = unsafe { std::slice::from_raw_parts(local[r].as_ptr() as *const u8, bytes) };
t.put(our_buf, raw).unwrap();
}
for (r_idx, t) in ts.iter().enumerate() {
let mut output = vec![0f32; n_ranks as usize * chunk];
let mut scratch = vec![0u8; bytes];
for src_r in 0..n_ranks {
let src = SymmetricBuffer {
rank: Rank(src_r),
offset: 0,
len: bytes,
};
t.get(src, &mut scratch).unwrap();
let view =
unsafe { std::slice::from_raw_parts(scratch.as_ptr() as *const f32, chunk) };
let dst_start = src_r as usize * chunk;
output[dst_start..dst_start + chunk].copy_from_slice(view);
}
assert_eq!(
output,
vec![0.0, 1.0, 10.0, 11.0, 20.0, 21.0],
"rank {r_idx} after all-gather"
);
}
}
#[test]
fn reduce_kind_max_takes_pointwise_max() {
let mut acc = ReduceKind::Max.identity();
for v in [3.0, 1.0, 7.0, -2.0] {
acc = ReduceKind::Max.fold(acc, v);
}
assert_eq!(acc, 7.0);
}
#[test]
fn reduce_kind_mean_divides_at_finalize() {
let mut acc = ReduceKind::Mean.identity();
for v in [2.0, 4.0, 6.0, 8.0] {
acc = ReduceKind::Mean.fold(acc, v);
}
assert_eq!(ReduceKind::Mean.finalize(acc, 4), 5.0);
}
}