use anyhow::{Context, Result};
use cudarc::nccl::sys::{
ncclComm_t, ncclCommDestroy, ncclCommInitRankConfig, ncclConfig_t, ncclGetUniqueId,
ncclGetVersion, ncclResult_t, ncclUniqueId,
};
fn check_nccl_result(result: ncclResult_t, operation: &str) -> Result<()> {
if result == ncclResult_t::ncclSuccess {
Ok(())
} else {
let error_name = match result {
ncclResult_t::ncclUnhandledCudaError => "ncclUnhandledCudaError",
ncclResult_t::ncclSystemError => "ncclSystemError",
ncclResult_t::ncclInternalError => "ncclInternalError",
ncclResult_t::ncclInvalidArgument => "ncclInvalidArgument",
ncclResult_t::ncclInvalidUsage => "ncclInvalidUsage",
ncclResult_t::ncclRemoteError => "ncclRemoteError",
ncclResult_t::ncclInProgress => "ncclInProgress",
_ => "Unknown",
};
anyhow::bail!(
"{} failed with error: {} ({:?}). Check NCCL_DEBUG=INFO for more details.",
operation,
error_name,
result
)
}
}
pub struct NcclBootstrap {
unique_id: ncclUniqueId,
world_size: i32,
}
impl NcclBootstrap {
pub fn generate(world_size: i32) -> Result<Self> {
let mut unique_id = ncclUniqueId { internal: [0; 128] };
let result = unsafe { ncclGetUniqueId(&mut unique_id) };
check_nccl_result(result, "ncclGetUniqueId")?;
Ok(Self {
unique_id,
world_size,
})
}
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(136);
bytes.extend_from_slice(&self.world_size.to_le_bytes());
bytes.extend_from_slice(&[0u8; 4]); let internal_bytes: &[u8; 128] =
unsafe { &*self.unique_id.internal.as_ptr().cast::<[u8; 128]>() };
bytes.extend_from_slice(internal_bytes);
bytes
}
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
anyhow::ensure!(
bytes.len() == 136,
"Invalid bootstrap data length: expected 136, got {}",
bytes.len()
);
let world_size = i32::from_le_bytes(
bytes[0..4]
.try_into()
.context("Failed to parse world_size")?,
);
let mut unique_id = ncclUniqueId { internal: [0; 128] };
unsafe {
std::ptr::copy_nonoverlapping(
bytes[8..136].as_ptr(),
unique_id.internal.as_mut_ptr().cast::<u8>(),
128,
);
}
Ok(Self {
unique_id,
world_size,
})
}
pub fn init_communicator(&self, rank: i32) -> Result<ncclComm_t> {
anyhow::ensure!(
rank >= 0 && rank < self.world_size,
"Invalid rank {}: must be in range [0, {})",
rank,
self.world_size
);
let mut config: ncclConfig_t;
let nccl_version = {
let mut v: std::ffi::c_int = 0;
let result = unsafe { ncclGetVersion(&mut v) };
check_nccl_result(result, "ncclGetVersion")?;
tracing::debug!("NCCL runtime version: {v}");
v as std::ffi::c_uint
};
let max_ctas = std::env::var("DYN_KVBM_NCCL_MAX_CTAS")
.ok()
.and_then(|val| val.parse::<i32>().ok())
.unwrap_or(8);
config = ncclConfig_t {
size: std::mem::size_of::<ncclConfig_t>(),
magic: 0xcafebeef, version: nccl_version,
blocking: 1,
cgaClusterSize: i32::MIN,
minCTAs: 1,
maxCTAs: max_ctas,
netName: std::ptr::null_mut(),
splitShare: i32::MIN,
trafficClass: i32::MIN,
commName: std::ptr::null_mut(),
collnetEnable: 0,
CTAPolicy: i32::MIN,
shrinkShare: i32::MIN,
nvlsCTAs: i32::MIN,
nChannelsPerNetPeer: i32::MIN,
nvlinkCentricSched: i32::MIN,
};
let mut comm: ncclComm_t = std::ptr::null_mut();
tracing::debug!(
"Calling ncclCommInitRank: rank={}, world_size={}",
rank,
self.world_size
);
let result = unsafe {
ncclCommInitRankConfig(
&mut comm,
self.world_size,
self.unique_id,
rank,
&mut config,
)
};
check_nccl_result(
result,
&format!(
"ncclCommInitRank(rank={}, world_size={})",
rank, self.world_size
),
)?;
tracing::info!(
"NCCL communicator initialized successfully: rank={}, world_size={}",
rank,
self.world_size
);
Ok(comm)
}
pub fn world_size(&self) -> i32 {
self.world_size
}
}
pub struct NcclCommOwned {
comm: ncclComm_t,
}
unsafe impl Send for NcclCommOwned {}
unsafe impl Sync for NcclCommOwned {}
impl NcclCommOwned {
pub unsafe fn from_raw(comm: ncclComm_t) -> Self {
Self { comm }
}
pub fn as_raw(&self) -> ncclComm_t {
self.comm
}
}
impl Drop for NcclCommOwned {
fn drop(&mut self) {
if !self.comm.is_null() {
let result = unsafe { ncclCommDestroy(self.comm) };
if result != ncclResult_t::ncclSuccess {
tracing::error!("Failed to destroy NCCL communicator: {:?}", result);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize() {
let internal_bytes: [u8; 128] = [42u8; 128];
let mut unique_id = ncclUniqueId { internal: [0; 128] };
unsafe {
std::ptr::copy_nonoverlapping(
internal_bytes.as_ptr(),
unique_id.internal.as_mut_ptr().cast::<u8>(),
128,
);
}
let bootstrap = NcclBootstrap {
unique_id,
world_size: 4,
};
let bytes = bootstrap.serialize();
assert_eq!(bytes.len(), 136);
let restored = NcclBootstrap::deserialize(&bytes).unwrap();
assert_eq!(restored.world_size, 4);
let restored_bytes: &[u8; 128] =
unsafe { &*restored.unique_id.internal.as_ptr().cast::<[u8; 128]>() };
assert_eq!(*restored_bytes, [42u8; 128]);
}
#[test]
fn test_deserialize_invalid_length() {
let bytes = vec![0u8; 100]; let result = NcclBootstrap::deserialize(&bytes);
assert!(result.is_err());
}
}