1use std::collections::HashMap;
19use std::sync::{Arc, Mutex};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ReduceOp {
28 Sum,
30 Product,
32 Min,
34 Max,
36 Average,
38}
39
40impl ReduceOp {
41 #[must_use]
43 pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
44 match self {
45 ReduceOp::Sum => a + b,
46 ReduceOp::Product => a * b,
47 ReduceOp::Min => a.min(b),
48 ReduceOp::Max => a.max(b),
49 ReduceOp::Average => f32::midpoint(a, b),
50 }
51 }
52
53 #[must_use]
55 pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
56 if slices.is_empty() {
57 return Vec::new();
58 }
59
60 let len = slices[0].len();
61
62 if *self == ReduceOp::Average {
64 let mut result = vec![0.0f32; len];
65 for slice in slices {
66 for (i, &val) in slice.iter().enumerate() {
67 if i < len {
68 result[i] += val;
69 }
70 }
71 }
72 let count = slices.len() as f32;
73 for val in &mut result {
74 *val /= count;
75 }
76 return result;
77 }
78
79 let mut result = slices[0].clone();
81 for slice in slices.iter().skip(1) {
82 for (i, &val) in slice.iter().enumerate() {
83 if i < len {
84 result[i] = self.apply_f32(result[i], val);
85 }
86 }
87 }
88
89 result
90 }
91}
92
93pub trait Backend: Send + Sync {
99 fn name(&self) -> &str;
101
102 fn rank(&self) -> usize;
104
105 fn world_size(&self) -> usize;
107
108 fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
110
111 fn broadcast(&self, data: &mut [f32], src: usize);
113
114 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
116
117 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
119
120 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
122
123 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
125
126 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
128
129 fn barrier(&self);
131
132 fn send(&self, data: &[f32], dst: usize, tag: usize);
134
135 fn recv(&self, data: &mut [f32], src: usize, tag: usize);
137}
138
139#[derive(Debug)]
145struct SharedState {
146 buffers: HashMap<usize, Vec<f32>>,
148 barrier_count: usize,
150 messages: HashMap<(usize, usize, usize), Vec<f32>>, }
153
154pub struct MockBackend {
161 rank: usize,
162 world_size: usize,
163 state: Arc<Mutex<SharedState>>,
164}
165
166impl MockBackend {
167 #[must_use]
169 pub fn create_world(world_size: usize) -> Vec<Self> {
170 let state = Arc::new(Mutex::new(SharedState {
171 buffers: HashMap::new(),
172 barrier_count: 0,
173 messages: HashMap::new(),
174 }));
175
176 (0..world_size)
177 .map(|rank| MockBackend {
178 rank,
179 world_size,
180 state: Arc::clone(&state),
181 })
182 .collect()
183 }
184
185 #[must_use]
187 pub fn single() -> Self {
188 MockBackend::create_world(1).pop().unwrap()
189 }
190}
191
192impl Backend for MockBackend {
193 fn name(&self) -> &'static str {
194 "mock"
195 }
196
197 fn rank(&self) -> usize {
198 self.rank
199 }
200
201 fn world_size(&self) -> usize {
202 self.world_size
203 }
204
205 fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
206 let mut state = self.state.lock().unwrap();
207
208 state.buffers.insert(self.rank, data.to_vec());
210
211 if state.buffers.len() == self.world_size {
213 let all_data: Vec<Vec<f32>> = (0..self.world_size)
215 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
216 .collect();
217
218 let reduced = op.reduce_slices(&all_data);
219
220 for r in 0..self.world_size {
222 state.buffers.insert(r, reduced.clone());
223 }
224 }
225
226 if let Some(result) = state.buffers.get(&self.rank) {
228 for (i, &val) in result.iter().enumerate() {
229 if i < data.len() {
230 data[i] = val;
231 }
232 }
233 }
234
235 if state.buffers.len() == self.world_size {
237 state.buffers.clear();
238 }
239 }
240
241 fn broadcast(&self, data: &mut [f32], src: usize) {
242 let mut state = self.state.lock().unwrap();
243
244 if self.rank == src {
245 state.buffers.insert(0, data.to_vec());
247 }
248
249 if let Some(src_data) = state.buffers.get(&0) {
251 for (i, &val) in src_data.iter().enumerate() {
252 if i < data.len() {
253 data[i] = val;
254 }
255 }
256 }
257 }
258
259 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
260 let mut state = self.state.lock().unwrap();
261
262 state.buffers.insert(self.rank, send_data.to_vec());
264
265 if state.buffers.len() == self.world_size {
267 let chunk_size = send_data.len();
269 for r in 0..self.world_size {
270 if let Some(data) = state.buffers.get(&r) {
271 let start = r * chunk_size;
272 for (i, &val) in data.iter().enumerate() {
273 if start + i < recv_data.len() {
274 recv_data[start + i] = val;
275 }
276 }
277 }
278 }
279 }
280 }
281
282 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
283 let mut state = self.state.lock().unwrap();
284
285 state.buffers.insert(self.rank, send_data.to_vec());
287
288 if state.buffers.len() == self.world_size {
290 let all_data: Vec<Vec<f32>> = (0..self.world_size)
292 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
293 .collect();
294
295 let reduced = op.reduce_slices(&all_data);
296
297 let chunk_size = recv_data.len();
299 let start = self.rank * chunk_size;
300 let end = (start + chunk_size).min(reduced.len());
301
302 for (i, &val) in reduced[start..end].iter().enumerate() {
303 if i < recv_data.len() {
304 recv_data[i] = val;
305 }
306 }
307 }
308 }
309
310 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
311 let mut state = self.state.lock().unwrap();
312
313 state.buffers.insert(self.rank, send_data.to_vec());
315
316 if self.rank == dst && state.buffers.len() == self.world_size {
318 let chunk_size = send_data.len();
319 for r in 0..self.world_size {
320 if let Some(data) = state.buffers.get(&r) {
321 let start = r * chunk_size;
322 for (i, &val) in data.iter().enumerate() {
323 if start + i < recv_data.len() {
324 recv_data[start + i] = val;
325 }
326 }
327 }
328 }
329 }
330 }
331
332 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
333 let state = self.state.lock().unwrap();
334
335 if self.rank == src {
337 let chunk_size = recv_data.len();
339 let start = self.rank * chunk_size;
340 let end = (start + chunk_size).min(send_data.len());
341
342 for (i, &val) in send_data[start..end].iter().enumerate() {
343 recv_data[i] = val;
344 }
345 }
346 drop(state);
347
348 }
350
351 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
352 let mut state = self.state.lock().unwrap();
353
354 state.buffers.insert(self.rank, send_data.to_vec());
356
357 if self.rank == dst && state.buffers.len() == self.world_size {
359 let all_data: Vec<Vec<f32>> = (0..self.world_size)
360 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
361 .collect();
362
363 let reduced = op.reduce_slices(&all_data);
364
365 for (i, &val) in reduced.iter().enumerate() {
366 if i < recv_data.len() {
367 recv_data[i] = val;
368 }
369 }
370 }
371 }
372
373 fn barrier(&self) {
374 let mut state = self.state.lock().unwrap();
375 state.barrier_count += 1;
376
377 if state.barrier_count == self.world_size {
379 state.barrier_count = 0;
380 }
381 }
382
383 fn send(&self, data: &[f32], dst: usize, tag: usize) {
384 let mut state = self.state.lock().unwrap();
385 state.messages.insert((self.rank, dst, tag), data.to_vec());
386 }
387
388 fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
389 let mut state = self.state.lock().unwrap();
390 if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
391 for (i, &val) in msg.iter().enumerate() {
392 if i < data.len() {
393 data[i] = val;
394 }
395 }
396 }
397 }
398}
399
400#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_reduce_op_sum() {
410 let op = ReduceOp::Sum;
411 assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
412 }
413
414 #[test]
415 fn test_reduce_op_product() {
416 let op = ReduceOp::Product;
417 assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
418 }
419
420 #[test]
421 fn test_reduce_op_min() {
422 let op = ReduceOp::Min;
423 assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
424 }
425
426 #[test]
427 fn test_reduce_op_max() {
428 let op = ReduceOp::Max;
429 assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
430 }
431
432 #[test]
433 fn test_reduce_slices_sum() {
434 let op = ReduceOp::Sum;
435 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
436 let result = op.reduce_slices(&slices);
437 assert_eq!(result, vec![9.0, 12.0]);
438 }
439
440 #[test]
441 fn test_reduce_slices_average() {
442 let op = ReduceOp::Average;
443 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
444 let result = op.reduce_slices(&slices);
445 assert_eq!(result, vec![2.0, 3.0]);
446 }
447
448 #[test]
449 fn test_mock_backend_single() {
450 let backend = MockBackend::single();
451 assert_eq!(backend.rank(), 0);
452 assert_eq!(backend.world_size(), 1);
453 assert_eq!(backend.name(), "mock");
454 }
455
456 #[test]
457 fn test_mock_backend_world() {
458 let backends = MockBackend::create_world(4);
459 assert_eq!(backends.len(), 4);
460
461 for (i, b) in backends.iter().enumerate() {
462 assert_eq!(b.rank(), i);
463 assert_eq!(b.world_size(), 4);
464 }
465 }
466
467 #[test]
468 fn test_mock_all_reduce() {
469 let backend = MockBackend::single();
473
474 let mut data = vec![1.0, 2.0];
475 backend.all_reduce(&mut data, ReduceOp::Sum);
476
477 assert_eq!(data, vec![1.0, 2.0]);
479 }
480
481 #[test]
482 fn test_mock_broadcast() {
483 let backends = MockBackend::create_world(2);
484
485 let mut data0 = vec![1.0, 2.0, 3.0];
486 let mut data1 = vec![0.0, 0.0, 0.0];
487
488 backends[0].broadcast(&mut data0, 0);
490 backends[1].broadcast(&mut data1, 0);
491
492 assert_eq!(data0, vec![1.0, 2.0, 3.0]);
493 assert_eq!(data1, vec![1.0, 2.0, 3.0]);
494 }
495
496 #[test]
497 fn test_mock_send_recv() {
498 let backends = MockBackend::create_world(2);
499
500 let send_data = vec![1.0, 2.0, 3.0];
502 backends[0].send(&send_data, 1, 0);
503
504 let mut recv_data = vec![0.0, 0.0, 0.0];
506 backends[1].recv(&mut recv_data, 0, 0);
507
508 assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
509 }
510
511 #[test]
512 fn test_mock_barrier() {
513 let backends = MockBackend::create_world(2);
514
515 backends[0].barrier();
517 backends[1].barrier();
518
519 }
521}