1use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ReduceOp {
27 Sum,
29 Product,
31 Min,
33 Max,
35 Average,
37}
38
39impl ReduceOp {
40 #[must_use]
42 pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
43 match self {
44 ReduceOp::Sum => a + b,
45 ReduceOp::Product => a * b,
46 ReduceOp::Min => a.min(b),
47 ReduceOp::Max => a.max(b),
48 ReduceOp::Average => f32::midpoint(a, b),
49 }
50 }
51
52 #[must_use]
54 pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
55 if slices.is_empty() {
56 return Vec::new();
57 }
58
59 let len = slices[0].len();
60
61 if *self == ReduceOp::Average {
63 let mut result = vec![0.0f32; len];
64 for slice in slices {
65 for (i, &val) in slice.iter().enumerate() {
66 if i < len {
67 result[i] += val;
68 }
69 }
70 }
71 let count = slices.len() as f32;
72 for val in &mut result {
73 *val /= count;
74 }
75 return result;
76 }
77
78 let mut result = slices[0].clone();
80 for slice in slices.iter().skip(1) {
81 for (i, &val) in slice.iter().enumerate() {
82 if i < len {
83 result[i] = self.apply_f32(result[i], val);
84 }
85 }
86 }
87
88 result
89 }
90}
91
92pub trait Backend: Send + Sync {
98 fn name(&self) -> &str;
100
101 fn rank(&self) -> usize;
103
104 fn world_size(&self) -> usize;
106
107 fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
109
110 fn broadcast(&self, data: &mut [f32], src: usize);
112
113 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
115
116 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
118
119 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
121
122 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
124
125 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
127
128 fn barrier(&self);
130
131 fn send(&self, data: &[f32], dst: usize, tag: usize);
133
134 fn recv(&self, data: &mut [f32], src: usize, tag: usize);
136}
137
138#[derive(Debug)]
144struct SharedState {
145 buffers: HashMap<usize, Vec<f32>>,
147 barrier_count: usize,
149 messages: HashMap<(usize, usize, usize), Vec<f32>>, }
152
153pub struct MockBackend {
160 rank: usize,
161 world_size: usize,
162 state: Arc<Mutex<SharedState>>,
163}
164
165impl MockBackend {
166 #[must_use]
168 pub fn create_world(world_size: usize) -> Vec<Self> {
169 let state = Arc::new(Mutex::new(SharedState {
170 buffers: HashMap::new(),
171 barrier_count: 0,
172 messages: HashMap::new(),
173 }));
174
175 (0..world_size)
176 .map(|rank| MockBackend {
177 rank,
178 world_size,
179 state: Arc::clone(&state),
180 })
181 .collect()
182 }
183
184 #[must_use]
186 pub fn single() -> Self {
187 MockBackend::create_world(1).pop().unwrap()
188 }
189}
190
191impl Backend for MockBackend {
192 fn name(&self) -> &'static str {
193 "mock"
194 }
195
196 fn rank(&self) -> usize {
197 self.rank
198 }
199
200 fn world_size(&self) -> usize {
201 self.world_size
202 }
203
204 fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
205 let mut state = self.state.lock().unwrap();
206
207 state.buffers.insert(self.rank, data.to_vec());
209
210 if state.buffers.len() == self.world_size {
212 let all_data: Vec<Vec<f32>> = (0..self.world_size)
214 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
215 .collect();
216
217 let reduced = op.reduce_slices(&all_data);
218
219 for r in 0..self.world_size {
221 state.buffers.insert(r, reduced.clone());
222 }
223 }
224
225 if let Some(result) = state.buffers.get(&self.rank) {
227 for (i, &val) in result.iter().enumerate() {
228 if i < data.len() {
229 data[i] = val;
230 }
231 }
232 }
233
234 if state.buffers.len() == self.world_size {
236 state.buffers.clear();
237 }
238 }
239
240 fn broadcast(&self, data: &mut [f32], src: usize) {
241 let mut state = self.state.lock().unwrap();
242
243 if self.rank == src {
244 state.buffers.insert(0, data.to_vec());
246 }
247
248 if let Some(src_data) = state.buffers.get(&0) {
250 for (i, &val) in src_data.iter().enumerate() {
251 if i < data.len() {
252 data[i] = val;
253 }
254 }
255 }
256 }
257
258 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
259 let mut state = self.state.lock().unwrap();
260
261 state.buffers.insert(self.rank, send_data.to_vec());
263
264 if state.buffers.len() == self.world_size {
266 let chunk_size = send_data.len();
268 for r in 0..self.world_size {
269 if let Some(data) = state.buffers.get(&r) {
270 let start = r * chunk_size;
271 for (i, &val) in data.iter().enumerate() {
272 if start + i < recv_data.len() {
273 recv_data[start + i] = val;
274 }
275 }
276 }
277 }
278 }
279 }
280
281 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
282 let mut state = self.state.lock().unwrap();
283
284 state.buffers.insert(self.rank, send_data.to_vec());
286
287 if state.buffers.len() == self.world_size {
289 let all_data: Vec<Vec<f32>> = (0..self.world_size)
291 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
292 .collect();
293
294 let reduced = op.reduce_slices(&all_data);
295
296 let chunk_size = recv_data.len();
298 let start = self.rank * chunk_size;
299 let end = (start + chunk_size).min(reduced.len());
300
301 for (i, &val) in reduced[start..end].iter().enumerate() {
302 if i < recv_data.len() {
303 recv_data[i] = val;
304 }
305 }
306 }
307 }
308
309 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
310 let mut state = self.state.lock().unwrap();
311
312 state.buffers.insert(self.rank, send_data.to_vec());
314
315 if self.rank == dst && state.buffers.len() == self.world_size {
317 let chunk_size = send_data.len();
318 for r in 0..self.world_size {
319 if let Some(data) = state.buffers.get(&r) {
320 let start = r * chunk_size;
321 for (i, &val) in data.iter().enumerate() {
322 if start + i < recv_data.len() {
323 recv_data[start + i] = val;
324 }
325 }
326 }
327 }
328 }
329 }
330
331 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
332 let state = self.state.lock().unwrap();
333
334 if self.rank == src {
336 let chunk_size = recv_data.len();
338 let start = self.rank * chunk_size;
339 let end = (start + chunk_size).min(send_data.len());
340
341 for (i, &val) in send_data[start..end].iter().enumerate() {
342 recv_data[i] = val;
343 }
344 }
345 drop(state);
346
347 }
349
350 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
351 let mut state = self.state.lock().unwrap();
352
353 state.buffers.insert(self.rank, send_data.to_vec());
355
356 if self.rank == dst && state.buffers.len() == self.world_size {
358 let all_data: Vec<Vec<f32>> = (0..self.world_size)
359 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
360 .collect();
361
362 let reduced = op.reduce_slices(&all_data);
363
364 for (i, &val) in reduced.iter().enumerate() {
365 if i < recv_data.len() {
366 recv_data[i] = val;
367 }
368 }
369 }
370 }
371
372 fn barrier(&self) {
373 let mut state = self.state.lock().unwrap();
374 state.barrier_count += 1;
375
376 if state.barrier_count == self.world_size {
378 state.barrier_count = 0;
379 }
380 }
381
382 fn send(&self, data: &[f32], dst: usize, tag: usize) {
383 let mut state = self.state.lock().unwrap();
384 state.messages.insert((self.rank, dst, tag), data.to_vec());
385 }
386
387 fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
388 let mut state = self.state.lock().unwrap();
389 if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
390 for (i, &val) in msg.iter().enumerate() {
391 if i < data.len() {
392 data[i] = val;
393 }
394 }
395 }
396 }
397}
398
399#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_reduce_op_sum() {
409 let op = ReduceOp::Sum;
410 assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
411 }
412
413 #[test]
414 fn test_reduce_op_product() {
415 let op = ReduceOp::Product;
416 assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
417 }
418
419 #[test]
420 fn test_reduce_op_min() {
421 let op = ReduceOp::Min;
422 assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
423 }
424
425 #[test]
426 fn test_reduce_op_max() {
427 let op = ReduceOp::Max;
428 assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
429 }
430
431 #[test]
432 fn test_reduce_slices_sum() {
433 let op = ReduceOp::Sum;
434 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
435 let result = op.reduce_slices(&slices);
436 assert_eq!(result, vec![9.0, 12.0]);
437 }
438
439 #[test]
440 fn test_reduce_slices_average() {
441 let op = ReduceOp::Average;
442 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
443 let result = op.reduce_slices(&slices);
444 assert_eq!(result, vec![2.0, 3.0]);
445 }
446
447 #[test]
448 fn test_mock_backend_single() {
449 let backend = MockBackend::single();
450 assert_eq!(backend.rank(), 0);
451 assert_eq!(backend.world_size(), 1);
452 assert_eq!(backend.name(), "mock");
453 }
454
455 #[test]
456 fn test_mock_backend_world() {
457 let backends = MockBackend::create_world(4);
458 assert_eq!(backends.len(), 4);
459
460 for (i, b) in backends.iter().enumerate() {
461 assert_eq!(b.rank(), i);
462 assert_eq!(b.world_size(), 4);
463 }
464 }
465
466 #[test]
467 fn test_mock_all_reduce() {
468 let backend = MockBackend::single();
472
473 let mut data = vec![1.0, 2.0];
474 backend.all_reduce(&mut data, ReduceOp::Sum);
475
476 assert_eq!(data, vec![1.0, 2.0]);
478 }
479
480 #[test]
481 fn test_mock_broadcast() {
482 let backends = MockBackend::create_world(2);
483
484 let mut data0 = vec![1.0, 2.0, 3.0];
485 let mut data1 = vec![0.0, 0.0, 0.0];
486
487 backends[0].broadcast(&mut data0, 0);
489 backends[1].broadcast(&mut data1, 0);
490
491 assert_eq!(data0, vec![1.0, 2.0, 3.0]);
492 assert_eq!(data1, vec![1.0, 2.0, 3.0]);
493 }
494
495 #[test]
496 fn test_mock_send_recv() {
497 let backends = MockBackend::create_world(2);
498
499 let send_data = vec![1.0, 2.0, 3.0];
501 backends[0].send(&send_data, 1, 0);
502
503 let mut recv_data = vec![0.0, 0.0, 0.0];
505 backends[1].recv(&mut recv_data, 0, 0);
506
507 assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
508 }
509
510 #[test]
511 fn test_mock_barrier() {
512 let backends = MockBackend::create_world(2);
513
514 backends[0].barrier();
516 backends[1].barrier();
517
518 }
520}