use std::marker::PhantomData;
use std::sync::Arc;
use cudarc::nccl::sys;
use cudarc::nccl::Comm;
use super::{NcclReduceSupported, LIB};
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
#[repr(C)]
struct CommLayoutShadow {
raw_comm: sys::ncclComm_t,
_stream: std::mem::ManuallyDrop<Arc<cudarc::driver::CudaStream>>,
_rank: usize,
_world_size: usize,
}
fn raw_comm_ptr(comm: &Comm) -> sys::ncclComm_t {
debug_assert!(std::mem::size_of::<CommLayoutShadow>() <= std::mem::size_of::<Comm>());
unsafe {
let p = comm as *const Comm as *const CommLayoutShadow;
(*p).raw_comm
}
}
pub struct PreMulSumOp<T: NcclReduceSupported> {
handle: sys::ncclRedOp_t,
#[allow(dead_code)]
scalar: GpuRef<T>,
comm_ptr: sys::ncclComm_t,
_phantom: PhantomData<T>,
}
unsafe impl<T: NcclReduceSupported> Send for PreMulSumOp<T> {}
impl<T: NcclReduceSupported> PreMulSumOp<T> {
pub fn new(comm: &Comm, scalar: GpuRef<T>) -> Result<Self, GpuError> {
let mut handle: sys::ncclRedOp_t = sys::ncclRedOp_t::ncclSum;
let comm_ptr = raw_comm_ptr(comm);
{
let slice = scalar.access()?;
if slice.len() == 0 {
return Err(GpuError::Unrecoverable(
"PreMulSumOp scalar buffer is empty".into(),
));
}
let stream = comm.stream();
let (dptr, _record) = {
use cudarc::driver::DevicePtr;
slice.device_ptr(&stream)
};
unsafe {
sys::ncclRedOpCreatePreMulSum(
&mut handle as *mut sys::ncclRedOp_t,
dptr as *mut std::ffi::c_void,
<T as cudarc::nccl::NcclType>::as_nccl_type(),
sys::ncclScalarResidence_t::ncclScalarDevice,
comm_ptr,
)
.result()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("ncclRedOpCreatePreMulSum: {e:?}"),
})?;
}
}
Ok(Self {
handle,
scalar,
comm_ptr,
_phantom: PhantomData,
})
}
pub fn handle(&self) -> sys::ncclRedOp_t {
self.handle
}
pub fn destroy(self) -> Result<(), GpuError> {
unsafe {
sys::ncclRedOpDestroy(self.handle, self.comm_ptr)
.result()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("ncclRedOpDestroy: {e:?}"),
})?;
}
Ok(())
}
}