use super::*;
use std::cell::Cell;
use std::ffi::c_void;
use std::ops::Range;
use anyhow::{Context, Result};
use cudarc::driver::sys::CUstream;
use cudarc::nccl::sys::{
ncclBcast, ncclComm_t, ncclDataType_t, ncclGroupEnd, ncclGroupStart, ncclResult_t,
};
fn check_nccl_result(result: ncclResult_t) -> Result<()> {
if result == ncclResult_t::ncclSuccess {
Ok(())
} else {
anyhow::bail!("NCCL error: {:?}", result)
}
}
pub struct NcclGroup {
ended: Cell<bool>,
}
impl NcclGroup {
pub unsafe fn new() -> Result<Self> {
let result = unsafe { ncclGroupStart() };
check_nccl_result(result).context("ncclGroupStart failed")?;
Ok(Self {
ended: Cell::new(false),
})
}
pub fn end(&self) -> Result<()> {
if self.ended.get() {
anyhow::bail!("NcclGroup::end called twice");
}
let result = unsafe { ncclGroupEnd() };
check_nccl_result(result).context("ncclGroupEnd failed")?;
self.ended.set(true);
Ok(())
}
}
impl Drop for NcclGroup {
fn drop(&mut self) {
if self.ended.get() {
return; }
let result = unsafe { ncclGroupEnd() };
if result != ncclResult_t::ncclSuccess {
panic!(
"ncclGroupEnd failed in NcclGroup drop: {:?}. Call NcclGroup::end() before drop to handle errors.",
result
);
}
}
}
pub unsafe fn bcast_block<B>(block: &B, root: i32, comm: ncclComm_t, stream: CUstream) -> Result<()>
where
B: BlockDataProvider,
{
let data = block.block_data();
if data.is_fully_contiguous() {
let view = data.block_view().context("Failed to get block view")?;
let ptr = unsafe { view.as_ptr() } as usize;
let size = view.size();
let result = unsafe {
ncclBcast(
ptr as *mut c_void,
size,
ncclDataType_t::ncclChar,
root,
comm,
stream.cast(),
)
};
check_nccl_result(result).context("ncclBcast failed")
} else {
unsafe { bcast_layer(block, None, root, comm, stream) }
}
}
pub unsafe fn bcast_layer<B>(
block: &B,
layer_range: Option<Range<usize>>,
root: i32,
comm: ncclComm_t,
stream: CUstream,
) -> Result<()>
where
B: BlockDataProvider,
{
let data = block.block_data();
let layer_range = layer_range.unwrap_or(0..data.num_layers());
for layer_idx in layer_range {
for outer_idx in 0..data.num_outer_dims() {
let view = data
.layer_view(layer_idx, outer_idx)
.context("Failed to get layer view")?;
let ptr = unsafe { view.as_ptr() } as usize;
let size = view.size();
let result = unsafe {
ncclBcast(
ptr as *mut c_void,
size,
ncclDataType_t::ncclChar,
root,
comm,
stream.cast(),
)
};
check_nccl_result(result).context("ncclBcast failed in layer loop")?;
}
}
Ok(())
}