use std::ffi::c_void;
use std::sync::OnceLock;
pub type NcclComm = *mut c_void;
#[repr(C)]
#[derive(Clone, Copy)]
pub struct NcclUniqueId {
pub internal: [u8; 128],
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NcclDataType {
Int8 = 0,
Uint8 = 1,
Int32 = 2,
Uint32 = 3,
Int64 = 4,
Uint64 = 5,
Float16 = 6,
Float32 = 7,
Float64 = 8,
Bfloat16 = 9,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NcclRedOp {
Sum = 0,
Prod = 1,
Max = 2,
Min = 3,
Avg = 4,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NcclResult {
Success = 0,
UnhandledCudaError = 1,
SystemError = 2,
InternalError = 3,
InvalidArgument = 4,
InvalidUsage = 5,
RemoteError = 6,
InProgress = 7,
NumResults = 8,
}
impl NcclResult {
pub fn ok(self) -> Result<(), NcclError> {
if self == NcclResult::Success {
Ok(())
} else {
Err(NcclError::NcclStatus(self))
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum NcclError {
#[error("NCCL library not found — install libnccl2 or set LD_LIBRARY_PATH")]
LibraryNotFound,
#[error("NCCL symbol not found: {0}")]
SymbolNotFound(String),
#[error("NCCL error: {0:?}")]
NcclStatus(NcclResult),
}
#[allow(non_snake_case)]
struct NcclFunctions {
ncclGetUniqueId: unsafe extern "C" fn(*mut NcclUniqueId) -> NcclResult,
ncclCommInitRank: unsafe extern "C" fn(*mut NcclComm, i32, NcclUniqueId, i32) -> NcclResult,
ncclCommDestroy: unsafe extern "C" fn(NcclComm) -> NcclResult,
ncclAllReduce: unsafe extern "C" fn(
*const c_void,
*mut c_void,
usize,
NcclDataType,
NcclRedOp,
NcclComm,
*mut c_void, // cudaStream_t
) -> NcclResult,
ncclBroadcast: unsafe extern "C" fn(
*const c_void,
*mut c_void,
usize,
NcclDataType,
i32, // root
NcclComm,
*mut c_void,
) -> NcclResult,
ncclAllGather: unsafe extern "C" fn(
*const c_void,
*mut c_void,
usize, // sendcount
NcclDataType,
NcclComm,
*mut c_void,
) -> NcclResult,
ncclReduceScatter: unsafe extern "C" fn(
*const c_void,
*mut c_void,
usize, // recvcount
NcclDataType,
NcclRedOp,
NcclComm,
*mut c_void,
) -> NcclResult,
ncclSend: unsafe extern "C" fn(
*const c_void,
usize,
NcclDataType,
i32, // peer
NcclComm,
*mut c_void,
) -> NcclResult,
ncclRecv: unsafe extern "C" fn(
*mut c_void,
usize,
NcclDataType,
i32, // peer
NcclComm,
*mut c_void,
) -> NcclResult,
ncclGroupStart: unsafe extern "C" fn() -> NcclResult,
ncclGroupEnd: unsafe extern "C" fn() -> NcclResult,
}
unsafe impl Send for NcclFunctions {}
unsafe impl Sync for NcclFunctions {}
static NCCL_LIB: OnceLock<Result<NcclFunctions, NcclError>> = OnceLock::new();
fn load_nccl() -> Result<NcclFunctions, NcclError> {
let lib_names = [
"libnccl.so.2",
"libnccl.so",
"/usr/lib/x86_64-linux-gnu/libnccl.so.2",
"/usr/local/cuda/lib64/libnccl.so.2",
];
let mut lib_handle: *mut c_void = std::ptr::null_mut();
for name in &lib_names {
let c_name = std::ffi::CString::new(*name).unwrap();
lib_handle = unsafe { libc::dlopen(c_name.as_ptr(), libc::RTLD_LAZY) };
if !lib_handle.is_null() {
break;
}
}
if lib_handle.is_null() {
return Err(NcclError::LibraryNotFound);
}
macro_rules! load_sym {
($name:ident) => {{
let c_name = std::ffi::CString::new(stringify!($name)).unwrap();
let ptr = unsafe { libc::dlsym(lib_handle, c_name.as_ptr()) };
if ptr.is_null() {
return Err(NcclError::SymbolNotFound(stringify!($name).into()));
}
unsafe { std::mem::transmute(ptr) }
}};
}
Ok(NcclFunctions {
ncclGetUniqueId: load_sym!(ncclGetUniqueId),
ncclCommInitRank: load_sym!(ncclCommInitRank),
ncclCommDestroy: load_sym!(ncclCommDestroy),
ncclAllReduce: load_sym!(ncclAllReduce),
ncclBroadcast: load_sym!(ncclBroadcast),
ncclAllGather: load_sym!(ncclAllGather),
ncclReduceScatter: load_sym!(ncclReduceScatter),
ncclSend: load_sym!(ncclSend),
ncclRecv: load_sym!(ncclRecv),
ncclGroupStart: load_sym!(ncclGroupStart),
ncclGroupEnd: load_sym!(ncclGroupEnd),
})
}
fn nccl() -> Result<&'static NcclFunctions, NcclError> {
NCCL_LIB
.get_or_init(load_nccl)
.as_ref()
.map_err(|e| match e {
NcclError::LibraryNotFound => NcclError::LibraryNotFound,
NcclError::SymbolNotFound(s) => NcclError::SymbolNotFound(s.clone()),
NcclError::NcclStatus(s) => NcclError::NcclStatus(*s),
})
}
pub fn get_unique_id() -> Result<NcclUniqueId, NcclError> {
let lib = nccl()?;
let mut id = NcclUniqueId {
internal: [0u8; 128],
};
unsafe { (lib.ncclGetUniqueId)(&mut id) }.ok()?;
Ok(id)
}
pub fn comm_init_rank(
world_size: i32,
rank: i32,
unique_id: NcclUniqueId,
) -> Result<NcclComm, NcclError> {
let lib = nccl()?;
let mut comm: NcclComm = std::ptr::null_mut();
unsafe { (lib.ncclCommInitRank)(&mut comm, world_size, unique_id, rank) }.ok()?;
Ok(comm)
}
pub unsafe fn comm_destroy(comm: NcclComm) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclCommDestroy)(comm).ok()
}
pub unsafe fn all_reduce(
sendbuf: *const c_void,
recvbuf: *mut c_void,
count: usize,
datatype: NcclDataType,
op: NcclRedOp,
comm: NcclComm,
stream: *mut c_void,
) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclAllReduce)(sendbuf, recvbuf, count, datatype, op, comm, stream).ok()
}
pub unsafe fn broadcast(
sendbuf: *const c_void,
recvbuf: *mut c_void,
count: usize,
datatype: NcclDataType,
root: i32,
comm: NcclComm,
stream: *mut c_void,
) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclBroadcast)(sendbuf, recvbuf, count, datatype, root, comm, stream).ok()
}
pub unsafe fn all_gather(
sendbuf: *const c_void,
recvbuf: *mut c_void,
sendcount: usize,
datatype: NcclDataType,
comm: NcclComm,
stream: *mut c_void,
) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclAllGather)(sendbuf, recvbuf, sendcount, datatype, comm, stream).ok()
}
pub unsafe fn reduce_scatter(
sendbuf: *const c_void,
recvbuf: *mut c_void,
recvcount: usize,
datatype: NcclDataType,
op: NcclRedOp,
comm: NcclComm,
stream: *mut c_void,
) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclReduceScatter)(sendbuf, recvbuf, recvcount, datatype, op, comm, stream).ok()
}
pub unsafe fn send(
sendbuf: *const c_void,
count: usize,
datatype: NcclDataType,
peer: i32,
comm: NcclComm,
stream: *mut c_void,
) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclSend)(sendbuf, count, datatype, peer, comm, stream).ok()
}
pub unsafe fn recv(
recvbuf: *mut c_void,
count: usize,
datatype: NcclDataType,
peer: i32,
comm: NcclComm,
stream: *mut c_void,
) -> Result<(), NcclError> {
let lib = nccl()?;
(lib.ncclRecv)(recvbuf, count, datatype, peer, comm, stream).ok()
}
pub fn group_start() -> Result<(), NcclError> {
let lib = nccl()?;
unsafe { (lib.ncclGroupStart)() }.ok()
}
pub fn group_end() -> Result<(), NcclError> {
let lib = nccl()?;
unsafe { (lib.ncclGroupEnd)() }.ok()
}
pub fn is_available() -> bool {
nccl().is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nccl_availability_doesnt_panic() {
let available = is_available();
eprintln!("NCCL available: {available}");
}
}