1use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ReduceOp {
18 Sum,
20 Product,
22 Min,
24 Max,
26 Average,
28}
29
30impl ReduceOp {
31 #[must_use] pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
33 match self {
34 ReduceOp::Sum => a + b,
35 ReduceOp::Product => a * b,
36 ReduceOp::Min => a.min(b),
37 ReduceOp::Max => a.max(b),
38 ReduceOp::Average => (a + b) / 2.0,
39 }
40 }
41
42 #[must_use] pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
44 if slices.is_empty() {
45 return Vec::new();
46 }
47
48 let len = slices[0].len();
49 let mut result = slices[0].clone();
50
51 for slice in slices.iter().skip(1) {
52 for (i, &val) in slice.iter().enumerate() {
53 if i < len {
54 result[i] = self.apply_f32(result[i], val);
55 }
56 }
57 }
58
59 if *self == ReduceOp::Average && slices.len() > 1 {
61 result = vec![0.0; len];
63 for slice in slices {
64 for (i, &val) in slice.iter().enumerate() {
65 if i < len {
66 result[i] += val;
67 }
68 }
69 }
70 let count = slices.len() as f32;
71 for val in &mut result {
72 *val /= count;
73 }
74 }
75
76 result
77 }
78}
79
80pub trait Backend: Send + Sync {
86 fn name(&self) -> &str;
88
89 fn rank(&self) -> usize;
91
92 fn world_size(&self) -> usize;
94
95 fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
97
98 fn broadcast(&self, data: &mut [f32], src: usize);
100
101 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
103
104 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
106
107 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
109
110 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
112
113 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
115
116 fn barrier(&self);
118
119 fn send(&self, data: &[f32], dst: usize, tag: usize);
121
122 fn recv(&self, data: &mut [f32], src: usize, tag: usize);
124}
125
126#[derive(Debug)]
132struct SharedState {
133 buffers: HashMap<usize, Vec<f32>>,
135 barrier_count: usize,
137 messages: HashMap<(usize, usize, usize), Vec<f32>>, }
140
141pub struct MockBackend {
148 rank: usize,
149 world_size: usize,
150 state: Arc<Mutex<SharedState>>,
151}
152
153impl MockBackend {
154 #[must_use] pub fn create_world(world_size: usize) -> Vec<Self> {
156 let state = Arc::new(Mutex::new(SharedState {
157 buffers: HashMap::new(),
158 barrier_count: 0,
159 messages: HashMap::new(),
160 }));
161
162 (0..world_size)
163 .map(|rank| MockBackend {
164 rank,
165 world_size,
166 state: Arc::clone(&state),
167 })
168 .collect()
169 }
170
171 #[must_use] pub fn single() -> Self {
173 MockBackend::create_world(1).pop().unwrap()
174 }
175}
176
177impl Backend for MockBackend {
178 fn name(&self) -> &'static str {
179 "mock"
180 }
181
182 fn rank(&self) -> usize {
183 self.rank
184 }
185
186 fn world_size(&self) -> usize {
187 self.world_size
188 }
189
190 fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
191 let mut state = self.state.lock().unwrap();
192
193 state.buffers.insert(self.rank, data.to_vec());
195
196 if state.buffers.len() == self.world_size {
198 let all_data: Vec<Vec<f32>> = (0..self.world_size)
200 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
201 .collect();
202
203 let reduced = op.reduce_slices(&all_data);
204
205 for r in 0..self.world_size {
207 state.buffers.insert(r, reduced.clone());
208 }
209 }
210
211 if let Some(result) = state.buffers.get(&self.rank) {
213 for (i, &val) in result.iter().enumerate() {
214 if i < data.len() {
215 data[i] = val;
216 }
217 }
218 }
219
220 if state.buffers.len() == self.world_size {
222 state.buffers.clear();
223 }
224 }
225
226 fn broadcast(&self, data: &mut [f32], src: usize) {
227 let mut state = self.state.lock().unwrap();
228
229 if self.rank == src {
230 state.buffers.insert(0, data.to_vec());
232 }
233
234 if let Some(src_data) = state.buffers.get(&0) {
236 for (i, &val) in src_data.iter().enumerate() {
237 if i < data.len() {
238 data[i] = val;
239 }
240 }
241 }
242 }
243
244 fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
245 let mut state = self.state.lock().unwrap();
246
247 state.buffers.insert(self.rank, send_data.to_vec());
249
250 if state.buffers.len() == self.world_size {
252 let chunk_size = send_data.len();
254 for r in 0..self.world_size {
255 if let Some(data) = state.buffers.get(&r) {
256 let start = r * chunk_size;
257 for (i, &val) in data.iter().enumerate() {
258 if start + i < recv_data.len() {
259 recv_data[start + i] = val;
260 }
261 }
262 }
263 }
264 }
265 }
266
267 fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
268 let mut state = self.state.lock().unwrap();
269
270 state.buffers.insert(self.rank, send_data.to_vec());
272
273 if state.buffers.len() == self.world_size {
275 let all_data: Vec<Vec<f32>> = (0..self.world_size)
277 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
278 .collect();
279
280 let reduced = op.reduce_slices(&all_data);
281
282 let chunk_size = recv_data.len();
284 let start = self.rank * chunk_size;
285 let end = (start + chunk_size).min(reduced.len());
286
287 for (i, &val) in reduced[start..end].iter().enumerate() {
288 if i < recv_data.len() {
289 recv_data[i] = val;
290 }
291 }
292 }
293 }
294
295 fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
296 let mut state = self.state.lock().unwrap();
297
298 state.buffers.insert(self.rank, send_data.to_vec());
300
301 if self.rank == dst && state.buffers.len() == self.world_size {
303 let chunk_size = send_data.len();
304 for r in 0..self.world_size {
305 if let Some(data) = state.buffers.get(&r) {
306 let start = r * chunk_size;
307 for (i, &val) in data.iter().enumerate() {
308 if start + i < recv_data.len() {
309 recv_data[start + i] = val;
310 }
311 }
312 }
313 }
314 }
315 }
316
317 fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
318 let state = self.state.lock().unwrap();
319
320 if self.rank == src {
322 let chunk_size = recv_data.len();
324 let start = self.rank * chunk_size;
325 let end = (start + chunk_size).min(send_data.len());
326
327 for (i, &val) in send_data[start..end].iter().enumerate() {
328 recv_data[i] = val;
329 }
330 }
331 drop(state);
332
333 }
335
336 fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
337 let mut state = self.state.lock().unwrap();
338
339 state.buffers.insert(self.rank, send_data.to_vec());
341
342 if self.rank == dst && state.buffers.len() == self.world_size {
344 let all_data: Vec<Vec<f32>> = (0..self.world_size)
345 .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
346 .collect();
347
348 let reduced = op.reduce_slices(&all_data);
349
350 for (i, &val) in reduced.iter().enumerate() {
351 if i < recv_data.len() {
352 recv_data[i] = val;
353 }
354 }
355 }
356 }
357
358 fn barrier(&self) {
359 let mut state = self.state.lock().unwrap();
360 state.barrier_count += 1;
361
362 if state.barrier_count == self.world_size {
364 state.barrier_count = 0;
365 }
366 }
367
368 fn send(&self, data: &[f32], dst: usize, tag: usize) {
369 let mut state = self.state.lock().unwrap();
370 state.messages.insert((self.rank, dst, tag), data.to_vec());
371 }
372
373 fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
374 let mut state = self.state.lock().unwrap();
375 if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
376 for (i, &val) in msg.iter().enumerate() {
377 if i < data.len() {
378 data[i] = val;
379 }
380 }
381 }
382 }
383}
384
385#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_reduce_op_sum() {
395 let op = ReduceOp::Sum;
396 assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
397 }
398
399 #[test]
400 fn test_reduce_op_product() {
401 let op = ReduceOp::Product;
402 assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
403 }
404
405 #[test]
406 fn test_reduce_op_min() {
407 let op = ReduceOp::Min;
408 assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
409 }
410
411 #[test]
412 fn test_reduce_op_max() {
413 let op = ReduceOp::Max;
414 assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
415 }
416
417 #[test]
418 fn test_reduce_slices_sum() {
419 let op = ReduceOp::Sum;
420 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
421 let result = op.reduce_slices(&slices);
422 assert_eq!(result, vec![9.0, 12.0]);
423 }
424
425 #[test]
426 fn test_reduce_slices_average() {
427 let op = ReduceOp::Average;
428 let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
429 let result = op.reduce_slices(&slices);
430 assert_eq!(result, vec![2.0, 3.0]);
431 }
432
433 #[test]
434 fn test_mock_backend_single() {
435 let backend = MockBackend::single();
436 assert_eq!(backend.rank(), 0);
437 assert_eq!(backend.world_size(), 1);
438 assert_eq!(backend.name(), "mock");
439 }
440
441 #[test]
442 fn test_mock_backend_world() {
443 let backends = MockBackend::create_world(4);
444 assert_eq!(backends.len(), 4);
445
446 for (i, b) in backends.iter().enumerate() {
447 assert_eq!(b.rank(), i);
448 assert_eq!(b.world_size(), 4);
449 }
450 }
451
452 #[test]
453 fn test_mock_all_reduce() {
454 let backend = MockBackend::single();
458
459 let mut data = vec![1.0, 2.0];
460 backend.all_reduce(&mut data, ReduceOp::Sum);
461
462 assert_eq!(data, vec![1.0, 2.0]);
464 }
465
466 #[test]
467 fn test_mock_broadcast() {
468 let backends = MockBackend::create_world(2);
469
470 let mut data0 = vec![1.0, 2.0, 3.0];
471 let mut data1 = vec![0.0, 0.0, 0.0];
472
473 backends[0].broadcast(&mut data0, 0);
475 backends[1].broadcast(&mut data1, 0);
476
477 assert_eq!(data0, vec![1.0, 2.0, 3.0]);
478 assert_eq!(data1, vec![1.0, 2.0, 3.0]);
479 }
480
481 #[test]
482 fn test_mock_send_recv() {
483 let backends = MockBackend::create_world(2);
484
485 let send_data = vec![1.0, 2.0, 3.0];
487 backends[0].send(&send_data, 1, 0);
488
489 let mut recv_data = vec![0.0, 0.0, 0.0];
491 backends[1].recv(&mut recv_data, 0, 0);
492
493 assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
494 }
495
496 #[test]
497 fn test_mock_barrier() {
498 let backends = MockBackend::create_world(2);
499
500 backends[0].barrier();
502 backends[1].barrier();
503
504 }
506}