1use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ReduceOp {
18 Sum,
20 Product,
22 Min,
24 Max,
26 Average,
28}
29
30impl ReduceOp {
31 #[must_use]
33 pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
34 match self {
35 ReduceOp::Sum => a + b,
36 ReduceOp::Product => a * b,
37 ReduceOp::Min => a.min(b),
38 ReduceOp::Max => a.max(b),
39 ReduceOp::Average => (a + b) / 2.0,
40 }
41 }
42
43 #[must_use]
45 pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
46 if slices.is_empty() {
47 return Vec::new();
48 }
49
50 let len = slices[0].len();
51 let mut result = slices[0].clone();
52
53 for slice in slices.iter().skip(1) {
54 for (i, &val) in slice.iter().enumerate() {
55 if i < len {
56 result[i] = self.apply_f32(result[i], val);
57 }
58 }
59 }
60
61 if *self == ReduceOp::Average && slices.len() > 1 {
63 result = vec![0.0; len];
65 for slice in slices {
66 for (i, &val) in slice.iter().enumerate() {
67 if i < len {
68 result[i] += val;
69 }
70 }
71 }
72 let count = slices.len() as f32;
73 for val in &mut result {
74 *val /= count;
75 }
76 }
77
78 result
79 }
80}
81
82pub trait Backend: Send + Sync {
88 fn name(&self) -> &str;
90
91 fn rank(&self) -> usize;
93
94 fn world_size(&self) -> usize;
96
97 fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
99
100 fn broadcast(&self, data: &mut [f32], src: usize);
102
103 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
105
106 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
108
109 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
111
112 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
114
115 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
117
118 fn barrier(&self);
120
121 fn send(&self, data: &[f32], dst: usize, tag: usize);
123
124 fn recv(&self, data: &mut [f32], src: usize, tag: usize);
126}
127
128#[derive(Debug)]
134struct SharedState {
135 buffers: HashMap<usize, Vec<f32>>,
137 barrier_count: usize,
139 messages: HashMap<(usize, usize, usize), Vec<f32>>, }
142
143pub struct MockBackend {
150 rank: usize,
151 world_size: usize,
152 state: Arc<Mutex<SharedState>>,
153}
154
155impl MockBackend {
156 #[must_use]
158 pub fn create_world(world_size: usize) -> Vec<Self> {
159 let state = Arc::new(Mutex::new(SharedState {
160 buffers: HashMap::new(),
161 barrier_count: 0,
162 messages: HashMap::new(),
163 }));
164
165 (0..world_size)
166 .map(|rank| MockBackend {
167 rank,
168 world_size,
169 state: Arc::clone(&state),
170 })
171 .collect()
172 }
173
174 #[must_use]
176 pub fn single() -> Self {
177 MockBackend::create_world(1).pop().unwrap()
178 }
179}
180
181impl Backend for MockBackend {
182 fn name(&self) -> &'static str {
183 "mock"
184 }
185
186 fn rank(&self) -> usize {
187 self.rank
188 }
189
190 fn world_size(&self) -> usize {
191 self.world_size
192 }
193
194 fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
195 let mut state = self.state.lock().unwrap();
196
197 state.buffers.insert(self.rank, data.to_vec());
199
200 if state.buffers.len() == self.world_size {
202 let all_data: Vec<Vec<f32>> = (0..self.world_size)
204 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
205 .collect();
206
207 let reduced = op.reduce_slices(&all_data);
208
209 for r in 0..self.world_size {
211 state.buffers.insert(r, reduced.clone());
212 }
213 }
214
215 if let Some(result) = state.buffers.get(&self.rank) {
217 for (i, &val) in result.iter().enumerate() {
218 if i < data.len() {
219 data[i] = val;
220 }
221 }
222 }
223
224 if state.buffers.len() == self.world_size {
226 state.buffers.clear();
227 }
228 }
229
230 fn broadcast(&self, data: &mut [f32], src: usize) {
231 let mut state = self.state.lock().unwrap();
232
233 if self.rank == src {
234 state.buffers.insert(0, data.to_vec());
236 }
237
238 if let Some(src_data) = state.buffers.get(&0) {
240 for (i, &val) in src_data.iter().enumerate() {
241 if i < data.len() {
242 data[i] = val;
243 }
244 }
245 }
246 }
247
248 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
249 let mut state = self.state.lock().unwrap();
250
251 state.buffers.insert(self.rank, send_data.to_vec());
253
254 if state.buffers.len() == self.world_size {
256 let chunk_size = send_data.len();
258 for r in 0..self.world_size {
259 if let Some(data) = state.buffers.get(&r) {
260 let start = r * chunk_size;
261 for (i, &val) in data.iter().enumerate() {
262 if start + i < recv_data.len() {
263 recv_data[start + i] = val;
264 }
265 }
266 }
267 }
268 }
269 }
270
271 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
272 let mut state = self.state.lock().unwrap();
273
274 state.buffers.insert(self.rank, send_data.to_vec());
276
277 if state.buffers.len() == self.world_size {
279 let all_data: Vec<Vec<f32>> = (0..self.world_size)
281 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
282 .collect();
283
284 let reduced = op.reduce_slices(&all_data);
285
286 let chunk_size = recv_data.len();
288 let start = self.rank * chunk_size;
289 let end = (start + chunk_size).min(reduced.len());
290
291 for (i, &val) in reduced[start..end].iter().enumerate() {
292 if i < recv_data.len() {
293 recv_data[i] = val;
294 }
295 }
296 }
297 }
298
299 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
300 let mut state = self.state.lock().unwrap();
301
302 state.buffers.insert(self.rank, send_data.to_vec());
304
305 if self.rank == dst && state.buffers.len() == self.world_size {
307 let chunk_size = send_data.len();
308 for r in 0..self.world_size {
309 if let Some(data) = state.buffers.get(&r) {
310 let start = r * chunk_size;
311 for (i, &val) in data.iter().enumerate() {
312 if start + i < recv_data.len() {
313 recv_data[start + i] = val;
314 }
315 }
316 }
317 }
318 }
319 }
320
321 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
322 let state = self.state.lock().unwrap();
323
324 if self.rank == src {
326 let chunk_size = recv_data.len();
328 let start = self.rank * chunk_size;
329 let end = (start + chunk_size).min(send_data.len());
330
331 for (i, &val) in send_data[start..end].iter().enumerate() {
332 recv_data[i] = val;
333 }
334 }
335 drop(state);
336
337 }
339
340 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
341 let mut state = self.state.lock().unwrap();
342
343 state.buffers.insert(self.rank, send_data.to_vec());
345
346 if self.rank == dst && state.buffers.len() == self.world_size {
348 let all_data: Vec<Vec<f32>> = (0..self.world_size)
349 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
350 .collect();
351
352 let reduced = op.reduce_slices(&all_data);
353
354 for (i, &val) in reduced.iter().enumerate() {
355 if i < recv_data.len() {
356 recv_data[i] = val;
357 }
358 }
359 }
360 }
361
362 fn barrier(&self) {
363 let mut state = self.state.lock().unwrap();
364 state.barrier_count += 1;
365
366 if state.barrier_count == self.world_size {
368 state.barrier_count = 0;
369 }
370 }
371
372 fn send(&self, data: &[f32], dst: usize, tag: usize) {
373 let mut state = self.state.lock().unwrap();
374 state.messages.insert((self.rank, dst, tag), data.to_vec());
375 }
376
377 fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
378 let mut state = self.state.lock().unwrap();
379 if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
380 for (i, &val) in msg.iter().enumerate() {
381 if i < data.len() {
382 data[i] = val;
383 }
384 }
385 }
386 }
387}
388
389#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_reduce_op_sum() {
399 let op = ReduceOp::Sum;
400 assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
401 }
402
403 #[test]
404 fn test_reduce_op_product() {
405 let op = ReduceOp::Product;
406 assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
407 }
408
409 #[test]
410 fn test_reduce_op_min() {
411 let op = ReduceOp::Min;
412 assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
413 }
414
415 #[test]
416 fn test_reduce_op_max() {
417 let op = ReduceOp::Max;
418 assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
419 }
420
421 #[test]
422 fn test_reduce_slices_sum() {
423 let op = ReduceOp::Sum;
424 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
425 let result = op.reduce_slices(&slices);
426 assert_eq!(result, vec![9.0, 12.0]);
427 }
428
429 #[test]
430 fn test_reduce_slices_average() {
431 let op = ReduceOp::Average;
432 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
433 let result = op.reduce_slices(&slices);
434 assert_eq!(result, vec![2.0, 3.0]);
435 }
436
437 #[test]
438 fn test_mock_backend_single() {
439 let backend = MockBackend::single();
440 assert_eq!(backend.rank(), 0);
441 assert_eq!(backend.world_size(), 1);
442 assert_eq!(backend.name(), "mock");
443 }
444
445 #[test]
446 fn test_mock_backend_world() {
447 let backends = MockBackend::create_world(4);
448 assert_eq!(backends.len(), 4);
449
450 for (i, b) in backends.iter().enumerate() {
451 assert_eq!(b.rank(), i);
452 assert_eq!(b.world_size(), 4);
453 }
454 }
455
456 #[test]
457 fn test_mock_all_reduce() {
458 let backend = MockBackend::single();
462
463 let mut data = vec![1.0, 2.0];
464 backend.all_reduce(&mut data, ReduceOp::Sum);
465
466 assert_eq!(data, vec![1.0, 2.0]);
468 }
469
470 #[test]
471 fn test_mock_broadcast() {
472 let backends = MockBackend::create_world(2);
473
474 let mut data0 = vec![1.0, 2.0, 3.0];
475 let mut data1 = vec![0.0, 0.0, 0.0];
476
477 backends[0].broadcast(&mut data0, 0);
479 backends[1].broadcast(&mut data1, 0);
480
481 assert_eq!(data0, vec![1.0, 2.0, 3.0]);
482 assert_eq!(data1, vec![1.0, 2.0, 3.0]);
483 }
484
485 #[test]
486 fn test_mock_send_recv() {
487 let backends = MockBackend::create_world(2);
488
489 let send_data = vec![1.0, 2.0, 3.0];
491 backends[0].send(&send_data, 1, 0);
492
493 let mut recv_data = vec![0.0, 0.0, 0.0];
495 backends[1].recv(&mut recv_data, 0, 0);
496
497 assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
498 }
499
500 #[test]
501 fn test_mock_barrier() {
502 let backends = MockBackend::create_world(2);
503
504 backends[0].barrier();
506 backends[1].barrier();
507
508 }
510}