use std::ffi::c_void;
use std::sync::Mutex;
use std::time::Duration;
use ferrotorch_core::FerrotorchResult;
use crate::backend::Backend;
use crate::collective::ReduceOp;
use crate::error::DistributedError;
use crate::nccl_sys::{
self, NcclComm, NcclDataType, NcclRedOp, NcclUniqueId,
};
pub struct NcclBackend {
comm: Mutex<NcclComm>,
rank: usize,
world_size: usize,
stream: *mut c_void,
owns_stream: bool,
}
unsafe impl Send for NcclBackend {}
unsafe impl Sync for NcclBackend {}
impl NcclBackend {
pub fn new(
rank: usize,
world_size: usize,
unique_id: NcclUniqueId,
) -> FerrotorchResult<Self> {
let comm = nccl_sys::comm_init_rank(
world_size as i32,
rank as i32,
unique_id,
)
.map_err(|e| DistributedError::Io {
message: format!("NCCL comm_init_rank failed: {e}"),
})?;
let stream = create_nccl_stream().unwrap_or(std::ptr::null_mut());
let owns_stream = !stream.is_null();
Ok(Self {
comm: Mutex::new(comm),
rank,
world_size,
stream,
owns_stream,
})
}
pub fn with_stream(
rank: usize,
world_size: usize,
unique_id: NcclUniqueId,
stream: *mut c_void,
) -> FerrotorchResult<Self> {
let comm = nccl_sys::comm_init_rank(
world_size as i32,
rank as i32,
unique_id,
)
.map_err(|e| DistributedError::Io {
message: format!("NCCL comm_init_rank failed: {e}"),
})?;
Ok(Self {
comm: Mutex::new(comm),
rank,
world_size,
stream,
owns_stream: false,
})
}
pub fn synchronize(&self) -> FerrotorchResult<()> {
if self.stream.is_null() {
return Ok(()); }
synchronize_stream(self.stream).map_err(|msg| {
DistributedError::Io { message: msg }.into()
})
}
pub fn comm(&self) -> &Mutex<NcclComm> {
&self.comm
}
fn lock_comm(&self) -> FerrotorchResult<std::sync::MutexGuard<'_, NcclComm>> {
self.comm.lock().map_err(|_| {
DistributedError::LockPoisoned {
message: "NCCL communicator mutex poisoned".into(),
}
.into()
})
}
pub unsafe fn allreduce_raw(
&self,
sendbuf: *const c_void,
recvbuf: *mut c_void,
count: usize,
datatype: NcclDataType,
op: NcclRedOp,
) -> FerrotorchResult<()> {
let comm = *self.lock_comm()?;
nccl_sys::all_reduce(sendbuf, recvbuf, count, datatype, op, comm, self.stream)
.map_err(|e| {
DistributedError::Io {
message: format!("NCCL allreduce failed: {e}"),
}
.into()
})
}
pub unsafe fn broadcast_raw(
&self,
sendbuf: *const c_void,
recvbuf: *mut c_void,
count: usize,
datatype: NcclDataType,
root: i32,
) -> FerrotorchResult<()> {
let comm = *self.lock_comm()?;
nccl_sys::broadcast(sendbuf, recvbuf, count, datatype, root, comm, self.stream)
.map_err(|e| {
DistributedError::Io {
message: format!("NCCL broadcast failed: {e}"),
}
.into()
})
}
pub unsafe fn all_gather_raw(
&self,
sendbuf: *const c_void,
recvbuf: *mut c_void,
sendcount: usize,
datatype: NcclDataType,
) -> FerrotorchResult<()> {
let comm = *self.lock_comm()?;
nccl_sys::all_gather(sendbuf, recvbuf, sendcount, datatype, comm, self.stream)
.map_err(|e| {
DistributedError::Io {
message: format!("NCCL all_gather failed: {e}"),
}
.into()
})
}
pub unsafe fn reduce_scatter_raw(
&self,
sendbuf: *const c_void,
recvbuf: *mut c_void,
recvcount: usize,
datatype: NcclDataType,
op: NcclRedOp,
) -> FerrotorchResult<()> {
let comm = *self.lock_comm()?;
nccl_sys::reduce_scatter(
sendbuf, recvbuf, recvcount, datatype, op, comm, self.stream,
)
.map_err(|e| {
DistributedError::Io {
message: format!("NCCL reduce_scatter failed: {e}"),
}
.into()
})
}
}
impl Backend for NcclBackend {
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, _data: &[u8], _dst_rank: usize) -> FerrotorchResult<()> {
Err(DistributedError::UnsupportedOp {
message: "NcclBackend does not support byte-level P2P send — use GPU-native collectives (nccl_allreduce, nccl_broadcast) or TcpBackend for P2P".into(),
}.into())
}
fn recv(&self, _dst: &mut [u8], _src_rank: usize) -> FerrotorchResult<()> {
Err(DistributedError::UnsupportedOp {
message: "NcclBackend does not support byte-level P2P recv — use GPU-native collectives or TcpBackend for P2P".into(),
}.into())
}
fn barrier(&self) -> FerrotorchResult<()> {
let comm = *self.lock_comm()?;
unsafe {
nccl_sys::all_reduce(
std::ptr::null(),
std::ptr::null_mut(),
0,
NcclDataType::Float32,
NcclRedOp::Sum,
comm,
self.stream,
)
.map_err(|e| DistributedError::Io {
message: format!("NCCL barrier (allreduce): {e}"),
})?;
}
Ok(())
}
}
impl Drop for NcclBackend {
fn drop(&mut self) {
if let Ok(comm) = self.comm.lock() {
if !(*comm).is_null() {
unsafe {
let _ = nccl_sys::comm_destroy(*comm);
}
}
}
if self.owns_stream && !self.stream.is_null() {
destroy_stream(self.stream);
}
}
}
fn create_nccl_stream() -> Option<*mut c_void> {
let lib = unsafe { libc::dlopen(b"libcudart.so.12\0".as_ptr() as *const _, libc::RTLD_LAZY) };
if lib.is_null() {
let lib = unsafe { libc::dlopen(b"libcudart.so\0".as_ptr() as *const _, libc::RTLD_LAZY) };
if lib.is_null() {
return None;
}
return create_stream_from_lib(lib);
}
create_stream_from_lib(lib)
}
fn create_stream_from_lib(lib: *mut c_void) -> Option<*mut c_void> {
let sym = unsafe {
libc::dlsym(lib, b"cudaStreamCreateWithFlags\0".as_ptr() as *const _)
};
if sym.is_null() {
return None;
}
type CudaStreamCreateFn = unsafe extern "C" fn(*mut *mut c_void, u32) -> i32;
let create_fn: CudaStreamCreateFn = unsafe { std::mem::transmute(sym) };
let mut stream: *mut c_void = std::ptr::null_mut();
let result = unsafe { create_fn(&mut stream, 1) }; if result == 0 {
Some(stream)
} else {
None
}
}
fn synchronize_stream(stream: *mut c_void) -> Result<(), String> {
let lib = unsafe { libc::dlopen(b"libcudart.so.12\0".as_ptr() as *const _, libc::RTLD_LAZY) };
let lib = if lib.is_null() {
unsafe { libc::dlopen(b"libcudart.so\0".as_ptr() as *const _, libc::RTLD_LAZY) }
} else {
lib
};
if lib.is_null() {
return Err("cudart not found".into());
}
let sym = unsafe {
libc::dlsym(lib, b"cudaStreamSynchronize\0".as_ptr() as *const _)
};
if sym.is_null() {
return Err("cudaStreamSynchronize not found".into());
}
type SyncFn = unsafe extern "C" fn(*mut c_void) -> i32;
let sync_fn: SyncFn = unsafe { std::mem::transmute(sym) };
let result = unsafe { sync_fn(stream) };
if result == 0 {
Ok(())
} else {
Err(format!("cudaStreamSynchronize failed: error {result}"))
}
}
fn destroy_stream(stream: *mut c_void) {
let lib = unsafe { libc::dlopen(b"libcudart.so.12\0".as_ptr() as *const _, libc::RTLD_LAZY) };
let lib = if lib.is_null() {
unsafe { libc::dlopen(b"libcudart.so\0".as_ptr() as *const _, libc::RTLD_LAZY) }
} else {
lib
};
if lib.is_null() {
return;
}
let sym = unsafe {
libc::dlsym(lib, b"cudaStreamDestroy\0".as_ptr() as *const _)
};
if sym.is_null() {
return;
}
type DestroyFn = unsafe extern "C" fn(*mut c_void) -> i32;
let destroy_fn: DestroyFn = unsafe { std::mem::transmute(sym) };
unsafe { destroy_fn(stream) };
}
pub fn reduce_op_to_nccl(op: &ReduceOp) -> NcclRedOp {
match op {
ReduceOp::Sum => NcclRedOp::Sum,
ReduceOp::Mean => NcclRedOp::Avg,
}
}
pub fn is_nccl_available() -> bool {
nccl_sys::is_available()
}