use super::nexar_compat::{to_nexar_dtype, to_nexar_op};
use super::{Communicator, ReduceOp};
use crate::dtype::DType;
use crate::error::{Error, Result};
fn map_err(e: nexar_nccl::NcclCommError) -> Error {
Error::Backend(format!("hierarchical communicator: {e}"))
}
fn map_nexar_err(e: nexar::NexarError) -> Error {
Error::Backend(format!("hierarchical communicator (nexar): {e}"))
}
pub struct HierarchicalCommunicator {
comm: nexar_nccl::HierarchicalComm,
rt: tokio::runtime::Runtime,
}
impl HierarchicalCommunicator {
pub fn new(comm: nexar_nccl::HierarchicalComm, rt: tokio::runtime::Runtime) -> Self {
Self { comm, rt }
}
pub fn inner(&self) -> &nexar_nccl::HierarchicalComm {
&self.comm
}
}
impl Communicator for HierarchicalCommunicator {
fn world_size(&self) -> usize {
self.comm.world_size() as usize
}
fn rank(&self) -> usize {
self.comm.rank() as usize
}
unsafe fn all_reduce(&self, ptr: u64, count: usize, dtype: DType, op: ReduceOp) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let no = to_nexar_op(op);
self.rt
.block_on(unsafe { self.comm.allreduce(ptr, count, nd, no) })
.map_err(map_err)
}
unsafe fn broadcast(&self, ptr: u64, count: usize, dtype: DType, root: usize) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
self.rt
.block_on(unsafe { self.comm.broadcast(ptr, count, nd, root as u32) })
.map_err(map_err)
}
unsafe fn all_gather(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
self.rt
.block_on(unsafe { self.comm.allgather(send_ptr, recv_ptr, count, nd) })
.map_err(map_err)
}
unsafe fn reduce_scatter(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
op: ReduceOp,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let no = to_nexar_op(op);
let ws = self.comm.world_size() as usize;
let total = count * ws;
self.rt
.block_on(unsafe { self.comm.allreduce(send_ptr, total, nd, no) })
.map_err(map_err)?;
let elem_size = dtype.size_in_bytes();
let offset = self.comm.rank() as usize * count * elem_size;
let bytes = count * elem_size;
unsafe {
std::ptr::copy_nonoverlapping(
(send_ptr as *const u8).add(offset),
recv_ptr as *mut u8,
bytes,
);
}
Ok(())
}
unsafe fn send(
&self,
ptr: u64,
count: usize,
dtype: DType,
dest: usize,
tag: u32,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let size = count * nd.size_in_bytes();
self.rt
.block_on(unsafe { self.comm.nexar().send(ptr, size, dest as u32, tag) })
.map_err(map_nexar_err)
}
unsafe fn recv(
&self,
ptr: u64,
count: usize,
dtype: DType,
src: usize,
tag: u32,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let size = count * nd.size_in_bytes();
self.rt
.block_on(unsafe { self.comm.nexar().recv(ptr, size, src as u32, tag) })
.map_err(map_nexar_err)
}
fn sync(&self) -> Result<()> {
self.comm.synchronize().map_err(map_err)
}
fn barrier(&self) -> Result<()> {
self.rt.block_on(self.comm.barrier()).map_err(map_err)
}
}