use crate::device_mesh::{DType, MeshComm, ReduceOp};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Default)]
pub struct MockCommState {
pub rank_buffers: HashMap<usize, Vec<u8>>,
pub world_size: usize,
barrier_count: usize,
}
impl MockCommState {
pub fn new(world_size: usize) -> Self {
Self {
rank_buffers: HashMap::new(),
world_size,
barrier_count: 0,
}
}
}
#[derive(Debug)]
pub struct MockComm {
pub rank: usize,
pub world_size: usize,
state: Arc<RwLock<MockCommState>>,
}
impl MockComm {
pub fn new(rank: usize, world_size: usize) -> Self {
Self {
rank,
world_size,
state: Arc::new(RwLock::new(MockCommState::new(world_size))),
}
}
pub fn with_shared_state(rank: usize, state: Arc<RwLock<MockCommState>>) -> Self {
let world_size = state.read().unwrap().world_size;
Self {
rank,
world_size,
state,
}
}
pub fn create_group(world_size: usize) -> Vec<Self> {
let state = Arc::new(RwLock::new(MockCommState::new(world_size)));
(0..world_size)
.map(|rank| Self::with_shared_state(rank, state.clone()))
.collect()
}
fn reduce_f32(values: &[f32], op: ReduceOp) -> f32 {
match op {
ReduceOp::Sum => values.iter().sum(),
ReduceOp::Product => values.iter().product(),
ReduceOp::Min => values.iter().cloned().fold(f32::INFINITY, f32::min),
ReduceOp::Max => values.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
ReduceOp::Average => values.iter().sum::<f32>() / values.len() as f32,
}
}
fn reduce_i32(values: &[i32], op: ReduceOp) -> i32 {
match op {
ReduceOp::Sum => values.iter().sum(),
ReduceOp::Product => values.iter().product(),
ReduceOp::Min => values.iter().cloned().min().unwrap_or(0),
ReduceOp::Max => values.iter().cloned().max().unwrap_or(0),
ReduceOp::Average => values.iter().sum::<i32>() / values.len().max(1) as i32,
}
}
}
impl MeshComm for MockComm {
fn all_reduce(
&self,
buf: &mut [u8],
dtype: DType,
op: ReduceOp,
_group: &str,
) -> Result<(), String> {
if self.world_size == 1 {
return Ok(());
}
{
let mut state = self.state.write().map_err(|e| e.to_string())?;
state.rank_buffers.insert(self.rank, buf.to_vec());
}
let all_buffers: Vec<Vec<u8>> = {
let state = self.state.read().map_err(|e| e.to_string())?;
if state.rank_buffers.len() < self.world_size {
return Ok(());
}
(0..self.world_size)
.filter_map(|r| state.rank_buffers.get(&r).cloned())
.collect()
};
match dtype {
DType::Float32 => {
let elem_count = buf.len() / 4;
for i in 0..elem_count {
let values: Vec<f32> = all_buffers
.iter()
.map(|b| {
let bytes: [u8; 4] = b[i * 4..(i + 1) * 4].try_into().unwrap();
f32::from_le_bytes(bytes)
})
.collect();
let result = Self::reduce_f32(&values, op);
buf[i * 4..(i + 1) * 4].copy_from_slice(&result.to_le_bytes());
}
}
DType::Int32 => {
let elem_count = buf.len() / 4;
for i in 0..elem_count {
let values: Vec<i32> = all_buffers
.iter()
.map(|b| {
let bytes: [u8; 4] = b[i * 4..(i + 1) * 4].try_into().unwrap();
i32::from_le_bytes(bytes)
})
.collect();
let result = Self::reduce_i32(&values, op);
buf[i * 4..(i + 1) * 4].copy_from_slice(&result.to_le_bytes());
}
}
_ => {
}
}
Ok(())
}
fn all_gather(
&self,
local: &[u8],
out: &mut [u8],
_dtype: DType,
_group: &str,
) -> Result<(), String> {
if self.world_size == 1 {
out[..local.len()].copy_from_slice(local);
return Ok(());
}
{
let mut state = self.state.write().map_err(|e| e.to_string())?;
state.rank_buffers.insert(self.rank, local.to_vec());
}
let chunk_size = local.len();
let state = self.state.read().map_err(|e| e.to_string())?;
for rank in 0..self.world_size {
let offset = rank * chunk_size;
if let Some(data) = state.rank_buffers.get(&rank) {
let copy_len = data.len().min(chunk_size);
out[offset..offset + copy_len].copy_from_slice(&data[..copy_len]);
}
}
Ok(())
}
fn broadcast(&self, buf: &mut [u8], root_rank: usize, _group: &str) -> Result<(), String> {
if self.rank == root_rank {
let mut state = self.state.write().map_err(|e| e.to_string())?;
state.rank_buffers.insert(root_rank, buf.to_vec());
} else {
let state = self.state.read().map_err(|e| e.to_string())?;
if let Some(root_data) = state.rank_buffers.get(&root_rank) {
let copy_len = root_data.len().min(buf.len());
buf[..copy_len].copy_from_slice(&root_data[..copy_len]);
}
}
Ok(())
}
fn reduce_scatter(
&self,
buf: &mut [u8],
out: &mut [u8],
op: ReduceOp,
group: &str,
) -> Result<(), String> {
self.all_reduce(buf, DType::Float32, op, group)?;
let chunk_size = buf.len() / self.world_size;
let offset = self.rank * chunk_size;
let copy_len = chunk_size.min(out.len());
out[..copy_len].copy_from_slice(&buf[offset..offset + copy_len]);
Ok(())
}
fn barrier(&self, _group: &str) -> Result<(), String> {
let mut state = self.state.write().map_err(|e| e.to_string())?;
state.barrier_count += 1;
Ok(())
}
fn send(&self, buf: &[u8], dest_rank: usize) -> Result<(), String> {
if dest_rank >= self.world_size {
return Err(format!("Invalid dest rank {}", dest_rank));
}
let key = self.rank * 1000 + dest_rank; let mut state = self.state.write().map_err(|e| e.to_string())?;
state.rank_buffers.insert(key, buf.to_vec());
Ok(())
}
fn recv(&self, buf: &mut [u8], src_rank: usize) -> Result<(), String> {
if src_rank >= self.world_size {
return Err(format!("Invalid src rank {}", src_rank));
}
let key = src_rank * 1000 + self.rank;
let state = self.state.read().map_err(|e| e.to_string())?;
if let Some(data) = state.rank_buffers.get(&key) {
let copy_len = data.len().min(buf.len());
buf[..copy_len].copy_from_slice(&data[..copy_len]);
Ok(())
} else {
Err(format!("No data from rank {}", src_rank))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_comm_single_rank() {
let comm = MockComm::new(0, 1);
let buf = [1.0f32, 2.0, 3.0, 4.0];
let bytes: Vec<u8> = buf.iter().flat_map(|f| f.to_le_bytes()).collect();
let mut byte_buf = bytes;
comm.all_reduce(&mut byte_buf, DType::Float32, ReduceOp::Sum, "world")
.unwrap();
let result: Vec<f32> = byte_buf
.chunks(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_mock_comm_broadcast() {
let comms = MockComm::create_group(2);
let mut root_buf = vec![0u8; 8];
root_buf[..4].copy_from_slice(&42.0f32.to_le_bytes());
root_buf[4..8].copy_from_slice(&24.0f32.to_le_bytes());
comms[0].broadcast(&mut root_buf, 0, "world").unwrap();
let mut recv_buf = vec![0u8; 8];
comms[1].broadcast(&mut recv_buf, 0, "world").unwrap();
assert_eq!(root_buf, recv_buf);
}
#[test]
fn test_mock_comm_send_recv() {
let comms = MockComm::create_group(2);
let send_data = vec![1u8, 2, 3, 4];
comms[0].send(&send_data, 1).unwrap();
let mut recv_buf = vec![0u8; 4];
comms[1].recv(&mut recv_buf, 0).unwrap();
assert_eq!(recv_buf, send_data);
}
#[test]
fn test_mock_comm_barrier() {
let comm = MockComm::new(0, 4);
assert!(comm.barrier("world").is_ok());
}
#[test]
fn test_mock_comm_all_gather() {
let comms = MockComm::create_group(2);
let local0 = vec![1u8, 2];
let local1 = vec![3u8, 4];
let mut out0 = vec![0u8; 4];
let mut out1 = vec![0u8; 4];
comms[0]
.all_gather(&local0, &mut out0, DType::UInt8, "world")
.unwrap();
comms[1]
.all_gather(&local1, &mut out1, DType::UInt8, "world")
.unwrap();
assert_eq!(&out0[0..2], &local0[..]);
assert_eq!(&out1[2..4], &local1[..]);
}
}