use std::ffi::c_void;
use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
use ferrotorch_core::gpu_dispatch::{self, GpuBufferHandle};
use crate::collective::ReduceOp;
use crate::nccl_backend::{NcclBackend, reduce_op_to_nccl};
use crate::nccl_sys::NcclDataType;
fn infer_dtype(handle: &GpuBufferHandle) -> FerrotorchResult<NcclDataType> {
let backend = gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
match backend.buffer_elem_size(handle) {
4 => Ok(NcclDataType::Float32),
8 => Ok(NcclDataType::Float64),
0 => Err(FerrotorchError::InvalidArgument {
message: "NCCL collective: unrecognized buffer type".into(),
}),
other => Err(FerrotorchError::InvalidArgument {
message: format!("NCCL collective: unsupported element size {other}"),
}),
}
}
fn get_ptr(handle: &GpuBufferHandle) -> FerrotorchResult<*const c_void> {
let backend = gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let ptr = backend.raw_device_ptr(handle);
if ptr.is_null() {
return Err(FerrotorchError::InvalidArgument {
message: "NCCL collective: buffer has no valid device pointer".into(),
});
}
Ok(ptr)
}
fn get_ptr_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<*mut c_void> {
let backend = gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let ptr = backend.raw_device_ptr_mut(handle);
if ptr.is_null() {
return Err(FerrotorchError::InvalidArgument {
message: "NCCL collective: buffer has no valid device pointer".into(),
});
}
Ok(ptr)
}
pub fn nccl_allreduce(
buffer: &mut GpuBufferHandle,
backend: &NcclBackend,
op: &ReduceOp,
) -> FerrotorchResult<()> {
let dtype = infer_dtype(buffer)?;
nccl_allreduce_dtype(buffer, backend, op, dtype)
}
pub fn nccl_allreduce_dtype(
buffer: &mut GpuBufferHandle,
backend: &NcclBackend,
op: &ReduceOp,
dtype: NcclDataType,
) -> FerrotorchResult<()> {
let count = buffer.len();
let nccl_op = reduce_op_to_nccl(op);
let ptr = get_ptr_mut(buffer)?;
unsafe { backend.allreduce_raw(ptr.cast_const(), ptr, count, dtype, nccl_op) }
}
pub fn nccl_broadcast(
buffer: &mut GpuBufferHandle,
backend: &NcclBackend,
root: usize,
) -> FerrotorchResult<()> {
let dtype = infer_dtype(buffer)?;
nccl_broadcast_dtype(buffer, backend, root, dtype)
}
pub fn nccl_broadcast_dtype(
buffer: &mut GpuBufferHandle,
backend: &NcclBackend,
root: usize,
dtype: NcclDataType,
) -> FerrotorchResult<()> {
let count = buffer.len();
let ptr = get_ptr_mut(buffer)?;
unsafe { backend.broadcast_raw(ptr.cast_const(), ptr, count, dtype, root as i32) }
}
pub fn nccl_all_gather(
send_buf: &GpuBufferHandle,
recv_buf: &mut GpuBufferHandle,
backend: &NcclBackend,
) -> FerrotorchResult<()> {
let dtype = infer_dtype(send_buf)?;
nccl_all_gather_dtype(send_buf, recv_buf, backend, dtype)
}
pub fn nccl_all_gather_dtype(
send_buf: &GpuBufferHandle,
recv_buf: &mut GpuBufferHandle,
backend: &NcclBackend,
dtype: NcclDataType,
) -> FerrotorchResult<()> {
let sendcount = send_buf.len();
let send_ptr = get_ptr(send_buf)?;
let recv_ptr = get_ptr_mut(recv_buf)?;
unsafe { backend.all_gather_raw(send_ptr, recv_ptr, sendcount, dtype) }
}
pub fn nccl_reduce_scatter(
send_buf: &GpuBufferHandle,
recv_buf: &mut GpuBufferHandle,
backend: &NcclBackend,
op: &ReduceOp,
) -> FerrotorchResult<()> {
let dtype = infer_dtype(recv_buf)?;
nccl_reduce_scatter_dtype(send_buf, recv_buf, backend, op, dtype)
}
pub fn nccl_reduce_scatter_dtype(
send_buf: &GpuBufferHandle,
recv_buf: &mut GpuBufferHandle,
backend: &NcclBackend,
op: &ReduceOp,
dtype: NcclDataType,
) -> FerrotorchResult<()> {
let recvcount = recv_buf.len();
let nccl_op = reduce_op_to_nccl(op);
let send_ptr = get_ptr(send_buf)?;
let recv_ptr = get_ptr_mut(recv_buf)?;
unsafe { backend.reduce_scatter_raw(send_ptr, recv_ptr, recvcount, dtype, nccl_op) }
}