1use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ReduceOp {
27 Sum,
29 Product,
31 Min,
33 Max,
35 Average,
37}
38
39impl ReduceOp {
40 #[must_use]
42 pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
43 match self {
44 ReduceOp::Sum => a + b,
45 ReduceOp::Product => a * b,
46 ReduceOp::Min => a.min(b),
47 ReduceOp::Max => a.max(b),
48 ReduceOp::Average => f32::midpoint(a, b),
49 }
50 }
51
52 #[must_use]
54 pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
55 if slices.is_empty() {
56 return Vec::new();
57 }
58
59 let len = slices[0].len();
60 let mut result = slices[0].clone();
61
62 for slice in slices.iter().skip(1) {
63 for (i, &val) in slice.iter().enumerate() {
64 if i < len {
65 result[i] = self.apply_f32(result[i], val);
66 }
67 }
68 }
69
70 if *self == ReduceOp::Average && slices.len() > 1 {
72 result = vec![0.0; len];
74 for slice in slices {
75 for (i, &val) in slice.iter().enumerate() {
76 if i < len {
77 result[i] += val;
78 }
79 }
80 }
81 let count = slices.len() as f32;
82 for val in &mut result {
83 *val /= count;
84 }
85 }
86
87 result
88 }
89}
90
91pub trait Backend: Send + Sync {
97 fn name(&self) -> &str;
99
100 fn rank(&self) -> usize;
102
103 fn world_size(&self) -> usize;
105
106 fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
108
109 fn broadcast(&self, data: &mut [f32], src: usize);
111
112 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
114
115 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
117
118 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
120
121 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
123
124 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
126
127 fn barrier(&self);
129
130 fn send(&self, data: &[f32], dst: usize, tag: usize);
132
133 fn recv(&self, data: &mut [f32], src: usize, tag: usize);
135}
136
137#[derive(Debug)]
143struct SharedState {
144 buffers: HashMap<usize, Vec<f32>>,
146 barrier_count: usize,
148 messages: HashMap<(usize, usize, usize), Vec<f32>>, }
151
152pub struct MockBackend {
159 rank: usize,
160 world_size: usize,
161 state: Arc<Mutex<SharedState>>,
162}
163
164impl MockBackend {
165 #[must_use]
167 pub fn create_world(world_size: usize) -> Vec<Self> {
168 let state = Arc::new(Mutex::new(SharedState {
169 buffers: HashMap::new(),
170 barrier_count: 0,
171 messages: HashMap::new(),
172 }));
173
174 (0..world_size)
175 .map(|rank| MockBackend {
176 rank,
177 world_size,
178 state: Arc::clone(&state),
179 })
180 .collect()
181 }
182
183 #[must_use]
185 pub fn single() -> Self {
186 MockBackend::create_world(1).pop().unwrap()
187 }
188}
189
190impl Backend for MockBackend {
191 fn name(&self) -> &'static str {
192 "mock"
193 }
194
195 fn rank(&self) -> usize {
196 self.rank
197 }
198
199 fn world_size(&self) -> usize {
200 self.world_size
201 }
202
203 fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
204 let mut state = self.state.lock().unwrap();
205
206 state.buffers.insert(self.rank, data.to_vec());
208
209 if state.buffers.len() == self.world_size {
211 let all_data: Vec<Vec<f32>> = (0..self.world_size)
213 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
214 .collect();
215
216 let reduced = op.reduce_slices(&all_data);
217
218 for r in 0..self.world_size {
220 state.buffers.insert(r, reduced.clone());
221 }
222 }
223
224 if let Some(result) = state.buffers.get(&self.rank) {
226 for (i, &val) in result.iter().enumerate() {
227 if i < data.len() {
228 data[i] = val;
229 }
230 }
231 }
232
233 if state.buffers.len() == self.world_size {
235 state.buffers.clear();
236 }
237 }
238
239 fn broadcast(&self, data: &mut [f32], src: usize) {
240 let mut state = self.state.lock().unwrap();
241
242 if self.rank == src {
243 state.buffers.insert(0, data.to_vec());
245 }
246
247 if let Some(src_data) = state.buffers.get(&0) {
249 for (i, &val) in src_data.iter().enumerate() {
250 if i < data.len() {
251 data[i] = val;
252 }
253 }
254 }
255 }
256
257 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
258 let mut state = self.state.lock().unwrap();
259
260 state.buffers.insert(self.rank, send_data.to_vec());
262
263 if state.buffers.len() == self.world_size {
265 let chunk_size = send_data.len();
267 for r in 0..self.world_size {
268 if let Some(data) = state.buffers.get(&r) {
269 let start = r * chunk_size;
270 for (i, &val) in data.iter().enumerate() {
271 if start + i < recv_data.len() {
272 recv_data[start + i] = val;
273 }
274 }
275 }
276 }
277 }
278 }
279
280 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
281 let mut state = self.state.lock().unwrap();
282
283 state.buffers.insert(self.rank, send_data.to_vec());
285
286 if state.buffers.len() == self.world_size {
288 let all_data: Vec<Vec<f32>> = (0..self.world_size)
290 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
291 .collect();
292
293 let reduced = op.reduce_slices(&all_data);
294
295 let chunk_size = recv_data.len();
297 let start = self.rank * chunk_size;
298 let end = (start + chunk_size).min(reduced.len());
299
300 for (i, &val) in reduced[start..end].iter().enumerate() {
301 if i < recv_data.len() {
302 recv_data[i] = val;
303 }
304 }
305 }
306 }
307
308 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
309 let mut state = self.state.lock().unwrap();
310
311 state.buffers.insert(self.rank, send_data.to_vec());
313
314 if self.rank == dst && state.buffers.len() == self.world_size {
316 let chunk_size = send_data.len();
317 for r in 0..self.world_size {
318 if let Some(data) = state.buffers.get(&r) {
319 let start = r * chunk_size;
320 for (i, &val) in data.iter().enumerate() {
321 if start + i < recv_data.len() {
322 recv_data[start + i] = val;
323 }
324 }
325 }
326 }
327 }
328 }
329
330 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
331 let state = self.state.lock().unwrap();
332
333 if self.rank == src {
335 let chunk_size = recv_data.len();
337 let start = self.rank * chunk_size;
338 let end = (start + chunk_size).min(send_data.len());
339
340 for (i, &val) in send_data[start..end].iter().enumerate() {
341 recv_data[i] = val;
342 }
343 }
344 drop(state);
345
346 }
348
349 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
350 let mut state = self.state.lock().unwrap();
351
352 state.buffers.insert(self.rank, send_data.to_vec());
354
355 if self.rank == dst && state.buffers.len() == self.world_size {
357 let all_data: Vec<Vec<f32>> = (0..self.world_size)
358 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
359 .collect();
360
361 let reduced = op.reduce_slices(&all_data);
362
363 for (i, &val) in reduced.iter().enumerate() {
364 if i < recv_data.len() {
365 recv_data[i] = val;
366 }
367 }
368 }
369 }
370
371 fn barrier(&self) {
372 let mut state = self.state.lock().unwrap();
373 state.barrier_count += 1;
374
375 if state.barrier_count == self.world_size {
377 state.barrier_count = 0;
378 }
379 }
380
381 fn send(&self, data: &[f32], dst: usize, tag: usize) {
382 let mut state = self.state.lock().unwrap();
383 state.messages.insert((self.rank, dst, tag), data.to_vec());
384 }
385
386 fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
387 let mut state = self.state.lock().unwrap();
388 if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
389 for (i, &val) in msg.iter().enumerate() {
390 if i < data.len() {
391 data[i] = val;
392 }
393 }
394 }
395 }
396}
397
398#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_reduce_op_sum() {
408 let op = ReduceOp::Sum;
409 assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
410 }
411
412 #[test]
413 fn test_reduce_op_product() {
414 let op = ReduceOp::Product;
415 assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
416 }
417
418 #[test]
419 fn test_reduce_op_min() {
420 let op = ReduceOp::Min;
421 assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
422 }
423
424 #[test]
425 fn test_reduce_op_max() {
426 let op = ReduceOp::Max;
427 assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
428 }
429
430 #[test]
431 fn test_reduce_slices_sum() {
432 let op = ReduceOp::Sum;
433 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
434 let result = op.reduce_slices(&slices);
435 assert_eq!(result, vec![9.0, 12.0]);
436 }
437
438 #[test]
439 fn test_reduce_slices_average() {
440 let op = ReduceOp::Average;
441 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
442 let result = op.reduce_slices(&slices);
443 assert_eq!(result, vec![2.0, 3.0]);
444 }
445
446 #[test]
447 fn test_mock_backend_single() {
448 let backend = MockBackend::single();
449 assert_eq!(backend.rank(), 0);
450 assert_eq!(backend.world_size(), 1);
451 assert_eq!(backend.name(), "mock");
452 }
453
454 #[test]
455 fn test_mock_backend_world() {
456 let backends = MockBackend::create_world(4);
457 assert_eq!(backends.len(), 4);
458
459 for (i, b) in backends.iter().enumerate() {
460 assert_eq!(b.rank(), i);
461 assert_eq!(b.world_size(), 4);
462 }
463 }
464
465 #[test]
466 fn test_mock_all_reduce() {
467 let backend = MockBackend::single();
471
472 let mut data = vec![1.0, 2.0];
473 backend.all_reduce(&mut data, ReduceOp::Sum);
474
475 assert_eq!(data, vec![1.0, 2.0]);
477 }
478
479 #[test]
480 fn test_mock_broadcast() {
481 let backends = MockBackend::create_world(2);
482
483 let mut data0 = vec![1.0, 2.0, 3.0];
484 let mut data1 = vec![0.0, 0.0, 0.0];
485
486 backends[0].broadcast(&mut data0, 0);
488 backends[1].broadcast(&mut data1, 0);
489
490 assert_eq!(data0, vec![1.0, 2.0, 3.0]);
491 assert_eq!(data1, vec![1.0, 2.0, 3.0]);
492 }
493
494 #[test]
495 fn test_mock_send_recv() {
496 let backends = MockBackend::create_world(2);
497
498 let send_data = vec![1.0, 2.0, 3.0];
500 backends[0].send(&send_data, 1, 0);
501
502 let mut recv_data = vec![0.0, 0.0, 0.0];
504 backends[1].recv(&mut recv_data, 0, 0);
505
506 assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
507 }
508
509 #[test]
510 fn test_mock_barrier() {
511 let backends = MockBackend::create_world(2);
512
513 backends[0].barrier();
515 backends[1].barrier();
516
517 }
519}