Skip to main content

kapsl_hal/
mock_comm.rs

1//! Mock Communication Backend for Testing
2//!
3//! Provides a single-process implementation of `MeshComm` that simulates
4//! distributed operations for testing without requiring real multi-GPU setup.
5
6use crate::device_mesh::{DType, MeshComm, ReduceOp};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10/// Shared state for simulating multi-rank communication in a single process.
11///
12/// Each "rank" has its own buffer space, and operations coordinate through
13/// this shared state.
14#[derive(Debug, Default)]
15pub struct MockCommState {
16    /// Per-rank buffers for simulating communication
17    pub rank_buffers: HashMap<usize, Vec<u8>>,
18    /// World size for this communication group
19    pub world_size: usize,
20    /// Barrier counter for synchronization
21    barrier_count: usize,
22}
23
24impl MockCommState {
25    /// Create new shared state for a given world size
26    pub fn new(world_size: usize) -> Self {
27        Self {
28            rank_buffers: HashMap::new(),
29            world_size,
30            barrier_count: 0,
31        }
32    }
33}
34
35/// Mock communication backend for testing distributed operations.
36///
37/// This implementation simulates collective operations in-memory,
38/// allowing testing of distributed code paths on a single machine.
39#[derive(Debug)]
40pub struct MockComm {
41    /// This rank's ID in the mesh
42    pub rank: usize,
43    /// Total number of ranks
44    pub world_size: usize,
45    /// Shared state for coordinating between simulated ranks
46    state: Arc<RwLock<MockCommState>>,
47}
48
49impl MockComm {
50    /// Create a new MockComm for a specific rank
51    pub fn new(rank: usize, world_size: usize) -> Self {
52        Self {
53            rank,
54            world_size,
55            state: Arc::new(RwLock::new(MockCommState::new(world_size))),
56        }
57    }
58
59    /// Create a new MockComm with shared state (for multi-rank simulation)
60    pub fn with_shared_state(rank: usize, state: Arc<RwLock<MockCommState>>) -> Self {
61        let world_size = state.read().unwrap().world_size;
62        Self {
63            rank,
64            world_size,
65            state,
66        }
67    }
68
69    /// Create a group of MockComm instances that share state
70    pub fn create_group(world_size: usize) -> Vec<Self> {
71        let state = Arc::new(RwLock::new(MockCommState::new(world_size)));
72        (0..world_size)
73            .map(|rank| Self::with_shared_state(rank, state.clone()))
74            .collect()
75    }
76
77    /// Helper to apply a reduction operation on f32 values
78    fn reduce_f32(values: &[f32], op: ReduceOp) -> f32 {
79        match op {
80            ReduceOp::Sum => values.iter().sum(),
81            ReduceOp::Product => values.iter().product(),
82            ReduceOp::Min => values.iter().cloned().fold(f32::INFINITY, f32::min),
83            ReduceOp::Max => values.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
84            ReduceOp::Average => values.iter().sum::<f32>() / values.len() as f32,
85        }
86    }
87
88    /// Helper to apply a reduction operation on i32 values
89    fn reduce_i32(values: &[i32], op: ReduceOp) -> i32 {
90        match op {
91            ReduceOp::Sum => values.iter().sum(),
92            ReduceOp::Product => values.iter().product(),
93            ReduceOp::Min => values.iter().cloned().min().unwrap_or(0),
94            ReduceOp::Max => values.iter().cloned().max().unwrap_or(0),
95            ReduceOp::Average => values.iter().sum::<i32>() / values.len().max(1) as i32,
96        }
97    }
98}
99
100impl MeshComm for MockComm {
101    fn all_reduce(
102        &self,
103        buf: &mut [u8],
104        dtype: DType,
105        op: ReduceOp,
106        _group: &str,
107    ) -> Result<(), String> {
108        // In single-rank mode, the buffer is already the result
109        if self.world_size == 1 {
110            return Ok(());
111        }
112
113        // Store our data in shared state
114        {
115            let mut state = self.state.write().map_err(|e| e.to_string())?;
116            state.rank_buffers.insert(self.rank, buf.to_vec());
117        }
118
119        // Wait for all ranks (in real impl, this would be a barrier)
120        // For mock, we simulate by checking if all buffers are present
121        let all_buffers: Vec<Vec<u8>> = {
122            let state = self.state.read().map_err(|e| e.to_string())?;
123            if state.rank_buffers.len() < self.world_size {
124                // Not all ranks have contributed yet - in real impl we'd wait
125                // For mock, just use our own buffer
126                return Ok(());
127            }
128            (0..self.world_size)
129                .filter_map(|r| state.rank_buffers.get(&r).cloned())
130                .collect()
131        };
132
133        // Perform reduction based on dtype
134        match dtype {
135            DType::Float32 => {
136                let elem_count = buf.len() / 4;
137                for i in 0..elem_count {
138                    let values: Vec<f32> = all_buffers
139                        .iter()
140                        .map(|b| {
141                            let bytes: [u8; 4] = b[i * 4..(i + 1) * 4].try_into().unwrap();
142                            f32::from_le_bytes(bytes)
143                        })
144                        .collect();
145                    let result = Self::reduce_f32(&values, op);
146                    buf[i * 4..(i + 1) * 4].copy_from_slice(&result.to_le_bytes());
147                }
148            }
149            DType::Int32 => {
150                let elem_count = buf.len() / 4;
151                for i in 0..elem_count {
152                    let values: Vec<i32> = all_buffers
153                        .iter()
154                        .map(|b| {
155                            let bytes: [u8; 4] = b[i * 4..(i + 1) * 4].try_into().unwrap();
156                            i32::from_le_bytes(bytes)
157                        })
158                        .collect();
159                    let result = Self::reduce_i32(&values, op);
160                    buf[i * 4..(i + 1) * 4].copy_from_slice(&result.to_le_bytes());
161                }
162            }
163            _ => {
164                // For other dtypes, just keep our buffer (no-op reduction)
165            }
166        }
167
168        Ok(())
169    }
170
171    fn all_gather(
172        &self,
173        local: &[u8],
174        out: &mut [u8],
175        _dtype: DType,
176        _group: &str,
177    ) -> Result<(), String> {
178        // In single-rank mode, just copy local to output
179        if self.world_size == 1 {
180            out[..local.len()].copy_from_slice(local);
181            return Ok(());
182        }
183
184        // Store our local data
185        {
186            let mut state = self.state.write().map_err(|e| e.to_string())?;
187            state.rank_buffers.insert(self.rank, local.to_vec());
188        }
189
190        // Gather from all ranks
191        let chunk_size = local.len();
192        let state = self.state.read().map_err(|e| e.to_string())?;
193
194        for rank in 0..self.world_size {
195            let offset = rank * chunk_size;
196            if let Some(data) = state.rank_buffers.get(&rank) {
197                let copy_len = data.len().min(chunk_size);
198                out[offset..offset + copy_len].copy_from_slice(&data[..copy_len]);
199            }
200        }
201
202        Ok(())
203    }
204
205    fn broadcast(&self, buf: &mut [u8], root_rank: usize, _group: &str) -> Result<(), String> {
206        if self.rank == root_rank {
207            // Root stores its data
208            let mut state = self.state.write().map_err(|e| e.to_string())?;
209            state.rank_buffers.insert(root_rank, buf.to_vec());
210        } else {
211            // Non-root reads from root's buffer
212            let state = self.state.read().map_err(|e| e.to_string())?;
213            if let Some(root_data) = state.rank_buffers.get(&root_rank) {
214                let copy_len = root_data.len().min(buf.len());
215                buf[..copy_len].copy_from_slice(&root_data[..copy_len]);
216            }
217        }
218        Ok(())
219    }
220
221    fn reduce_scatter(
222        &self,
223        buf: &mut [u8],
224        out: &mut [u8],
225        op: ReduceOp,
226        group: &str,
227    ) -> Result<(), String> {
228        // First do all-reduce
229        self.all_reduce(buf, DType::Float32, op, group)?;
230
231        // Then scatter - each rank gets its chunk
232        let chunk_size = buf.len() / self.world_size;
233        let offset = self.rank * chunk_size;
234        let copy_len = chunk_size.min(out.len());
235        out[..copy_len].copy_from_slice(&buf[offset..offset + copy_len]);
236
237        Ok(())
238    }
239
240    fn barrier(&self, _group: &str) -> Result<(), String> {
241        // Increment barrier counter
242        let mut state = self.state.write().map_err(|e| e.to_string())?;
243        state.barrier_count += 1;
244        // In a real implementation, we'd wait for all ranks to reach the barrier
245        Ok(())
246    }
247
248    fn send(&self, buf: &[u8], dest_rank: usize) -> Result<(), String> {
249        if dest_rank >= self.world_size {
250            return Err(format!("Invalid dest rank {}", dest_rank));
251        }
252        // Store in shared state with a special key
253        let key = self.rank * 1000 + dest_rank; // Unique key for src->dst
254        let mut state = self.state.write().map_err(|e| e.to_string())?;
255        state.rank_buffers.insert(key, buf.to_vec());
256        Ok(())
257    }
258
259    fn recv(&self, buf: &mut [u8], src_rank: usize) -> Result<(), String> {
260        if src_rank >= self.world_size {
261            return Err(format!("Invalid src rank {}", src_rank));
262        }
263        let key = src_rank * 1000 + self.rank;
264        let state = self.state.read().map_err(|e| e.to_string())?;
265        if let Some(data) = state.rank_buffers.get(&key) {
266            let copy_len = data.len().min(buf.len());
267            buf[..copy_len].copy_from_slice(&data[..copy_len]);
268            Ok(())
269        } else {
270            Err(format!("No data from rank {}", src_rank))
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_mock_comm_single_rank() {
281        let comm = MockComm::new(0, 1);
282        let buf = [1.0f32, 2.0, 3.0, 4.0];
283        let bytes: Vec<u8> = buf.iter().flat_map(|f| f.to_le_bytes()).collect();
284        let mut byte_buf = bytes;
285
286        comm.all_reduce(&mut byte_buf, DType::Float32, ReduceOp::Sum, "world")
287            .unwrap();
288
289        // Single rank, buffer unchanged
290        let result: Vec<f32> = byte_buf
291            .chunks(4)
292            .map(|c| f32::from_le_bytes(c.try_into().unwrap()))
293            .collect();
294        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
295    }
296
297    #[test]
298    fn test_mock_comm_broadcast() {
299        let comms = MockComm::create_group(2);
300
301        // Root broadcasts
302        let mut root_buf = vec![0u8; 8];
303        root_buf[..4].copy_from_slice(&42.0f32.to_le_bytes());
304        root_buf[4..8].copy_from_slice(&24.0f32.to_le_bytes());
305
306        comms[0].broadcast(&mut root_buf, 0, "world").unwrap();
307
308        // Non-root receives
309        let mut recv_buf = vec![0u8; 8];
310        comms[1].broadcast(&mut recv_buf, 0, "world").unwrap();
311
312        assert_eq!(root_buf, recv_buf);
313    }
314
315    #[test]
316    fn test_mock_comm_send_recv() {
317        let comms = MockComm::create_group(2);
318
319        let send_data = vec![1u8, 2, 3, 4];
320        comms[0].send(&send_data, 1).unwrap();
321
322        let mut recv_buf = vec![0u8; 4];
323        comms[1].recv(&mut recv_buf, 0).unwrap();
324
325        assert_eq!(recv_buf, send_data);
326    }
327
328    #[test]
329    fn test_mock_comm_barrier() {
330        let comm = MockComm::new(0, 4);
331        assert!(comm.barrier("world").is_ok());
332    }
333
334    #[test]
335    fn test_mock_comm_all_gather() {
336        let comms = MockComm::create_group(2);
337
338        // Each rank contributes its data
339        let local0 = vec![1u8, 2];
340        let local1 = vec![3u8, 4];
341
342        let mut out0 = vec![0u8; 4];
343        let mut out1 = vec![0u8; 4];
344
345        comms[0]
346            .all_gather(&local0, &mut out0, DType::UInt8, "world")
347            .unwrap();
348        comms[1]
349            .all_gather(&local1, &mut out1, DType::UInt8, "world")
350            .unwrap();
351
352        // Both should have gathered data
353        assert_eq!(&out0[0..2], &local0[..]);
354        assert_eq!(&out1[2..4], &local1[..]);
355    }
356}