1use crate::device_mesh::{DType, MeshComm, ReduceOp};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10#[derive(Debug, Default)]
15pub struct MockCommState {
16 pub rank_buffers: HashMap<usize, Vec<u8>>,
18 pub world_size: usize,
20 barrier_count: usize,
22}
23
24impl MockCommState {
25 pub fn new(world_size: usize) -> Self {
27 Self {
28 rank_buffers: HashMap::new(),
29 world_size,
30 barrier_count: 0,
31 }
32 }
33}
34
35#[derive(Debug)]
40pub struct MockComm {
41 pub rank: usize,
43 pub world_size: usize,
45 state: Arc<RwLock<MockCommState>>,
47}
48
49impl MockComm {
50 pub fn new(rank: usize, world_size: usize) -> Self {
52 Self {
53 rank,
54 world_size,
55 state: Arc::new(RwLock::new(MockCommState::new(world_size))),
56 }
57 }
58
59 pub fn with_shared_state(rank: usize, state: Arc<RwLock<MockCommState>>) -> Self {
61 let world_size = state.read().unwrap().world_size;
62 Self {
63 rank,
64 world_size,
65 state,
66 }
67 }
68
69 pub fn create_group(world_size: usize) -> Vec<Self> {
71 let state = Arc::new(RwLock::new(MockCommState::new(world_size)));
72 (0..world_size)
73 .map(|rank| Self::with_shared_state(rank, state.clone()))
74 .collect()
75 }
76
77 fn reduce_f32(values: &[f32], op: ReduceOp) -> f32 {
79 match op {
80 ReduceOp::Sum => values.iter().sum(),
81 ReduceOp::Product => values.iter().product(),
82 ReduceOp::Min => values.iter().cloned().fold(f32::INFINITY, f32::min),
83 ReduceOp::Max => values.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
84 ReduceOp::Average => values.iter().sum::<f32>() / values.len() as f32,
85 }
86 }
87
88 fn reduce_i32(values: &[i32], op: ReduceOp) -> i32 {
90 match op {
91 ReduceOp::Sum => values.iter().sum(),
92 ReduceOp::Product => values.iter().product(),
93 ReduceOp::Min => values.iter().cloned().min().unwrap_or(0),
94 ReduceOp::Max => values.iter().cloned().max().unwrap_or(0),
95 ReduceOp::Average => values.iter().sum::<i32>() / values.len().max(1) as i32,
96 }
97 }
98}
99
100impl MeshComm for MockComm {
101 fn all_reduce(
102 &self,
103 buf: &mut [u8],
104 dtype: DType,
105 op: ReduceOp,
106 _group: &str,
107 ) -> Result<(), String> {
108 if self.world_size == 1 {
110 return Ok(());
111 }
112
113 {
115 let mut state = self.state.write().map_err(|e| e.to_string())?;
116 state.rank_buffers.insert(self.rank, buf.to_vec());
117 }
118
119 let all_buffers: Vec<Vec<u8>> = {
122 let state = self.state.read().map_err(|e| e.to_string())?;
123 if state.rank_buffers.len() < self.world_size {
124 return Ok(());
127 }
128 (0..self.world_size)
129 .filter_map(|r| state.rank_buffers.get(&r).cloned())
130 .collect()
131 };
132
133 match dtype {
135 DType::Float32 => {
136 let elem_count = buf.len() / 4;
137 for i in 0..elem_count {
138 let values: Vec<f32> = all_buffers
139 .iter()
140 .map(|b| {
141 let bytes: [u8; 4] = b[i * 4..(i + 1) * 4].try_into().unwrap();
142 f32::from_le_bytes(bytes)
143 })
144 .collect();
145 let result = Self::reduce_f32(&values, op);
146 buf[i * 4..(i + 1) * 4].copy_from_slice(&result.to_le_bytes());
147 }
148 }
149 DType::Int32 => {
150 let elem_count = buf.len() / 4;
151 for i in 0..elem_count {
152 let values: Vec<i32> = all_buffers
153 .iter()
154 .map(|b| {
155 let bytes: [u8; 4] = b[i * 4..(i + 1) * 4].try_into().unwrap();
156 i32::from_le_bytes(bytes)
157 })
158 .collect();
159 let result = Self::reduce_i32(&values, op);
160 buf[i * 4..(i + 1) * 4].copy_from_slice(&result.to_le_bytes());
161 }
162 }
163 _ => {
164 }
166 }
167
168 Ok(())
169 }
170
171 fn all_gather(
172 &self,
173 local: &[u8],
174 out: &mut [u8],
175 _dtype: DType,
176 _group: &str,
177 ) -> Result<(), String> {
178 if self.world_size == 1 {
180 out[..local.len()].copy_from_slice(local);
181 return Ok(());
182 }
183
184 {
186 let mut state = self.state.write().map_err(|e| e.to_string())?;
187 state.rank_buffers.insert(self.rank, local.to_vec());
188 }
189
190 let chunk_size = local.len();
192 let state = self.state.read().map_err(|e| e.to_string())?;
193
194 for rank in 0..self.world_size {
195 let offset = rank * chunk_size;
196 if let Some(data) = state.rank_buffers.get(&rank) {
197 let copy_len = data.len().min(chunk_size);
198 out[offset..offset + copy_len].copy_from_slice(&data[..copy_len]);
199 }
200 }
201
202 Ok(())
203 }
204
205 fn broadcast(&self, buf: &mut [u8], root_rank: usize, _group: &str) -> Result<(), String> {
206 if self.rank == root_rank {
207 let mut state = self.state.write().map_err(|e| e.to_string())?;
209 state.rank_buffers.insert(root_rank, buf.to_vec());
210 } else {
211 let state = self.state.read().map_err(|e| e.to_string())?;
213 if let Some(root_data) = state.rank_buffers.get(&root_rank) {
214 let copy_len = root_data.len().min(buf.len());
215 buf[..copy_len].copy_from_slice(&root_data[..copy_len]);
216 }
217 }
218 Ok(())
219 }
220
221 fn reduce_scatter(
222 &self,
223 buf: &mut [u8],
224 out: &mut [u8],
225 op: ReduceOp,
226 group: &str,
227 ) -> Result<(), String> {
228 self.all_reduce(buf, DType::Float32, op, group)?;
230
231 let chunk_size = buf.len() / self.world_size;
233 let offset = self.rank * chunk_size;
234 let copy_len = chunk_size.min(out.len());
235 out[..copy_len].copy_from_slice(&buf[offset..offset + copy_len]);
236
237 Ok(())
238 }
239
240 fn barrier(&self, _group: &str) -> Result<(), String> {
241 let mut state = self.state.write().map_err(|e| e.to_string())?;
243 state.barrier_count += 1;
244 Ok(())
246 }
247
248 fn send(&self, buf: &[u8], dest_rank: usize) -> Result<(), String> {
249 if dest_rank >= self.world_size {
250 return Err(format!("Invalid dest rank {}", dest_rank));
251 }
252 let key = self.rank * 1000 + dest_rank; let mut state = self.state.write().map_err(|e| e.to_string())?;
255 state.rank_buffers.insert(key, buf.to_vec());
256 Ok(())
257 }
258
259 fn recv(&self, buf: &mut [u8], src_rank: usize) -> Result<(), String> {
260 if src_rank >= self.world_size {
261 return Err(format!("Invalid src rank {}", src_rank));
262 }
263 let key = src_rank * 1000 + self.rank;
264 let state = self.state.read().map_err(|e| e.to_string())?;
265 if let Some(data) = state.rank_buffers.get(&key) {
266 let copy_len = data.len().min(buf.len());
267 buf[..copy_len].copy_from_slice(&data[..copy_len]);
268 Ok(())
269 } else {
270 Err(format!("No data from rank {}", src_rank))
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_mock_comm_single_rank() {
281 let comm = MockComm::new(0, 1);
282 let buf = [1.0f32, 2.0, 3.0, 4.0];
283 let bytes: Vec<u8> = buf.iter().flat_map(|f| f.to_le_bytes()).collect();
284 let mut byte_buf = bytes;
285
286 comm.all_reduce(&mut byte_buf, DType::Float32, ReduceOp::Sum, "world")
287 .unwrap();
288
289 let result: Vec<f32> = byte_buf
291 .chunks(4)
292 .map(|c| f32::from_le_bytes(c.try_into().unwrap()))
293 .collect();
294 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
295 }
296
297 #[test]
298 fn test_mock_comm_broadcast() {
299 let comms = MockComm::create_group(2);
300
301 let mut root_buf = vec![0u8; 8];
303 root_buf[..4].copy_from_slice(&42.0f32.to_le_bytes());
304 root_buf[4..8].copy_from_slice(&24.0f32.to_le_bytes());
305
306 comms[0].broadcast(&mut root_buf, 0, "world").unwrap();
307
308 let mut recv_buf = vec![0u8; 8];
310 comms[1].broadcast(&mut recv_buf, 0, "world").unwrap();
311
312 assert_eq!(root_buf, recv_buf);
313 }
314
315 #[test]
316 fn test_mock_comm_send_recv() {
317 let comms = MockComm::create_group(2);
318
319 let send_data = vec![1u8, 2, 3, 4];
320 comms[0].send(&send_data, 1).unwrap();
321
322 let mut recv_buf = vec![0u8; 4];
323 comms[1].recv(&mut recv_buf, 0).unwrap();
324
325 assert_eq!(recv_buf, send_data);
326 }
327
328 #[test]
329 fn test_mock_comm_barrier() {
330 let comm = MockComm::new(0, 4);
331 assert!(comm.barrier("world").is_ok());
332 }
333
334 #[test]
335 fn test_mock_comm_all_gather() {
336 let comms = MockComm::create_group(2);
337
338 let local0 = vec![1u8, 2];
340 let local1 = vec![3u8, 4];
341
342 let mut out0 = vec![0u8; 4];
343 let mut out1 = vec![0u8; 4];
344
345 comms[0]
346 .all_gather(&local0, &mut out0, DType::UInt8, "world")
347 .unwrap();
348 comms[1]
349 .all_gather(&local1, &mut out1, DType::UInt8, "world")
350 .unwrap();
351
352 assert_eq!(&out0[0..2], &local0[..]);
354 assert_eq!(&out1[2..4], &local1[..]);
355 }
356}