use crate::device_mesh::{DType, DeviceMesh, MeshComm, ReduceOp};
use std::sync::Arc;
pub type DistResult<T> = Result<T, String>;
pub struct DistributedOps<'a> {
mesh: &'a DeviceMesh,
}
impl<'a> DistributedOps<'a> {
pub fn new(mesh: &'a DeviceMesh) -> Self {
Self { mesh }
}
fn comm(&self) -> DistResult<&Arc<dyn MeshComm + Send + Sync>> {
self.mesh
.comm
.as_ref()
.ok_or_else(|| "No communication backend configured".to_string())
}
pub fn all_reduce_f32(&self, data: &mut [f32], op: ReduceOp, group: &str) -> DistResult<()> {
let comm = self.comm()?;
let byte_slice =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
comm.all_reduce(byte_slice, DType::Float32, op, group)
}
pub fn all_reduce_i32(&self, data: &mut [i32], op: ReduceOp, group: &str) -> DistResult<()> {
let comm = self.comm()?;
let byte_slice =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
comm.all_reduce(byte_slice, DType::Int32, op, group)
}
pub fn broadcast_f32(&self, data: &mut [f32], root_rank: usize, group: &str) -> DistResult<()> {
let comm = self.comm()?;
let byte_slice =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
comm.broadcast(byte_slice, root_rank, group)
}
pub fn all_gather_f32(&self, local: &[f32], output: &mut [f32], group: &str) -> DistResult<()> {
let comm = self.comm()?;
let local_bytes =
unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, local.len() * 4) };
let output_bytes = unsafe {
std::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut u8, output.len() * 4)
};
comm.all_gather(local_bytes, output_bytes, DType::Float32, group)
}
pub fn scatter_f32(&self, data: &[f32], chunk: &mut [f32], root_rank: usize) -> DistResult<()> {
let comm = self.comm()?;
if self.mesh.rank == root_rank {
let data_bytes =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
let mut broadcast_buf = data_bytes.to_vec();
comm.broadcast(&mut broadcast_buf, root_rank, "world")?;
let chunk_size = chunk.len();
let offset = self.mesh.rank * chunk_size;
for (i, val) in chunk.iter_mut().enumerate() {
let src_idx = offset + i;
if src_idx < data.len() {
*val = data[src_idx];
}
}
} else {
let total_size = data.len();
let mut broadcast_buf = vec![0u8; total_size * 4];
comm.broadcast(&mut broadcast_buf, root_rank, "world")?;
let chunk_size = chunk.len();
let offset = self.mesh.rank * chunk_size;
for (i, val) in chunk.iter_mut().enumerate() {
let idx = (offset + i) * 4;
if idx + 4 <= broadcast_buf.len() {
let bytes: [u8; 4] = broadcast_buf[idx..idx + 4].try_into().unwrap();
*val = f32::from_le_bytes(bytes);
}
}
}
Ok(())
}
pub fn gather_f32(
&self,
local: &[f32],
output: &mut [f32],
root_rank: usize,
) -> DistResult<()> {
self.all_gather_f32(local, output, "world")?;
if self.mesh.rank != root_rank {
}
Ok(())
}
pub fn reduce_scatter_f32(
&self,
data: &mut [f32],
output: &mut [f32],
op: ReduceOp,
group: &str,
) -> DistResult<()> {
let comm = self.comm()?;
let data_bytes =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
let output_bytes = unsafe {
std::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut u8, output.len() * 4)
};
comm.reduce_scatter(data_bytes, output_bytes, op, group)
}
pub fn barrier(&self, group: &str) -> DistResult<()> {
let comm = self.comm()?;
comm.barrier(group)
}
pub fn send_f32(&self, data: &[f32], dest_rank: usize) -> DistResult<()> {
let comm = self.comm()?;
let bytes =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
comm.send(bytes, dest_rank)
}
pub fn recv_f32(&self, data: &mut [f32], src_rank: usize) -> DistResult<()> {
let comm = self.comm()?;
let bytes =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
comm.recv(bytes, src_rank)
}
}
pub fn dist_ops(mesh: &DeviceMesh) -> DistributedOps<'_> {
DistributedOps::new(mesh)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::{Device, DeviceBackend};
use crate::device_mesh::DeviceMesh;
fn create_test_devices(count: usize) -> Vec<Device> {
(0..count)
.map(|i| Device {
id: i,
name: format!("GPU_{}", i),
backend: DeviceBackend::Cuda,
memory_mb: 16000,
compute_units: 80,
pci_bus_id: None,
partition_id: None,
driver_version: None,
compute_capability: None,
utilization_gpu_pct: None,
temperature_c: None,
supports_fp16: true,
supports_int8: true,
cuda_version: Some("12.0".to_string()),
})
.collect()
}
#[test]
fn test_all_reduce_f32_single_rank() {
let devices = create_test_devices(1);
let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
let ops = dist_ops(&mesh);
let mut data = vec![1.0f32, 2.0, 3.0, 4.0];
let result = ops.all_reduce_f32(&mut data, ReduceOp::Sum, "world");
assert!(result.is_ok());
assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_broadcast_f32() {
let devices = create_test_devices(2);
let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
let ops = dist_ops(&mesh);
let mut data = vec![42.0f32, 24.0];
let result = ops.broadcast_f32(&mut data, 0, "world");
assert!(result.is_ok());
}
#[test]
fn test_barrier() {
let devices = create_test_devices(4);
let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
let ops = dist_ops(&mesh);
let result = ops.barrier("world");
assert!(result.is_ok());
}
#[test]
fn test_send_recv_f32() {
use crate::mock_comm::MockComm;
use std::sync::{Arc, RwLock};
let state = Arc::new(RwLock::new(crate::mock_comm::MockCommState::new(2)));
let devices = create_test_devices(2);
let mut mesh0 = DeviceMesh::new(devices.clone());
mesh0.rank = 0;
mesh0.comm = Some(Arc::new(MockComm::with_shared_state(0, state.clone())));
let mut mesh1 = DeviceMesh::new(devices);
mesh1.rank = 1;
mesh1.comm = Some(Arc::new(MockComm::with_shared_state(1, state)));
let ops0 = dist_ops(&mesh0);
let send_data = vec![1.0f32, 2.0, 3.0];
ops0.send_f32(&send_data, 1).unwrap();
let ops1 = dist_ops(&mesh1);
let mut recv_data = vec![0.0f32; 3];
ops1.recv_f32(&mut recv_data, 0).unwrap();
assert_eq!(recv_data, send_data);
}
#[test]
fn test_no_comm_backend_error() {
let devices = create_test_devices(2);
let mesh = DeviceMesh::new(devices);
let ops = dist_ops(&mesh);
let mut data = vec![1.0f32];
let result = ops.all_reduce_f32(&mut data, ReduceOp::Sum, "world");
assert!(result.is_err());
assert!(result.unwrap_err().contains("No communication backend"));
}
}