use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::runtime::{Communicator, ReduceOp, Runtime};
use numr::tensor::Tensor;
pub fn all_reduce_tensor<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
tensor: &Tensor<R>,
op: ReduceOp,
) -> Result<()> {
if !tensor.is_contiguous() {
return Err(Error::DistributedError {
reason: "all_reduce requires contiguous tensor".to_string(),
});
}
let ptr = tensor.ptr();
let count = tensor.numel();
let dtype = tensor.dtype();
unsafe {
comm.all_reduce(ptr, count, dtype, op)
.map_err(|e| Error::DistributedError {
reason: format!("all_reduce failed: {e}"),
})?;
}
Ok(())
}
pub fn send_tensor<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
tensor: &Tensor<R>,
dest: usize,
tag: u32,
) -> Result<()> {
if !tensor.is_contiguous() {
return Err(Error::DistributedError {
reason: "send requires contiguous tensor".to_string(),
});
}
let ptr = tensor.ptr();
let count = tensor.numel();
let dtype = tensor.dtype();
unsafe {
comm.send(ptr, count, dtype, dest, tag)
.map_err(|e| Error::DistributedError {
reason: format!("send to rank {dest} failed: {e}"),
})?;
}
Ok(())
}
pub fn recv_into_tensor<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
buffer: &Tensor<R>,
src: usize,
tag: u32,
) -> Result<()> {
if !buffer.is_contiguous() {
return Err(Error::DistributedError {
reason: "recv buffer must be contiguous".to_string(),
});
}
let ptr = buffer.ptr();
let count = buffer.numel();
let dtype = buffer.dtype();
unsafe {
comm.recv(ptr, count, dtype, src, tag)
.map_err(|e| Error::DistributedError {
reason: format!("recv from rank {src} failed: {e}"),
})?;
}
Ok(())
}
pub fn broadcast_tensor<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
tensor: &Tensor<R>,
root: usize,
) -> Result<()> {
if !tensor.is_contiguous() {
return Err(Error::DistributedError {
reason: "broadcast requires contiguous tensor".to_string(),
});
}
let ptr = tensor.ptr();
let count = tensor.numel();
let dtype = tensor.dtype();
unsafe {
comm.broadcast(ptr, count, dtype, root)
.map_err(|e| Error::DistributedError {
reason: format!("broadcast failed: {e}"),
})?;
}
Ok(())
}
pub fn reduce_scatter_tensor<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
send: &Tensor<R>,
recv: &Tensor<R>,
op: ReduceOp,
) -> Result<()> {
if !send.is_contiguous() {
return Err(Error::DistributedError {
reason: "reduce_scatter send tensor must be contiguous".to_string(),
});
}
if !recv.is_contiguous() {
return Err(Error::DistributedError {
reason: "reduce_scatter recv tensor must be contiguous".to_string(),
});
}
let recv_count = recv.numel();
let dtype = send.dtype();
unsafe {
comm.reduce_scatter(send.ptr(), recv.ptr(), recv_count, dtype, op)
.map_err(|e| Error::DistributedError {
reason: format!("reduce_scatter failed: {e}"),
})?;
}
Ok(())
}
pub fn send_tensor_with_metadata<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
tensor: &Tensor<R>,
dest: usize,
tag: u32,
) -> Result<()> {
if !tensor.is_contiguous() {
return Err(Error::DistributedError {
reason: "send requires contiguous tensor".to_string(),
});
}
let shape = tensor.shape();
let ndim = shape.len();
let dtype = tensor.dtype();
let mut header: Vec<u64> = Vec::with_capacity(ndim + 2);
header.push(ndim as u64);
for &d in shape {
header.push(d as u64);
}
header.push(dtype_to_u64(dtype));
let header_count = header.len();
unsafe {
comm.send(header.as_ptr() as u64, header_count, DType::U64, dest, tag)
.map_err(|e| Error::DistributedError {
reason: format!("send header to rank {dest} failed: {e}"),
})?;
}
comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after header send failed: {e}"),
})?;
let ptr = tensor.ptr();
let count = tensor.numel();
unsafe {
comm.send(ptr, count, dtype, dest, tag + 1)
.map_err(|e| Error::DistributedError {
reason: format!("send data to rank {dest} failed: {e}"),
})?;
}
comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after data send failed: {e}"),
})?;
Ok(())
}
pub fn recv_tensor_with_metadata<R: Runtime<DType = DType>>(
comm: &dyn Communicator,
src: usize,
tag: u32,
device: &R::Device,
) -> Result<Tensor<R>> {
const MAX_HEADER: usize = 10;
let mut header_buf = [0u64; MAX_HEADER];
unsafe {
comm.recv(
header_buf.as_mut_ptr() as u64,
MAX_HEADER,
DType::U64,
src,
tag,
)
.map_err(|e| Error::DistributedError {
reason: format!("recv header from rank {src} failed: {e}"),
})?;
}
comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after header recv failed: {e}"),
})?;
let ndim = header_buf[0] as usize;
if ndim == 0 || ndim + 2 > MAX_HEADER {
return Err(Error::DistributedError {
reason: format!("invalid ndim {ndim} in recv header (max 8 dims)"),
});
}
let shape: Vec<usize> = header_buf[1..=ndim].iter().map(|&d| d as usize).collect();
let dtype = u64_to_dtype(header_buf[ndim + 1])?;
let buffer = Tensor::<R>::zeros(&shape, dtype, device);
let ptr = buffer.ptr();
let count = buffer.numel();
unsafe {
comm.recv(ptr, count, dtype, src, tag + 1)
.map_err(|e| Error::DistributedError {
reason: format!("recv data from rank {src} failed: {e}"),
})?;
}
comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after data recv failed: {e}"),
})?;
Ok(buffer)
}
fn dtype_to_u64(dtype: DType) -> u64 {
(dtype as u8) as u64
}
fn u64_to_dtype(val: u64) -> Result<DType> {
match val {
0 => Ok(DType::F64),
1 => Ok(DType::F32),
2 => Ok(DType::F16),
3 => Ok(DType::BF16),
4 => Ok(DType::FP8E4M3),
5 => Ok(DType::FP8E5M2),
10 => Ok(DType::I64),
11 => Ok(DType::I32),
12 => Ok(DType::I16),
13 => Ok(DType::I8),
20 => Ok(DType::U64),
21 => Ok(DType::U32),
22 => Ok(DType::U16),
23 => Ok(DType::U8),
30 => Ok(DType::Bool),
40 => Ok(DType::Complex64),
41 => Ok(DType::Complex128),
_ => Err(Error::DistributedError {
reason: format!("unknown dtype discriminant {val} in recv header"),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_all_reduce_tensor_noop() {
let (_client, device) = cpu_setup();
let comm = NoOpCommunicator;
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
all_reduce_tensor(&comm, &t, ReduceOp::Sum).unwrap();
let data = t.to_vec::<f32>();
assert_eq!(data, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_send_recv_tensor_noop() {
let (_client, device) = cpu_setup();
let comm = NoOpCommunicator;
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
send_tensor(&comm, &t, 0, 0).unwrap();
let buf = Tensor::<CpuRuntime>::zeros(&[2], DType::F32, &device);
recv_into_tensor(&comm, &buf, 0, 0).unwrap();
}
#[test]
fn test_broadcast_tensor_noop() {
let (_client, device) = cpu_setup();
let comm = NoOpCommunicator;
let t = Tensor::<CpuRuntime>::from_slice(&[5.0f32, 10.0], &[2], &device);
broadcast_tensor(&comm, &t, 0).unwrap();
let data = t.to_vec::<f32>();
assert_eq!(data, vec![5.0, 10.0]);
}
#[test]
fn test_dtype_roundtrip() {
let dtypes = [
DType::F32,
DType::F64,
DType::I32,
DType::I64,
DType::U8,
DType::U32,
DType::U64,
DType::F16,
DType::BF16,
DType::Bool,
];
for &dt in &dtypes {
let id = dtype_to_u64(dt);
let back = u64_to_dtype(id).unwrap();
assert_eq!(dt, back);
}
}
#[test]
fn test_u64_to_dtype_invalid() {
assert!(u64_to_dtype(99).is_err());
}
#[test]
fn test_send_with_metadata_noop() {
let (_client, device) = cpu_setup();
let comm = NoOpCommunicator;
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
send_tensor_with_metadata(&comm, &t, 0, 0).unwrap();
}
#[test]
fn test_reduce_scatter_tensor_noop() {
let (_client, device) = cpu_setup();
let comm = NoOpCommunicator;
let send = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let recv = Tensor::<CpuRuntime>::zeros(&[3], DType::F32, &device);
reduce_scatter_tensor(&comm, &send, &recv, ReduceOp::Sum).unwrap();
}
#[test]
fn test_recv_with_metadata_noop_returns_error() {
let (_client, device) = cpu_setup();
let comm = NoOpCommunicator;
let result = recv_tensor_with_metadata::<CpuRuntime>(&comm, 0, 0, &device);
assert!(result.is_err());
}
}