use super::nexar_compat::{to_nexar_dtype, to_nexar_op};
use super::{Communicator, ReduceOp};
use crate::dtype::DType;
use crate::error::{Error, Result};
fn map_err(e: nexar::NexarError) -> Error {
Error::Backend(format!("nexar: {e}"))
}
pub struct NexarNetCommunicator {
client: nexar::SyncClient,
}
impl NexarNetCommunicator {
pub fn new(client: nexar::SyncClient) -> Self {
Self { client }
}
}
impl Communicator for NexarNetCommunicator {
fn world_size(&self) -> usize {
self.client.world_size() as usize
}
fn rank(&self) -> usize {
self.client.rank() as usize
}
unsafe fn all_reduce(&self, ptr: u64, count: usize, dtype: DType, op: ReduceOp) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let no = to_nexar_op(op);
unsafe { self.client.all_reduce(ptr, count, nd, no).map_err(map_err) }
}
unsafe fn broadcast(&self, ptr: u64, count: usize, dtype: DType, root: usize) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
unsafe {
self.client
.broadcast(ptr, count, nd, root as u32)
.map_err(map_err)
}
}
unsafe fn all_gather(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
unsafe {
self.client
.all_gather(send_ptr, recv_ptr, count, nd)
.map_err(map_err)
}
}
unsafe fn reduce_scatter(
&self,
send_ptr: u64,
recv_ptr: u64,
count: usize,
dtype: DType,
op: ReduceOp,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let no = to_nexar_op(op);
unsafe {
self.client
.reduce_scatter(send_ptr, recv_ptr, count, nd, no)
.map_err(map_err)
}
}
unsafe fn send(
&self,
ptr: u64,
count: usize,
dtype: DType,
dest: usize,
tag: u32,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let size = count * nd.size_in_bytes();
unsafe {
self.client
.send(ptr, size, dest as u32, tag)
.map_err(map_err)
}
}
unsafe fn recv(
&self,
ptr: u64,
count: usize,
dtype: DType,
src: usize,
tag: u32,
) -> Result<()> {
let nd = to_nexar_dtype(dtype)?;
let size = count * nd.size_in_bytes();
unsafe {
self.client
.recv(ptr, size, src as u32, tag)
.map_err(map_err)
}
}
fn sync(&self) -> Result<()> {
Ok(())
}
fn barrier(&self) -> Result<()> {
self.client.barrier().map_err(map_err)
}
fn split(&self, color: u32, key: u32) -> Result<Option<Box<dyn Communicator>>> {
let sub = self.client.split(color, key).map_err(map_err)?;
Ok(Some(Box::new(NexarNetCommunicator::new(sub))))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nexar_communicator_metadata() {
let adapter = std::sync::Arc::new(nexar::CpuAdapter::new());
let clients = nexar::SyncClient::bootstrap_local(2, adapter).unwrap();
let comms: Vec<NexarNetCommunicator> =
clients.into_iter().map(NexarNetCommunicator::new).collect();
assert_eq!(comms[0].world_size(), 2);
assert_eq!(comms[0].rank(), 0);
assert_eq!(comms[1].rank(), 1);
}
#[test]
fn test_nexar_allreduce_f32() {
let adapter = std::sync::Arc::new(nexar::CpuAdapter::new());
let clients = nexar::SyncClient::bootstrap_local(2, adapter).unwrap();
let comms: Vec<NexarNetCommunicator> =
clients.into_iter().map(NexarNetCommunicator::new).collect();
std::thread::scope(|s| {
let handles: Vec<_> = comms
.iter()
.enumerate()
.map(|(i, comm)| {
s.spawn(move || {
let val = (i + 1) as f32;
let mut data = vec![val; 4];
let ptr = data.as_mut_ptr() as u64;
unsafe {
comm.all_reduce(ptr, 4, DType::F32, ReduceOp::Sum).unwrap();
}
data
})
})
.collect();
for h in handles {
let data = h.join().unwrap();
assert_eq!(data, vec![3.0f32; 4]);
}
});
}
}