#![warn(missing_debug_implementations)]
use baracuda_driver::{DeviceBuffer, Stream};
use baracuda_nccl_sys::{
nccl, ncclComm_t, ncclDataType_t, ncclRedOp_t, ncclResult_t, ncclUniqueId,
};
use baracuda_types::DeviceRepr;
pub type Error = baracuda_core::Error<ncclResult_t>;
pub type Result<T, E = Error> = core::result::Result<T, E>;
#[inline]
fn check(status: ncclResult_t) -> Result<()> {
Error::check(status)
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
pub enum RedOp {
#[default]
Sum,
Prod,
Max,
Min,
Avg,
Custom(i32),
}
impl RedOp {
fn raw(self) -> ncclRedOp_t {
match self {
RedOp::Sum => ncclRedOp_t::Sum,
RedOp::Prod => ncclRedOp_t::Prod,
RedOp::Max => ncclRedOp_t::Max,
RedOp::Min => ncclRedOp_t::Min,
RedOp::Avg => ncclRedOp_t::Avg,
RedOp::Custom(id) => ncclRedOp_t(id),
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ScalarResidence {
Host = 0,
Device = 1,
}
pub trait NcclScalar: DeviceRepr + sealed::Sealed {
#[doc(hidden)]
fn raw() -> ncclDataType_t;
}
macro_rules! impl_nccl_scalar {
($ty:ty, $variant:ident) => {
impl NcclScalar for $ty {
fn raw() -> ncclDataType_t {
ncclDataType_t::$variant
}
}
impl sealed::Sealed for $ty {}
};
}
impl_nccl_scalar!(i8, Int8);
impl_nccl_scalar!(u8, Uint8);
impl_nccl_scalar!(i32, Int32);
impl_nccl_scalar!(u32, Uint32);
impl_nccl_scalar!(i64, Int64);
impl_nccl_scalar!(u64, Uint64);
impl_nccl_scalar!(f32, Float32);
impl_nccl_scalar!(f64, Float64);
#[cfg(feature = "half-crate")]
impl_nccl_scalar!(half::f16, Float16);
#[cfg(feature = "half-crate")]
impl_nccl_scalar!(half::bf16, BFloat16);
mod sealed {
pub trait Sealed {}
}
#[cfg(all(test, feature = "half-crate"))]
mod half_scalar_tests {
use super::*;
#[test]
fn half_types_are_nccl_scalars() {
fn require_scalar<T: NcclScalar>() -> ncclDataType_t {
T::raw()
}
assert_eq!(
require_scalar::<half::f16>(),
ncclDataType_t::Float16,
"half::f16 must map to ncclFloat16"
);
assert_eq!(
require_scalar::<half::bf16>(),
ncclDataType_t::BFloat16,
"half::bf16 must map to ncclBfloat16"
);
}
}
#[derive(Copy, Clone, Debug)]
pub struct UniqueId(ncclUniqueId);
impl UniqueId {
pub fn new() -> Result<Self> {
let n = nccl()?;
let cu = n.nccl_get_unique_id()?;
let mut id = ncclUniqueId::default();
check(unsafe { cu(&mut id) })?;
Ok(Self(id))
}
pub fn as_bytes(&self) -> [u8; 128] {
let mut out = [0u8; 128];
for (o, b) in out.iter_mut().zip(&self.0.internal) {
*o = *b as u8;
}
out
}
pub fn from_bytes(bytes: [u8; 128]) -> Self {
let mut id = ncclUniqueId::default();
for (i, b) in id.internal.iter_mut().zip(&bytes) {
*i = *b as i8;
}
Self(id)
}
}
pub struct Communicator {
handle: ncclComm_t,
}
unsafe impl Send for Communicator {}
impl core::fmt::Debug for Communicator {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("nccl::Communicator")
.field("handle", &self.handle)
.finish()
}
}
impl Communicator {
pub fn init_all(devices: &[i32]) -> Result<Vec<Self>> {
let n = nccl()?;
let cu = n.nccl_comm_init_all()?;
let ndev = devices.len() as core::ffi::c_int;
let mut comms = vec![core::ptr::null_mut::<core::ffi::c_void>(); devices.len()];
check(unsafe { cu(comms.as_mut_ptr(), ndev, devices.as_ptr()) })?;
Ok(comms.into_iter().map(|handle| Self { handle }).collect())
}
pub fn init_rank(nranks: i32, id: UniqueId, rank: i32) -> Result<Self> {
let n = nccl()?;
let cu = n.nccl_comm_init_rank()?;
let mut handle: ncclComm_t = core::ptr::null_mut();
check(unsafe { cu(&mut handle, nranks, id.0, rank) })?;
Ok(Self { handle })
}
pub unsafe fn init_rank_config(
nranks: i32,
id: UniqueId,
rank: i32,
config: *mut core::ffi::c_void,
) -> Result<Self> {
let n = nccl()?;
let cu = n.nccl_comm_init_rank_config()?;
let mut handle: ncclComm_t = core::ptr::null_mut();
check(cu(&mut handle, nranks, id.0, rank, config))?;
Ok(Self { handle })
}
pub fn nranks(&self) -> Result<i32> {
let n = nccl()?;
let cu = n.nccl_comm_count()?;
let mut c: core::ffi::c_int = 0;
check(unsafe { cu(self.handle, &mut c) })?;
Ok(c)
}
pub fn rank(&self) -> Result<i32> {
let n = nccl()?;
let cu = n.nccl_comm_user_rank()?;
let mut r: core::ffi::c_int = 0;
check(unsafe { cu(self.handle, &mut r) })?;
Ok(r)
}
#[inline]
pub fn as_raw(&self) -> ncclComm_t {
self.handle
}
}
impl Drop for Communicator {
fn drop(&mut self) {
if let Ok(n) = nccl() {
if let Ok(cu) = n.nccl_comm_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn all_reduce<T: NcclScalar>(
send: &DeviceBuffer<T>,
recv: &mut DeviceBuffer<T>,
count: usize,
op: RedOp,
comm: &Communicator,
stream: &Stream,
) -> Result<()> {
assert!(send.len() >= count && recv.len() >= count);
let n = nccl()?;
let cu = n.nccl_all_reduce()?;
check(unsafe {
cu(
send.as_raw().0 as *const core::ffi::c_void,
recv.as_raw().0 as *mut core::ffi::c_void,
count,
T::raw(),
op.raw(),
comm.handle,
stream.as_raw() as _,
)
})
}
pub fn broadcast<T: NcclScalar>(
send: &DeviceBuffer<T>,
recv: &mut DeviceBuffer<T>,
count: usize,
root: i32,
comm: &Communicator,
stream: &Stream,
) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_broadcast()?;
check(unsafe {
cu(
send.as_raw().0 as *const core::ffi::c_void,
recv.as_raw().0 as *mut core::ffi::c_void,
count,
T::raw(),
root,
comm.handle,
stream.as_raw() as _,
)
})
}
pub fn group_start() -> Result<()> {
let n = nccl()?;
let cu = n.nccl_group_start()?;
check(unsafe { cu() })
}
pub fn group_end() -> Result<()> {
let n = nccl()?;
let cu = n.nccl_group_end()?;
check(unsafe { cu() })
}
pub fn version() -> Result<i32> {
let n = nccl()?;
let cu = n.nccl_get_version()?;
let mut v: core::ffi::c_int = 0;
check(unsafe { cu(&mut v) })?;
Ok(v)
}
pub fn error_string(status: ncclResult_t) -> Result<&'static str> {
let n = nccl()?;
let cu = n.nccl_get_error_string()?;
let p = unsafe { cu(status) };
if p.is_null() {
return Ok("unknown");
}
Ok(unsafe { core::ffi::CStr::from_ptr(p) }
.to_str()
.unwrap_or("unknown"))
}
impl Communicator {
pub fn reduce<T: NcclScalar>(
&self,
sendbuf: &DeviceBuffer<T>,
recvbuf: &mut DeviceBuffer<T>,
count: usize,
op: RedOp,
root: i32,
stream: &Stream,
) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_reduce()?;
check(unsafe {
cu(
sendbuf.as_raw().0 as *const core::ffi::c_void,
recvbuf.as_raw().0 as *mut core::ffi::c_void,
count,
T::raw(),
op.raw(),
root,
self.handle,
stream.as_raw(),
)
})
}
pub fn all_gather<T: NcclScalar>(
&self,
sendbuf: &DeviceBuffer<T>,
recvbuf: &mut DeviceBuffer<T>,
sendcount: usize,
stream: &Stream,
) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_all_gather()?;
check(unsafe {
cu(
sendbuf.as_raw().0 as *const core::ffi::c_void,
recvbuf.as_raw().0 as *mut core::ffi::c_void,
sendcount,
T::raw(),
self.handle,
stream.as_raw(),
)
})
}
pub fn reduce_scatter<T: NcclScalar>(
&self,
sendbuf: &DeviceBuffer<T>,
recvbuf: &mut DeviceBuffer<T>,
recvcount: usize,
op: RedOp,
stream: &Stream,
) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_reduce_scatter()?;
check(unsafe {
cu(
sendbuf.as_raw().0 as *const core::ffi::c_void,
recvbuf.as_raw().0 as *mut core::ffi::c_void,
recvcount,
T::raw(),
op.raw(),
self.handle,
stream.as_raw(),
)
})
}
pub fn send<T: NcclScalar>(
&self,
sendbuf: &DeviceBuffer<T>,
count: usize,
peer: i32,
stream: &Stream,
) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_send()?;
check(unsafe {
cu(
sendbuf.as_raw().0 as *const core::ffi::c_void,
count,
T::raw(),
peer,
self.handle,
stream.as_raw(),
)
})
}
pub fn recv<T: NcclScalar>(
&self,
recvbuf: &mut DeviceBuffer<T>,
count: usize,
peer: i32,
stream: &Stream,
) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_recv()?;
check(unsafe {
cu(
recvbuf.as_raw().0 as *mut core::ffi::c_void,
count,
T::raw(),
peer,
self.handle,
stream.as_raw(),
)
})
}
pub fn abort(&self) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_comm_abort()?;
check(unsafe { cu(self.handle) })
}
pub fn finalize(&self) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_comm_finalize()?;
check(unsafe { cu(self.handle) })
}
pub fn get_async_error(&self) -> Result<ncclResult_t> {
let n = nccl()?;
let cu = n.nccl_comm_get_async_error()?;
let mut s = ncclResult_t::Success;
check(unsafe { cu(self.handle, &mut s) })?;
Ok(s)
}
pub fn cuda_device(&self) -> Result<i32> {
let n = nccl()?;
let cu = n.nccl_comm_cu_device()?;
let mut d: core::ffi::c_int = 0;
check(unsafe { cu(self.handle, &mut d) })?;
Ok(d)
}
pub fn split(&self, color: i32, key: i32) -> Result<Communicator> {
let n = nccl()?;
let cu = n.nccl_comm_split()?;
let mut new_comm: ncclComm_t = core::ptr::null_mut();
check(unsafe { cu(self.handle, color, key, &mut new_comm, core::ptr::null_mut()) })?;
Ok(Communicator { handle: new_comm })
}
pub unsafe fn register(
&self,
dev_ptr: *mut core::ffi::c_void,
size: usize,
) -> Result<*mut core::ffi::c_void> {
let n = nccl()?;
let cu = n.nccl_comm_register()?;
let mut handle: *mut core::ffi::c_void = core::ptr::null_mut();
check(cu(self.handle, dev_ptr, size, &mut handle))?;
Ok(handle)
}
pub unsafe fn deregister(&self, handle: *mut core::ffi::c_void) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_comm_deregister()?;
check(cu(self.handle, handle))
}
pub unsafe fn create_pre_mul_sum<T: NcclScalar>(
&self,
scalar: *mut core::ffi::c_void,
residence: ScalarResidence,
) -> Result<RedOp> {
let n = nccl()?;
let cu = n.nccl_red_op_create_pre_mul_sum()?;
let mut op = ncclRedOp_t(0);
check(cu(&mut op, scalar, T::raw(), residence as i32, self.handle))?;
Ok(RedOp::Custom(op.0))
}
pub fn destroy_red_op(&self, op: RedOp) -> Result<()> {
let n = nccl()?;
let cu = n.nccl_red_op_destroy()?;
check(unsafe { cu(op.raw(), self.handle) })
}
pub fn last_error(&self) -> Result<&'static str> {
let n = nccl()?;
let cu = n.nccl_get_last_error()?;
let p = unsafe { cu(self.handle) };
if p.is_null() {
return Ok("unknown");
}
Ok(unsafe { core::ffi::CStr::from_ptr(p) }
.to_str()
.unwrap_or("unknown"))
}
}
#[derive(Debug)]
pub struct NcclMem {
ptr: *mut core::ffi::c_void,
}
impl NcclMem {
pub fn new(size: usize) -> Result<Self> {
let n = nccl()?;
let cu = n.nccl_mem_alloc()?;
let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
check(unsafe { cu(&mut p, size) })?;
Ok(Self { ptr: p })
}
#[inline]
pub fn as_raw(&self) -> *mut core::ffi::c_void {
self.ptr
}
}
impl Drop for NcclMem {
fn drop(&mut self) {
if let Ok(n) = nccl() {
if let Ok(cu) = n.nccl_mem_free() {
let _ = unsafe { cu(self.ptr) };
}
}
}
}