use crate::device_mesh::{DType, MeshComm, ReduceOp};
use cudarc::driver::{CudaDevice, CudaSlice, DevicePtr, DevicePtrMut};
use cudarc::nccl::{Comm, Id, ReduceOp as NcclReduceOp};
use std::sync::Arc;
#[derive(Debug)]
pub struct NcclComm {
comm: Comm,
device: Arc<CudaDevice>,
pub rank: usize,
pub world_size: usize,
}
impl NcclComm {
pub fn new(
device: Arc<CudaDevice>,
nccl_id: &Id,
rank: usize,
world_size: usize,
) -> Result<Self, String> {
let comm = Comm::from_rank(device.clone(), rank, world_size, *nccl_id)
.map_err(|e| format!("Failed to create NCCL communicator: {:?}", e))?;
Ok(Self {
comm,
device,
rank,
world_size,
})
}
pub fn generate_id() -> Result<Id, String> {
Id::new().map_err(|e| format!("Failed to generate NCCL ID: {:?}", e))
}
fn dtype_size(dtype: DType) -> usize {
match dtype {
DType::Float16 => 2,
DType::Float32 => 4,
DType::Float64 => 8,
DType::Int8 => 1,
DType::Int32 => 4,
DType::Int64 => 8,
DType::UInt8 => 1,
DType::BFloat16 => 2,
}
}
fn to_nccl_op(op: ReduceOp) -> NcclReduceOp {
match op {
ReduceOp::Sum => NcclReduceOp::Sum,
ReduceOp::Product => NcclReduceOp::Prod,
ReduceOp::Min => NcclReduceOp::Min,
ReduceOp::Max => NcclReduceOp::Max,
ReduceOp::Average => NcclReduceOp::Avg,
}
}
fn host_to_device(&self, data: &[u8]) -> Result<CudaSlice<u8>, String> {
self.device
.htod_sync_copy(data)
.map_err(|e| format!("Failed to copy to device: {:?}", e))
}
fn device_to_host(
&self,
device_buf: &CudaSlice<u8>,
host_buf: &mut [u8],
) -> Result<(), String> {
self.device
.dtoh_sync_copy_into(device_buf, host_buf)
.map_err(|e| format!("Failed to copy to host: {:?}", e))
}
}
impl MeshComm for NcclComm {
fn all_reduce(
&self,
buf: &mut [u8],
dtype: DType,
op: ReduceOp,
_group: &str,
) -> Result<(), String> {
let elem_size = Self::dtype_size(dtype);
let count = buf.len() / elem_size;
let mut device_buf = self.host_to_device(buf)?;
match dtype {
DType::Float32 => {
let f32_count = buf.len() / 4;
unsafe {
let send_ptr = *device_buf.device_ptr() as *const f32;
let recv_ptr = *device_buf.device_ptr_mut() as *mut f32;
self.comm
.all_reduce(send_ptr, recv_ptr, f32_count, Self::to_nccl_op(op))
.map_err(|e| format!("NCCL all_reduce failed: {:?}", e))?;
}
}
DType::Float64 => {
let f64_count = buf.len() / 8;
unsafe {
let send_ptr = *device_buf.device_ptr() as *const f64;
let recv_ptr = *device_buf.device_ptr_mut() as *mut f64;
self.comm
.all_reduce(send_ptr, recv_ptr, f64_count, Self::to_nccl_op(op))
.map_err(|e| format!("NCCL all_reduce failed: {:?}", e))?;
}
}
DType::Int32 => {
let i32_count = buf.len() / 4;
unsafe {
let send_ptr = *device_buf.device_ptr() as *const i32;
let recv_ptr = *device_buf.device_ptr_mut() as *mut i32;
self.comm
.all_reduce(send_ptr, recv_ptr, i32_count, Self::to_nccl_op(op))
.map_err(|e| format!("NCCL all_reduce failed: {:?}", e))?;
}
}
_ => {
return Err(format!("Unsupported dtype for all_reduce: {:?}", dtype));
}
}
self.device
.synchronize()
.map_err(|e| format!("Sync failed: {:?}", e))?;
self.device_to_host(&device_buf, buf)?;
Ok(())
}
fn all_gather(
&self,
local: &[u8],
out: &mut [u8],
dtype: DType,
_group: &str,
) -> Result<(), String> {
let elem_size = Self::dtype_size(dtype);
let local_count = local.len() / elem_size;
let device_local = self.host_to_device(local)?;
let mut device_out = self.host_to_device(out)?;
match dtype {
DType::Float32 => unsafe {
let send_ptr = *device_local.device_ptr() as *const f32;
let recv_ptr = *device_out.device_ptr_mut() as *mut f32;
self.comm
.all_gather(send_ptr, recv_ptr, local_count)
.map_err(|e| format!("NCCL all_gather failed: {:?}", e))?;
},
_ => {
return Err(format!("Unsupported dtype for all_gather: {:?}", dtype));
}
}
self.device
.synchronize()
.map_err(|e| format!("Sync failed: {:?}", e))?;
self.device_to_host(&device_out, out)?;
Ok(())
}
fn broadcast(&self, buf: &mut [u8], root_rank: usize, _group: &str) -> Result<(), String> {
let mut device_buf = self.host_to_device(buf)?;
let count = buf.len();
unsafe {
let ptr = *device_buf.device_ptr_mut() as *mut u8;
self.comm
.broadcast(ptr, ptr, count, root_rank)
.map_err(|e| format!("NCCL broadcast failed: {:?}", e))?;
}
self.device
.synchronize()
.map_err(|e| format!("Sync failed: {:?}", e))?;
self.device_to_host(&device_buf, buf)?;
Ok(())
}
fn reduce_scatter(
&self,
buf: &mut [u8],
out: &mut [u8],
op: ReduceOp,
_group: &str,
) -> Result<(), String> {
let device_in = self.host_to_device(buf)?;
let mut device_out = self.host_to_device(out)?;
let out_count = out.len() / 4;
unsafe {
let send_ptr = *device_in.device_ptr() as *const f32;
let recv_ptr = *device_out.device_ptr_mut() as *mut f32;
self.comm
.reduce_scatter(send_ptr, recv_ptr, out_count, Self::to_nccl_op(op))
.map_err(|e| format!("NCCL reduce_scatter failed: {:?}", e))?;
}
self.device
.synchronize()
.map_err(|e| format!("Sync failed: {:?}", e))?;
self.device_to_host(&device_out, out)?;
Ok(())
}
fn barrier(&self, _group: &str) -> Result<(), String> {
let dummy: [f32; 1] = [0.0];
let device_dummy = self
.device
.htod_sync_copy(&dummy)
.map_err(|e| format!("Barrier failed: {:?}", e))?;
unsafe {
let ptr = *device_dummy.device_ptr() as *const f32;
let mut_ptr = *device_dummy.device_ptr() as *mut f32;
self.comm
.all_reduce(ptr, mut_ptr, 1, NcclReduceOp::Sum)
.map_err(|e| format!("Barrier all_reduce failed: {:?}", e))?;
}
self.device
.synchronize()
.map_err(|e| format!("Barrier sync failed: {:?}", e))?;
Ok(())
}
fn send(&self, buf: &[u8], dest_rank: usize) -> Result<(), String> {
let device_buf = self.host_to_device(buf)?;
unsafe {
let ptr = *device_buf.device_ptr() as *const u8;
self.comm
.send(ptr, buf.len(), dest_rank)
.map_err(|e| format!("NCCL send failed: {:?}", e))?;
}
self.device
.synchronize()
.map_err(|e| format!("Send sync failed: {:?}", e))?;
Ok(())
}
fn recv(&self, buf: &mut [u8], src_rank: usize) -> Result<(), String> {
let mut device_buf = self.host_to_device(buf)?;
unsafe {
let ptr = *device_buf.device_ptr_mut() as *mut u8;
self.comm
.recv(ptr, buf.len(), src_rank)
.map_err(|e| format!("NCCL recv failed: {:?}", e))?;
}
self.device
.synchronize()
.map_err(|e| format!("Recv sync failed: {:?}", e))?;
self.device_to_host(&device_buf, buf)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtype_size() {
assert_eq!(NcclComm::dtype_size(DType::Float32), 4);
assert_eq!(NcclComm::dtype_size(DType::Float64), 8);
assert_eq!(NcclComm::dtype_size(DType::Int32), 4);
assert_eq!(NcclComm::dtype_size(DType::Float16), 2);
assert_eq!(NcclComm::dtype_size(DType::UInt8), 1);
}
#[test]
fn test_reduce_op_conversion() {
let _ = NcclComm::to_nccl_op(ReduceOp::Sum);
let _ = NcclComm::to_nccl_op(ReduceOp::Product);
let _ = NcclComm::to_nccl_op(ReduceOp::Min);
let _ = NcclComm::to_nccl_op(ReduceOp::Max);
let _ = NcclComm::to_nccl_op(ReduceOp::Average);
}
}