use crate::compression::Compressor;
use crate::device::DeviceAdapter;
use crate::error::Result;
use crate::memory::{BufferRef, Device, Host};
use crate::rpc::registry::RpcHandler;
use crate::types::{DataType, Rank, ReduceOp};
use std::sync::Arc;
pub struct SyncClient {
inner: super::NexarClient,
rt: tokio::runtime::Runtime,
}
macro_rules! sync_unsafe {
(
$(#[$meta:meta])*
$vis:vis fn $name:ident(&self $(, $arg:ident: $ty:ty)*) -> Result<()>
) => {
$(#[$meta])*
$vis unsafe fn $name(&self $(, $arg: $ty)*) -> Result<()> {
self.rt.block_on(unsafe { self.inner.$name($($arg),*) })
}
};
}
macro_rules! sync_typed {
(
$(#[$meta:meta])*
$vis:vis fn $name:ident(&self, buf: &mut BufferRef<$loc:ty>
$(, $arg:ident: $ty:ty)*) -> Result<()>
) => {
$(#[$meta])*
$vis fn $name(&self, buf: &mut BufferRef<$loc> $(, $arg: $ty)*) -> Result<()> {
self.rt.block_on(self.inner.$name(buf $(, $arg)*))
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident(&self,
send_buf: &BufferRef<$loc:ty>,
recv_buf: &mut BufferRef<$loc2:ty>
$(, $arg:ident: $ty:ty)*) -> Result<()>
) => {
$(#[$meta])*
$vis fn $name(
&self,
send_buf: &BufferRef<$loc>,
recv_buf: &mut BufferRef<$loc2>
$(, $arg: $ty)*
) -> Result<()> {
self.rt.block_on(self.inner.$name(send_buf, recv_buf $(, $arg)*))
}
};
}
impl SyncClient {
pub fn bootstrap_local(world_size: u32, adapter: Arc<dyn DeviceAdapter>) -> Result<Vec<Self>> {
let rt = tokio::runtime::Runtime::new()
.map_err(|e| crate::error::NexarError::transport_with_source("tokio runtime", e))?;
let clients = rt.block_on(super::NexarClient::bootstrap_local(world_size, adapter))?;
let mut sync_clients = Vec::new();
let mut iter = clients.into_iter();
if let Some(first) = iter.next() {
sync_clients.push(SyncClient { inner: first, rt });
}
for client in iter {
let rt = tokio::runtime::Runtime::new()
.map_err(|e| crate::error::NexarError::transport_with_source("tokio runtime", e))?;
sync_clients.push(SyncClient { inner: client, rt });
}
Ok(sync_clients)
}
pub fn from_async(inner: super::NexarClient) -> Result<Self> {
let rt = tokio::runtime::Runtime::new()
.map_err(|e| crate::error::NexarError::transport_with_source("tokio runtime", e))?;
Ok(Self { inner, rt })
}
pub fn rank(&self) -> Rank {
self.inner.rank()
}
pub fn world_size(&self) -> u32 {
self.inner.world_size()
}
sync_unsafe! {
pub fn all_reduce(&self, ptr: u64, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_unsafe! {
pub fn all_reduce_bucketed(&self, entries: &[(u64, usize)], dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_unsafe! {
pub fn all_reduce_rs_ag(&self, ptr: u64, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_unsafe! {
pub fn broadcast(&self, ptr: u64, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_unsafe! {
pub fn all_gather(&self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DataType) -> Result<()>
}
sync_unsafe! {
pub fn reduce_scatter(&self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_unsafe! {
pub fn reduce(&self, ptr: u64, count: usize, dtype: DataType, op: ReduceOp, root: Rank) -> Result<()>
}
sync_unsafe! {
pub fn all_to_all(&self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DataType) -> Result<()>
}
sync_unsafe! {
pub fn gather(&self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_unsafe! {
pub fn scatter(&self, send_ptr: u64, recv_ptr: u64, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_unsafe! {
pub fn exclusive_scan(&self, ptr: u64, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_unsafe! {
pub fn scan(&self, ptr: u64, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_unsafe! {
pub fn send(&self, ptr: u64, size: usize, dest: Rank, tag: u32) -> Result<()>
}
sync_unsafe! {
pub fn recv(&self, ptr: u64, size: usize, src: Rank, tag: u32) -> Result<()>
}
pub fn barrier(&self) -> Result<()> {
self.rt.block_on(self.inner.barrier())
}
pub fn split(&self, color: u32, key: u32) -> Result<SyncClient> {
let inner = self.rt.block_on(self.inner.split(color, key))?;
SyncClient::from_async(inner)
}
sync_typed! {
pub fn all_reduce_host(&self, buf: &mut BufferRef<Host>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_typed! {
pub fn broadcast_host(&self, buf: &mut BufferRef<Host>, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_typed! {
pub fn all_gather_host(&self, send_buf: &BufferRef<Host>, recv_buf: &mut BufferRef<Host>, count: usize, dtype: DataType) -> Result<()>
}
sync_typed! {
pub fn reduce_scatter_host(&self, send_buf: &BufferRef<Host>, recv_buf: &mut BufferRef<Host>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_typed! {
pub fn reduce_host(&self, buf: &mut BufferRef<Host>, count: usize, dtype: DataType, op: ReduceOp, root: Rank) -> Result<()>
}
sync_typed! {
pub fn all_to_all_host(&self, send_buf: &BufferRef<Host>, recv_buf: &mut BufferRef<Host>, count: usize, dtype: DataType) -> Result<()>
}
sync_typed! {
pub fn gather_host(&self, send_buf: &BufferRef<Host>, recv_buf: &mut BufferRef<Host>, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_typed! {
pub fn scatter_host(&self, send_buf: &BufferRef<Host>, recv_buf: &mut BufferRef<Host>, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_typed! {
pub fn scan_host(&self, buf: &mut BufferRef<Host>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_typed! {
pub fn exclusive_scan_host(&self, buf: &mut BufferRef<Host>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
pub fn all_reduce_compressed_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
compressor: &dyn Compressor,
residual: &mut [u8],
) -> Result<()> {
self.rt.block_on(
self.inner
.all_reduce_compressed_host(buf, count, dtype, op, compressor, residual),
)
}
sync_typed! {
pub fn all_reduce_device(&self, buf: &mut BufferRef<Device>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_typed! {
pub fn broadcast_device(&self, buf: &mut BufferRef<Device>, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_typed! {
pub fn all_gather_device(&self, send_buf: &BufferRef<Device>, recv_buf: &mut BufferRef<Device>, count: usize, dtype: DataType) -> Result<()>
}
sync_typed! {
pub fn reduce_scatter_device(&self, send_buf: &BufferRef<Device>, recv_buf: &mut BufferRef<Device>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_typed! {
pub fn reduce_device(&self, buf: &mut BufferRef<Device>, count: usize, dtype: DataType, op: ReduceOp, root: Rank) -> Result<()>
}
sync_typed! {
pub fn all_to_all_device(&self, send_buf: &BufferRef<Device>, recv_buf: &mut BufferRef<Device>, count: usize, dtype: DataType) -> Result<()>
}
sync_typed! {
pub fn gather_device(&self, send_buf: &BufferRef<Device>, recv_buf: &mut BufferRef<Device>, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_typed! {
pub fn scatter_device(&self, send_buf: &BufferRef<Device>, recv_buf: &mut BufferRef<Device>, count: usize, dtype: DataType, root: Rank) -> Result<()>
}
sync_typed! {
pub fn scan_device(&self, buf: &mut BufferRef<Device>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
sync_typed! {
pub fn exclusive_scan_device(&self, buf: &mut BufferRef<Device>, count: usize, dtype: DataType, op: ReduceOp) -> Result<()>
}
pub fn all_reduce_compressed_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
compressor: &dyn Compressor,
residual: &mut [u8],
) -> Result<()> {
self.rt.block_on(
self.inner
.all_reduce_compressed_device(buf, count, dtype, op, compressor, residual),
)
}
pub fn register_rpc(&self, fn_id: u16, handler: RpcHandler) {
self.rt.block_on(self.inner.register_rpc(fn_id, handler))
}
pub fn rpc(&self, target: Rank, fn_id: u16, args: &[u8]) -> Result<Vec<u8>> {
self.rt.block_on(self.inner.rpc(target, fn_id, args))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::CpuAdapter;
#[test]
fn test_sync_client_single_node() {
let adapter = Arc::new(CpuAdapter::new());
let clients = SyncClient::bootstrap_local(1, adapter).unwrap();
assert_eq!(clients.len(), 1);
assert_eq!(clients[0].rank(), 0);
assert_eq!(clients[0].world_size(), 1);
}
}