use super::sys::{self};
use std::mem::MaybeUninit;
#[derive(Clone, PartialEq, Eq)]
pub struct NcclError(pub sys::ncclResult_t);
impl std::fmt::Debug for NcclError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NcclError")
}
}
#[derive(Clone, PartialEq, Eq)]
pub enum NcclStatus {
Success,
InProgress,
NumResults,
}
impl sys::ncclResult_t {
pub fn result(self) -> Result<NcclStatus, NcclError> {
match self {
sys::ncclResult_t::ncclSuccess => Ok(NcclStatus::Success),
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070"
)))]
sys::ncclResult_t::ncclInProgress => Ok(NcclStatus::InProgress),
sys::ncclResult_t::ncclNumResults => Ok(NcclStatus::NumResults),
_ => Err(NcclError(self)),
}
}
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070"
)))]
pub unsafe fn comm_finalize(comm: sys::ncclComm_t) -> Result<NcclStatus, NcclError> {
sys::ncclCommFinalize(comm).result()
}
pub unsafe fn comm_destroy(comm: sys::ncclComm_t) -> Result<NcclStatus, NcclError> {
sys::ncclCommDestroy(comm).result()
}
pub unsafe fn comm_abort(comm: sys::ncclComm_t) -> Result<NcclStatus, NcclError> {
sys::ncclCommAbort(comm).result()
}
pub fn get_nccl_version() -> Result<::core::ffi::c_int, NcclError> {
let mut version: ::core::ffi::c_int = 0;
unsafe {
sys::ncclGetVersion(&mut version).result()?;
}
Ok(version)
}
pub fn get_uniqueid() -> Result<sys::ncclUniqueId, NcclError> {
let mut uniqueid = MaybeUninit::uninit();
Ok(unsafe {
sys::ncclGetUniqueId(uniqueid.as_mut_ptr()).result()?;
uniqueid.assume_init()
})
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070"
)))]
pub unsafe fn comm_init_rank_config(
comm: *mut sys::ncclComm_t,
nranks: ::core::ffi::c_int,
comm_id: sys::ncclUniqueId,
rank: ::core::ffi::c_int,
config: *mut sys::ncclConfig_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclCommInitRankConfig(comm, nranks, comm_id, rank, config).result()
}
pub unsafe fn comm_init_rank(
comm: *mut sys::ncclComm_t,
nranks: ::core::ffi::c_int,
comm_id: sys::ncclUniqueId,
rank: ::core::ffi::c_int,
) -> Result<NcclStatus, NcclError> {
sys::ncclCommInitRank(comm, nranks, comm_id, rank).result()
}
pub unsafe fn comm_init_all(
comm: *mut sys::ncclComm_t,
ndev: ::core::ffi::c_int,
devlist: *const ::core::ffi::c_int,
) -> Result<NcclStatus, NcclError> {
sys::ncclCommInitAll(comm, ndev, devlist).result()
}
#[cfg(any(
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050"
))]
pub unsafe fn comm_split(
comm: sys::ncclComm_t,
color: ::core::ffi::c_int,
key: ::core::ffi::c_int,
newcomm: *mut sys::ncclComm_t,
config: *mut sys::ncclConfig_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclCommSplit(comm, color, key, newcomm, config).result()
}
pub unsafe fn comm_count(comm: sys::ncclComm_t) -> Result<::core::ffi::c_int, NcclError> {
let mut count = 0;
sys::ncclCommCount(comm, &mut count).result()?;
Ok(count)
}
pub unsafe fn comm_cu_device(comm: sys::ncclComm_t) -> Result<::core::ffi::c_int, NcclError> {
let mut device = 0;
sys::ncclCommCuDevice(comm, &mut device).result()?;
Ok(device)
}
pub unsafe fn comm_user_rank(comm: sys::ncclComm_t) -> Result<::core::ffi::c_int, NcclError> {
let mut rank = 0;
sys::ncclCommUserRank(comm, &mut rank).result()?;
Ok(rank)
}
pub unsafe fn reduce_op_create_pre_mul_sum(
op: *mut sys::ncclRedOp_t,
scalar: *mut ::core::ffi::c_void,
datatype: sys::ncclDataType_t,
residence: sys::ncclScalarResidence_t,
comm: sys::ncclComm_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclRedOpCreatePreMulSum(op, scalar, datatype, residence, comm).result()
}
pub unsafe fn reduce_op_destroy(
op: sys::ncclRedOp_t,
comm: sys::ncclComm_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclRedOpDestroy(op, comm).result()
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn reduce(
sendbuff: *const ::core::ffi::c_void,
recvbuff: *mut ::core::ffi::c_void,
count: usize,
datatype: sys::ncclDataType_t,
op: sys::ncclRedOp_t,
root: ::core::ffi::c_int,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclReduce(sendbuff, recvbuff, count, datatype, op, root, comm, stream).result()
}
pub unsafe fn broadcast(
sendbuff: *const ::core::ffi::c_void,
recvbuff: *mut ::core::ffi::c_void,
count: usize,
datatype: sys::ncclDataType_t,
root: ::core::ffi::c_int,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclBroadcast(sendbuff, recvbuff, count, datatype, root, comm, stream).result()
}
pub unsafe fn all_reduce(
sendbuff: *const ::core::ffi::c_void,
recvbuff: *mut ::core::ffi::c_void,
count: usize,
datatype: sys::ncclDataType_t,
op: sys::ncclRedOp_t,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclAllReduce(sendbuff, recvbuff, count, datatype, op, comm, stream).result()
}
pub unsafe fn reduce_scatter(
sendbuff: *const ::core::ffi::c_void,
recvbuff: *mut ::core::ffi::c_void,
recvcount: usize,
datatype: sys::ncclDataType_t,
op: sys::ncclRedOp_t,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclReduceScatter(sendbuff, recvbuff, recvcount, datatype, op, comm, stream).result()
}
pub unsafe fn all_gather(
sendbuff: *const ::core::ffi::c_void,
recvbuff: *mut ::core::ffi::c_void,
sendcount: usize,
datatype: sys::ncclDataType_t,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclAllGather(sendbuff, recvbuff, sendcount, datatype, comm, stream).result()
}
pub unsafe fn send(
sendbuff: *const ::core::ffi::c_void,
count: usize,
datatype: sys::ncclDataType_t,
peer: ::core::ffi::c_int,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclSend(sendbuff, count, datatype, peer, comm, stream).result()
}
pub unsafe fn recv(
recvbuff: *mut ::core::ffi::c_void,
count: usize,
datatype: sys::ncclDataType_t,
peer: ::core::ffi::c_int,
comm: sys::ncclComm_t,
stream: sys::cudaStream_t,
) -> Result<NcclStatus, NcclError> {
sys::ncclRecv(recvbuff, count, datatype, peer, comm, stream).result()
}
pub fn group_end() -> Result<NcclStatus, NcclError> {
unsafe { sys::ncclGroupEnd().result() }
}
pub fn group_start() -> Result<NcclStatus, NcclError> {
unsafe { sys::ncclGroupStart().result() }
}
#[cfg(test)]
mod tests {
use crate::driver::CudaContext;
use super::*;
use std::{ffi::c_void, vec, vec::Vec};
#[test]
fn single_thread() {
let n_devices = CudaContext::device_count().unwrap() as usize;
let n = 2;
let mut streams = vec![];
let mut sendslices = vec![];
let mut recvslices = vec![];
for i in 0..n_devices {
let ctx = CudaContext::new(i).unwrap();
let stream = ctx.default_stream();
let slice = stream.clone_htod(&vec![(i + 1) as f32 * 1.0; n]).unwrap();
sendslices.push(slice);
let slice = stream.alloc_zeros::<f32>(n).unwrap();
recvslices.push(slice);
streams.push(stream);
}
let mut comms = vec![std::ptr::null_mut(); n_devices];
let ordinals: Vec<_> = streams
.iter()
.map(|d| d.context().ordinal() as i32)
.collect();
unsafe {
comm_init_all(comms.as_mut_ptr(), n_devices as i32, ordinals.as_ptr()).unwrap();
group_start().unwrap();
for i in 0..n_devices {
let ctx = CudaContext::new(i).unwrap();
let stream = ctx.default_stream();
all_reduce(
sendslices[i].cu_device_ptr as *const c_void,
recvslices[i].cu_device_ptr as *mut c_void,
n,
sys::ncclDataType_t::ncclFloat32,
sys::ncclRedOp_t::ncclSum,
comms[i],
stream.cu_stream as sys::cudaStream_t,
)
.unwrap();
}
group_end().unwrap();
}
for (i, recv) in recvslices.iter().enumerate() {
let ctx = CudaContext::new(i).unwrap();
let stream = ctx.default_stream();
let out = stream.clone_dtoh(recv).unwrap();
assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
}
}
#[test]
fn multi_thread() {
let n_devices = CudaContext::device_count().unwrap() as usize;
let n = 2;
let comm_id = get_uniqueid().unwrap();
let threads: Vec<_> = (0..n_devices)
.map(|i| {
std::thread::spawn(move || {
let ctx = CudaContext::new(i).unwrap();
let stream = ctx.default_stream();
let sendslice = stream.clone_htod(&vec![(i + 1) as f32 * 1.0; n]).unwrap();
let recvslice = stream.alloc_zeros::<f32>(n).unwrap();
let mut comm = MaybeUninit::uninit();
unsafe {
comm_init_rank(comm.as_mut_ptr(), n_devices as i32, comm_id, i as i32)
.unwrap();
let comm = comm.assume_init();
use std::ffi::c_void;
all_reduce(
sendslice.cu_device_ptr as *const c_void,
recvslice.cu_device_ptr as *mut c_void,
n,
sys::ncclDataType_t::ncclFloat32,
sys::ncclRedOp_t::ncclSum,
comm,
stream.cu_stream as sys::cudaStream_t,
)
.unwrap();
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
}
}