use oxicuda_driver::{CudaError, CudaResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReduceOp {
Sum,
Product,
Min,
Max,
Avg,
}
impl ReduceOp {
pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
match self {
Self::Sum | Self::Avg => a + b,
Self::Product => a * b,
Self::Min => a.min(b),
Self::Max => a.max(b),
}
}
pub fn apply_f64(&self, a: f64, b: f64) -> f64 {
match self {
Self::Sum | Self::Avg => a + b,
Self::Product => a * b,
Self::Min => a.min(b),
Self::Max => a.max(b),
}
}
pub fn identity_f32(&self) -> f32 {
match self {
Self::Sum | Self::Avg => 0.0,
Self::Product => 1.0,
Self::Min => f32::INFINITY,
Self::Max => f32::NEG_INFINITY,
}
}
pub fn identity_f64(&self) -> f64 {
match self {
Self::Sum | Self::Avg => 0.0,
Self::Product => 1.0,
Self::Min => f64::INFINITY,
Self::Max => f64::NEG_INFINITY,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DataType {
F32,
F64,
I32,
I64,
U32,
U64,
F16,
BF16,
}
impl DataType {
pub fn size_bytes(&self) -> usize {
match self {
Self::F16 | Self::BF16 => 2,
Self::F32 | Self::I32 | Self::U32 => 4,
Self::F64 | Self::I64 | Self::U64 => 8,
}
}
}
#[derive(Debug, Clone)]
pub struct Communicator {
devices: Vec<i32>,
my_rank: usize,
}
impl Communicator {
pub fn new(device_ordinals: &[i32]) -> CudaResult<Self> {
if device_ordinals.is_empty() {
return Err(CudaError::InvalidValue);
}
Ok(Self {
devices: device_ordinals.to_vec(),
my_rank: 0,
})
}
pub fn with_rank(device_ordinals: &[i32], rank: usize) -> CudaResult<Self> {
if device_ordinals.is_empty() || rank >= device_ordinals.len() {
return Err(CudaError::InvalidValue);
}
Ok(Self {
devices: device_ordinals.to_vec(),
my_rank: rank,
})
}
pub fn rank(&self) -> usize {
self.my_rank
}
pub fn world_size(&self) -> usize {
self.devices.len()
}
pub fn device_ordinal(&self, rank: usize) -> Option<i32> {
self.devices.get(rank).copied()
}
}
#[derive(Debug, Clone)]
pub struct CollectiveConfig {
pub stream: Option<usize>,
pub async_op: bool,
pub chunk_size: Option<usize>,
}
#[allow(clippy::derivable_impls)]
impl Default for CollectiveConfig {
fn default() -> Self {
Self {
stream: None,
async_op: false,
chunk_size: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AllReduceAlgorithm {
Ring,
Tree,
RecursiveHalving,
Auto,
}
impl AllReduceAlgorithm {
fn select(world_size: usize, msg_len: usize) -> Self {
if msg_len < 256 {
Self::Tree
} else if world_size.is_power_of_two() && msg_len < 4096 {
Self::RecursiveHalving
} else {
Self::Ring
}
}
}
pub struct RingAllReduce;
impl RingAllReduce {
pub fn execute(buffers: &mut [Vec<f32>], op: ReduceOp) -> CudaResult<()> {
let n_ranks = buffers.len();
if n_ranks < 2 {
return Ok(());
}
let buf_len = buffers[0].len();
for b in buffers.iter() {
if b.len() != buf_len {
return Err(CudaError::InvalidValue);
}
}
if buf_len == 0 {
return Ok(());
}
let chunk_size = buf_len.div_ceil(n_ranks);
for step in 0..n_ranks - 1 {
let sends: Vec<(usize, usize, Vec<f32>)> = (0..n_ranks)
.map(|rank| {
let send_chunk = (rank + n_ranks - 1 - step) % n_ranks;
let start = send_chunk * chunk_size;
let end = (start + chunk_size).min(buf_len);
let data = buffers[rank][start..end].to_vec();
(rank, send_chunk, data)
})
.collect();
for (rank, send_chunk, data) in sends {
let recv_rank = (rank + 1) % n_ranks;
let start = send_chunk * chunk_size;
let end = (start + chunk_size).min(buf_len);
for (i, idx) in (start..end).enumerate() {
buffers[recv_rank][idx] = op.apply_f32(buffers[recv_rank][idx], data[i]);
}
}
}
for step in 0..n_ranks - 1 {
let sends: Vec<(usize, Vec<f32>)> = (0..n_ranks)
.map(|rank| {
let send_chunk = (rank + n_ranks - step) % n_ranks;
let start = send_chunk * chunk_size;
let end = (start + chunk_size).min(buf_len);
let data = buffers[rank][start..end].to_vec();
(send_chunk, data)
})
.collect();
for (rank, (send_chunk, data)) in sends.into_iter().enumerate() {
let recv_rank = (rank + 1) % n_ranks;
let start = send_chunk * chunk_size;
let end = (start + chunk_size).min(buf_len);
buffers[recv_rank][start..end].copy_from_slice(&data[..end - start]);
}
}
if op == ReduceOp::Avg {
let divisor = n_ranks as f32;
for buf in buffers.iter_mut() {
for v in buf.iter_mut() {
*v /= divisor;
}
}
}
Ok(())
}
}
pub struct TreeAllReduce;
impl TreeAllReduce {
pub fn execute(buffers: &mut [Vec<f32>], op: ReduceOp) -> CudaResult<()> {
let n_ranks = buffers.len();
if n_ranks < 2 {
return Ok(());
}
let buf_len = buffers[0].len();
for b in buffers.iter() {
if b.len() != buf_len {
return Err(CudaError::InvalidValue);
}
}
let mut stride = 1;
while stride < n_ranks {
let mut rank = 0;
while rank + stride < n_ranks {
let child = rank + stride;
let child_data = buffers[child].clone();
for (i, v) in child_data.into_iter().enumerate() {
buffers[rank][i] = op.apply_f32(buffers[rank][i], v);
}
rank += stride * 2;
}
stride *= 2;
}
if op == ReduceOp::Avg {
let divisor = n_ranks as f32;
for v in buffers[0].iter_mut() {
*v /= divisor;
}
}
let root_data = buffers[0].clone();
for buf in buffers.iter_mut().skip(1) {
buf.copy_from_slice(&root_data);
}
Ok(())
}
}
pub struct CollectiveOps;
impl CollectiveOps {
pub fn all_reduce(
sendbuf: &[f32],
recvbuf: &mut [f32],
op: ReduceOp,
comm: &Communicator,
algo: AllReduceAlgorithm,
) -> CudaResult<()> {
if sendbuf.len() != recvbuf.len() {
return Err(CudaError::InvalidValue);
}
let n = comm.world_size();
let mut buffers: Vec<Vec<f32>> = (0..n).map(|_| sendbuf.to_vec()).collect();
let resolved = match algo {
AllReduceAlgorithm::Auto => AllReduceAlgorithm::select(n, sendbuf.len()),
other => other,
};
match resolved {
AllReduceAlgorithm::Ring => RingAllReduce::execute(&mut buffers, op)?,
AllReduceAlgorithm::Tree => TreeAllReduce::execute(&mut buffers, op)?,
AllReduceAlgorithm::RecursiveHalving => {
Self::recursive_halving(&mut buffers, op)?;
}
AllReduceAlgorithm::Auto => unreachable!(),
}
recvbuf.copy_from_slice(&buffers[0]);
Ok(())
}
fn recursive_halving(buffers: &mut [Vec<f32>], op: ReduceOp) -> CudaResult<()> {
let n = buffers.len();
if n < 2 {
return Ok(());
}
if !n.is_power_of_two() {
return RingAllReduce::execute(buffers, op);
}
let buf_len = buffers[0].len();
let mut distance = 1;
while distance < n {
let pairs: Vec<(usize, Vec<f32>)> = (0..n)
.map(|rank| {
let partner = rank ^ distance;
(partner, buffers[rank].clone())
})
.collect();
for (rank, (partner, data)) in pairs.into_iter().enumerate() {
if partner < n {
for i in 0..buf_len {
buffers[rank][i] = op.apply_f32(buffers[rank][i], data[i]);
}
}
}
distance *= 2;
}
if op == ReduceOp::Avg {
let divisor = n as f32;
for buf in buffers.iter_mut() {
for v in buf.iter_mut() {
*v /= divisor;
}
}
}
Ok(())
}
pub fn all_gather(sendbuf: &[f32], recvbuf: &mut [f32], comm: &Communicator) -> CudaResult<()> {
let n = comm.world_size();
let send_len = sendbuf.len();
if recvbuf.len() != send_len * n {
return Err(CudaError::InvalidValue);
}
for rank in 0..n {
let start = rank * send_len;
recvbuf[start..start + send_len].copy_from_slice(sendbuf);
}
Ok(())
}
pub fn reduce_scatter(
sendbuf: &[f32],
recvbuf: &mut [f32],
op: ReduceOp,
comm: &Communicator,
) -> CudaResult<()> {
let n = comm.world_size();
let total = sendbuf.len();
let chunk = total / n;
if chunk == 0 || recvbuf.len() != chunk {
return Err(CudaError::InvalidValue);
}
let my_rank = comm.rank();
let start = my_rank * chunk;
let end = start + chunk;
for (i, idx) in (start..end).enumerate() {
let mut acc = sendbuf[idx];
for _ in 1..n {
acc = op.apply_f32(acc, sendbuf[idx]);
}
if op == ReduceOp::Avg {
acc /= n as f32;
}
recvbuf[i] = acc;
}
Ok(())
}
pub fn broadcast(_buf: &mut [f32], root: usize, comm: &Communicator) -> CudaResult<()> {
if root >= comm.world_size() {
return Err(CudaError::InvalidValue);
}
Ok(())
}
pub fn reduce(
sendbuf: &[f32],
recvbuf: &mut [f32],
op: ReduceOp,
root: usize,
comm: &Communicator,
) -> CudaResult<()> {
if root >= comm.world_size() {
return Err(CudaError::InvalidValue);
}
if sendbuf.len() != recvbuf.len() {
return Err(CudaError::InvalidValue);
}
let n = comm.world_size();
for (i, &v) in sendbuf.iter().enumerate() {
let mut acc = v;
for _ in 1..n {
acc = op.apply_f32(acc, v);
}
if op == ReduceOp::Avg {
acc /= n as f32;
}
recvbuf[i] = acc;
}
Ok(())
}
pub fn all_to_all(sendbuf: &[f32], recvbuf: &mut [f32], comm: &Communicator) -> CudaResult<()> {
let n = comm.world_size();
let total = sendbuf.len();
if total != recvbuf.len() {
return Err(CudaError::InvalidValue);
}
let chunk = total / n;
if chunk == 0 {
return Err(CudaError::InvalidValue);
}
let my_rank = comm.rank();
let src_start = my_rank * chunk;
for r in 0..n {
let dst_start = r * chunk;
recvbuf[dst_start..dst_start + chunk]
.copy_from_slice(&sendbuf[src_start..src_start + chunk]);
}
Ok(())
}
}
pub struct CommGroup;
impl CommGroup {
pub fn world() -> Communicator {
Communicator {
devices: vec![0, 1, 2, 3],
my_rank: 0,
}
}
pub fn split(comm: &Communicator, color: usize, rank: usize) -> CudaResult<Communicator> {
if rank >= comm.world_size() {
return Err(CudaError::InvalidValue);
}
let step = color.max(1);
let sub: Vec<i32> = comm
.devices
.iter()
.copied()
.skip(rank % step)
.step_by(step)
.collect();
if sub.is_empty() {
return Err(CudaError::InvalidValue);
}
Ok(Communicator {
devices: sub,
my_rank: 0,
})
}
pub fn dup(comm: &Communicator) -> Communicator {
Communicator {
devices: comm.devices.clone(),
my_rank: comm.my_rank,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reduce_op_sum() {
assert!((ReduceOp::Sum.apply_f32(3.0, 4.0) - 7.0).abs() < f32::EPSILON);
assert!((ReduceOp::Sum.identity_f32()).abs() < f32::EPSILON);
}
#[test]
fn reduce_op_product() {
assert!((ReduceOp::Product.apply_f32(3.0, 4.0) - 12.0).abs() < f32::EPSILON);
assert!((ReduceOp::Product.identity_f32() - 1.0).abs() < f32::EPSILON);
}
#[test]
fn reduce_op_min_max() {
assert!((ReduceOp::Min.apply_f32(3.0, 4.0) - 3.0).abs() < f32::EPSILON);
assert!((ReduceOp::Max.apply_f32(3.0, 4.0) - 4.0).abs() < f32::EPSILON);
assert!(ReduceOp::Min.identity_f32().is_infinite());
assert!(ReduceOp::Max.identity_f32().is_infinite());
}
#[test]
fn reduce_op_avg() {
assert!((ReduceOp::Avg.apply_f64(3.0, 4.0) - 7.0).abs() < f64::EPSILON);
assert!((ReduceOp::Avg.identity_f64()).abs() < f64::EPSILON);
}
#[test]
fn reduce_op_f64() {
assert!((ReduceOp::Product.apply_f64(2.5, 4.0) - 10.0).abs() < f64::EPSILON);
assert!((ReduceOp::Min.apply_f64(1.0, 2.0) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn data_type_sizes() {
assert_eq!(DataType::F16.size_bytes(), 2);
assert_eq!(DataType::BF16.size_bytes(), 2);
assert_eq!(DataType::F32.size_bytes(), 4);
assert_eq!(DataType::I32.size_bytes(), 4);
assert_eq!(DataType::U32.size_bytes(), 4);
assert_eq!(DataType::F64.size_bytes(), 8);
assert_eq!(DataType::I64.size_bytes(), 8);
assert_eq!(DataType::U64.size_bytes(), 8);
}
#[test]
fn communicator_basics() {
let comm = Communicator::new(&[0, 1, 2]).expect("create comm");
assert_eq!(comm.rank(), 0);
assert_eq!(comm.world_size(), 3);
assert_eq!(comm.device_ordinal(1), Some(1));
assert_eq!(comm.device_ordinal(5), None);
}
#[test]
fn communicator_empty_rejected() {
assert!(Communicator::new(&[]).is_err());
}
#[test]
fn communicator_with_rank() {
let comm = Communicator::with_rank(&[10, 20, 30], 2).expect("rank 2");
assert_eq!(comm.rank(), 2);
assert_eq!(comm.device_ordinal(0), Some(10));
}
#[test]
fn ring_all_reduce_2_ranks_sum() {
let mut bufs = vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]];
RingAllReduce::execute(&mut bufs, ReduceOp::Sum).expect("ring 2");
let expected = vec![6.0, 8.0, 10.0, 12.0];
for (a, b) in bufs[0].iter().zip(&expected) {
assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
}
assert_eq!(bufs[0], bufs[1]);
}
#[test]
fn ring_all_reduce_4_ranks_sum() {
let mut bufs = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 2.0, 0.0, 0.0],
vec![0.0, 0.0, 3.0, 0.0],
vec![0.0, 0.0, 0.0, 4.0],
];
RingAllReduce::execute(&mut bufs, ReduceOp::Sum).expect("ring 4");
let expected = vec![1.0, 2.0, 3.0, 4.0];
for buf in &bufs {
for (a, b) in buf.iter().zip(&expected) {
assert!((a - b).abs() < 1e-5);
}
}
}
#[test]
fn tree_all_reduce_sum() {
let mut bufs = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
TreeAllReduce::execute(&mut bufs, ReduceOp::Sum).expect("tree");
let expected = vec![9.0, 12.0];
for buf in &bufs {
for (a, b) in buf.iter().zip(&expected) {
assert!((a - b).abs() < 1e-5);
}
}
}
#[test]
fn tree_all_reduce_product() {
let mut bufs = vec![vec![2.0, 3.0], vec![4.0, 5.0]];
TreeAllReduce::execute(&mut bufs, ReduceOp::Product).expect("tree prod");
assert!((bufs[0][0] - 8.0).abs() < 1e-5);
assert!((bufs[0][1] - 15.0).abs() < 1e-5);
assert_eq!(bufs[0], bufs[1]);
}
#[test]
fn all_gather_correctness() {
let comm = Communicator::new(&[0, 1, 2]).expect("comm");
let send = [10.0, 20.0];
let mut recv = vec![0.0f32; 6];
CollectiveOps::all_gather(&send, &mut recv, &comm).expect("all_gather");
assert_eq!(recv, vec![10.0, 20.0, 10.0, 20.0, 10.0, 20.0]);
}
#[test]
fn reduce_scatter_correctness() {
let comm = Communicator::new(&[0, 1]).expect("comm");
let send = [1.0, 2.0, 3.0, 4.0];
let mut recv = vec![0.0f32; 2];
CollectiveOps::reduce_scatter(&send, &mut recv, ReduceOp::Sum, &comm)
.expect("reduce_scatter");
assert!((recv[0] - 2.0).abs() < 1e-5);
assert!((recv[1] - 4.0).abs() < 1e-5);
}
#[test]
fn broadcast_from_root() {
let comm = Communicator::new(&[0, 1, 2]).expect("comm");
let mut buf = [42.0f32, 99.0];
CollectiveOps::broadcast(&mut buf, 0, &comm).expect("broadcast");
assert_eq!(buf, [42.0, 99.0]);
}
#[test]
fn broadcast_invalid_root() {
let comm = Communicator::new(&[0]).expect("comm");
let mut buf = [1.0f32];
assert!(CollectiveOps::broadcast(&mut buf, 5, &comm).is_err());
}
#[test]
fn reduce_to_root() {
let comm = Communicator::new(&[0, 1, 2, 3]).expect("comm");
let send = [1.0f32, 2.0];
let mut recv = vec![0.0f32; 2];
CollectiveOps::reduce(&send, &mut recv, ReduceOp::Sum, 0, &comm).expect("reduce");
assert!((recv[0] - 4.0).abs() < 1e-5);
assert!((recv[1] - 8.0).abs() < 1e-5);
}
#[test]
fn all_to_all_exchange() {
let comm = Communicator::new(&[0, 1]).expect("comm");
let send = [1.0f32, 2.0, 3.0, 4.0];
let mut recv = vec![0.0f32; 4];
CollectiveOps::all_to_all(&send, &mut recv, &comm).expect("a2a");
assert_eq!(recv, vec![1.0, 2.0, 1.0, 2.0]);
}
#[test]
fn comm_group_world() {
let w = CommGroup::world();
assert_eq!(w.world_size(), 4);
assert_eq!(w.rank(), 0);
}
#[test]
fn comm_group_dup() {
let comm = Communicator::with_rank(&[0, 1, 2], 1).expect("comm");
let dup = CommGroup::dup(&comm);
assert_eq!(dup.world_size(), comm.world_size());
assert_eq!(dup.rank(), comm.rank());
}
#[test]
fn comm_group_split() {
let comm = Communicator::new(&[0, 1, 2, 3]).expect("comm");
let sub = CommGroup::split(&comm, 2, 0).expect("split");
assert_eq!(sub.world_size(), 2);
}
#[test]
fn all_reduce_auto_algorithm() {
let comm = Communicator::new(&[0, 1]).expect("comm");
let send = vec![1.0f32; 8];
let mut recv = vec![0.0f32; 8];
CollectiveOps::all_reduce(
&send,
&mut recv,
ReduceOp::Sum,
&comm,
AllReduceAlgorithm::Auto,
)
.expect("auto");
for v in &recv {
assert!((*v - 2.0).abs() < 1e-5);
}
}
#[test]
fn collective_config_defaults() {
let cfg = CollectiveConfig::default();
assert!(cfg.stream.is_none());
assert!(!cfg.async_op);
assert!(cfg.chunk_size.is_none());
}
}