use crate::compression::Compressor;
use crate::error::Result;
use crate::memory::{BufferRef, Device, Host};
use crate::types::{DataType, Rank, ReduceOp};
use super::NexarClient;
impl NexarClient {
pub async fn all_reduce_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe { self.all_reduce(buf.as_u64(), count, dtype, op).await }
}
pub async fn broadcast_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
root: Rank,
) -> Result<()> {
unsafe { self.broadcast(buf.as_u64(), count, dtype, root).await }
}
pub async fn all_gather_host(
&self,
send_buf: &BufferRef<Host>,
recv_buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
) -> Result<()> {
unsafe {
self.all_gather(send_buf.as_u64(), recv_buf.as_u64(), count, dtype)
.await
}
}
pub async fn reduce_scatter_host(
&self,
send_buf: &BufferRef<Host>,
recv_buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe {
self.reduce_scatter(send_buf.as_u64(), recv_buf.as_u64(), count, dtype, op)
.await
}
}
pub async fn reduce_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
root: Rank,
) -> Result<()> {
unsafe { self.reduce(buf.as_u64(), count, dtype, op, root).await }
}
pub async fn all_to_all_host(
&self,
send_buf: &BufferRef<Host>,
recv_buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
) -> Result<()> {
unsafe {
self.all_to_all(send_buf.as_u64(), recv_buf.as_u64(), count, dtype)
.await
}
}
pub async fn gather_host(
&self,
send_buf: &BufferRef<Host>,
recv_buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
root: Rank,
) -> Result<()> {
unsafe {
self.gather(send_buf.as_u64(), recv_buf.as_u64(), count, dtype, root)
.await
}
}
pub async fn scatter_host(
&self,
send_buf: &BufferRef<Host>,
recv_buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
root: Rank,
) -> Result<()> {
unsafe {
self.scatter(send_buf.as_u64(), recv_buf.as_u64(), count, dtype, root)
.await
}
}
pub async fn scan_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe { self.scan(buf.as_u64(), count, dtype, op).await }
}
pub async fn exclusive_scan_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe { self.exclusive_scan(buf.as_u64(), count, dtype, op).await }
}
pub async fn all_reduce_bucketed_host(
&self,
entries: &[(BufferRef<Host>, usize)],
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
let raw: Vec<(u64, usize)> = entries
.iter()
.map(|(b, count)| (b.as_u64(), *count))
.collect();
unsafe { self.all_reduce_bucketed(&raw, dtype, op).await }
}
pub async fn all_reduce_compressed_host(
&self,
buf: &mut BufferRef<Host>,
count: usize,
dtype: DataType,
op: ReduceOp,
compressor: &dyn Compressor,
residual: &mut [u8],
) -> Result<()> {
unsafe {
self.all_reduce_compressed(buf.as_u64(), count, dtype, op, compressor, residual)
.await
}
}
}
impl NexarClient {
pub async fn all_reduce_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe { self.all_reduce(buf.as_u64(), count, dtype, op).await }
}
pub async fn broadcast_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
root: Rank,
) -> Result<()> {
unsafe { self.broadcast(buf.as_u64(), count, dtype, root).await }
}
pub async fn all_gather_device(
&self,
send_buf: &BufferRef<Device>,
recv_buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
) -> Result<()> {
unsafe {
self.all_gather(send_buf.as_u64(), recv_buf.as_u64(), count, dtype)
.await
}
}
pub async fn reduce_scatter_device(
&self,
send_buf: &BufferRef<Device>,
recv_buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe {
self.reduce_scatter(send_buf.as_u64(), recv_buf.as_u64(), count, dtype, op)
.await
}
}
pub async fn reduce_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
root: Rank,
) -> Result<()> {
unsafe { self.reduce(buf.as_u64(), count, dtype, op, root).await }
}
pub async fn all_to_all_device(
&self,
send_buf: &BufferRef<Device>,
recv_buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
) -> Result<()> {
unsafe {
self.all_to_all(send_buf.as_u64(), recv_buf.as_u64(), count, dtype)
.await
}
}
pub async fn gather_device(
&self,
send_buf: &BufferRef<Device>,
recv_buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
root: Rank,
) -> Result<()> {
unsafe {
self.gather(send_buf.as_u64(), recv_buf.as_u64(), count, dtype, root)
.await
}
}
pub async fn scatter_device(
&self,
send_buf: &BufferRef<Device>,
recv_buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
root: Rank,
) -> Result<()> {
unsafe {
self.scatter(send_buf.as_u64(), recv_buf.as_u64(), count, dtype, root)
.await
}
}
pub async fn scan_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe { self.scan(buf.as_u64(), count, dtype, op).await }
}
pub async fn exclusive_scan_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
) -> Result<()> {
unsafe { self.exclusive_scan(buf.as_u64(), count, dtype, op).await }
}
pub async fn all_reduce_compressed_device(
&self,
buf: &mut BufferRef<Device>,
count: usize,
dtype: DataType,
op: ReduceOp,
compressor: &dyn Compressor,
residual: &mut [u8],
) -> Result<()> {
unsafe {
self.all_reduce_compressed(buf.as_u64(), count, dtype, op, compressor, residual)
.await
}
}
}