axonml_distributed/
backend.rs

1//! Backend - Communication Backend Abstractions
2//!
3//! Provides backend trait and implementations for distributed communication.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11// =============================================================================
12// Reduce Operations
13// =============================================================================
14
15/// Reduction operation for collective communication.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ReduceOp {
18    /// Sum all values.
19    Sum,
20    /// Compute product of all values.
21    Product,
22    /// Find minimum value.
23    Min,
24    /// Find maximum value.
25    Max,
26    /// Compute average of all values.
27    Average,
28}
29
30impl ReduceOp {
31    /// Applies the reduction operation to two f32 values.
32    #[must_use] pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
33        match self {
34            ReduceOp::Sum => a + b,
35            ReduceOp::Product => a * b,
36            ReduceOp::Min => a.min(b),
37            ReduceOp::Max => a.max(b),
38            ReduceOp::Average => (a + b) / 2.0,
39        }
40    }
41
42    /// Applies the reduction operation to slices.
43    #[must_use] pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
44        if slices.is_empty() {
45            return Vec::new();
46        }
47
48        let len = slices[0].len();
49        let mut result = slices[0].clone();
50
51        for slice in slices.iter().skip(1) {
52            for (i, &val) in slice.iter().enumerate() {
53                if i < len {
54                    result[i] = self.apply_f32(result[i], val);
55                }
56            }
57        }
58
59        // For average, we need to divide by count (already averaged pairwise above)
60        if *self == ReduceOp::Average && slices.len() > 1 {
61            // Re-compute as actual average
62            result = vec![0.0; len];
63            for slice in slices {
64                for (i, &val) in slice.iter().enumerate() {
65                    if i < len {
66                        result[i] += val;
67                    }
68                }
69            }
70            let count = slices.len() as f32;
71            for val in &mut result {
72                *val /= count;
73            }
74        }
75
76        result
77    }
78}
79
80// =============================================================================
81// Backend Trait
82// =============================================================================
83
84/// Trait for distributed communication backends.
85pub trait Backend: Send + Sync {
86    /// Returns the name of the backend.
87    fn name(&self) -> &str;
88
89    /// Returns the rank of this process.
90    fn rank(&self) -> usize;
91
92    /// Returns the total world size.
93    fn world_size(&self) -> usize;
94
95    /// Performs all-reduce operation.
96    fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
97
98    /// Broadcasts data from a source rank.
99    fn broadcast(&self, data: &mut [f32], src: usize);
100
101    /// Performs all-gather operation.
102    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
103
104    /// Performs reduce-scatter operation.
105    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
106
107    /// Performs gather operation.
108    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
109
110    /// Performs scatter operation.
111    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
112
113    /// Performs reduce operation (result only on dst rank).
114    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
115
116    /// Synchronizes all processes.
117    fn barrier(&self);
118
119    /// Sends data to a specific rank.
120    fn send(&self, data: &[f32], dst: usize, tag: usize);
121
122    /// Receives data from a specific rank.
123    fn recv(&self, data: &mut [f32], src: usize, tag: usize);
124}
125
126// =============================================================================
127// SharedState for Mock Backend
128// =============================================================================
129
130/// Shared state for mock distributed communication.
131#[derive(Debug)]
132struct SharedState {
133    /// Data buffers for each rank.
134    buffers: HashMap<usize, Vec<f32>>,
135    /// Barrier counter.
136    barrier_count: usize,
137    /// Message queue for send/recv operations.
138    messages: HashMap<(usize, usize, usize), Vec<f32>>, // (src, dst, tag) -> data
139}
140
141// =============================================================================
142// Mock Backend
143// =============================================================================
144
145/// A mock backend for testing distributed operations in a single process.
146/// Simulates distributed communication without actual network operations.
147pub struct MockBackend {
148    rank: usize,
149    world_size: usize,
150    state: Arc<Mutex<SharedState>>,
151}
152
153impl MockBackend {
154    /// Creates a collection of mock backends for testing.
155    #[must_use] pub fn create_world(world_size: usize) -> Vec<Self> {
156        let state = Arc::new(Mutex::new(SharedState {
157            buffers: HashMap::new(),
158            barrier_count: 0,
159            messages: HashMap::new(),
160        }));
161
162        (0..world_size)
163            .map(|rank| MockBackend {
164                rank,
165                world_size,
166                state: Arc::clone(&state),
167            })
168            .collect()
169    }
170
171    /// Creates a single mock backend (rank 0, world size 1).
172    #[must_use] pub fn single() -> Self {
173        MockBackend::create_world(1).pop().unwrap()
174    }
175}
176
177impl Backend for MockBackend {
178    fn name(&self) -> &'static str {
179        "mock"
180    }
181
182    fn rank(&self) -> usize {
183        self.rank
184    }
185
186    fn world_size(&self) -> usize {
187        self.world_size
188    }
189
190    fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
191        let mut state = self.state.lock().unwrap();
192
193        // Store this rank's data
194        state.buffers.insert(self.rank, data.to_vec());
195
196        // Check if all ranks have submitted
197        if state.buffers.len() == self.world_size {
198            // Perform reduction
199            let all_data: Vec<Vec<f32>> = (0..self.world_size)
200                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
201                .collect();
202
203            let reduced = op.reduce_slices(&all_data);
204
205            // Update all buffers with result
206            for r in 0..self.world_size {
207                state.buffers.insert(r, reduced.clone());
208            }
209        }
210
211        // Get result for this rank
212        if let Some(result) = state.buffers.get(&self.rank) {
213            for (i, &val) in result.iter().enumerate() {
214                if i < data.len() {
215                    data[i] = val;
216                }
217            }
218        }
219
220        // Clear buffers if this is the last rank to read
221        if state.buffers.len() == self.world_size {
222            state.buffers.clear();
223        }
224    }
225
226    fn broadcast(&self, data: &mut [f32], src: usize) {
227        let mut state = self.state.lock().unwrap();
228
229        if self.rank == src {
230            // Source rank stores its data
231            state.buffers.insert(0, data.to_vec());
232        }
233
234        // Get broadcast data
235        if let Some(src_data) = state.buffers.get(&0) {
236            for (i, &val) in src_data.iter().enumerate() {
237                if i < data.len() {
238                    data[i] = val;
239                }
240            }
241        }
242    }
243
244    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
245        let mut state = self.state.lock().unwrap();
246
247        // Store this rank's data
248        state.buffers.insert(self.rank, send_data.to_vec());
249
250        // Check if all ranks have submitted
251        if state.buffers.len() == self.world_size {
252            // Concatenate all data in rank order
253            let chunk_size = send_data.len();
254            for r in 0..self.world_size {
255                if let Some(data) = state.buffers.get(&r) {
256                    let start = r * chunk_size;
257                    for (i, &val) in data.iter().enumerate() {
258                        if start + i < recv_data.len() {
259                            recv_data[start + i] = val;
260                        }
261                    }
262                }
263            }
264        }
265    }
266
267    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
268        let mut state = self.state.lock().unwrap();
269
270        // Store this rank's data
271        state.buffers.insert(self.rank, send_data.to_vec());
272
273        // Check if all ranks have submitted
274        if state.buffers.len() == self.world_size {
275            // First reduce all data
276            let all_data: Vec<Vec<f32>> = (0..self.world_size)
277                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
278                .collect();
279
280            let reduced = op.reduce_slices(&all_data);
281
282            // Scatter to each rank
283            let chunk_size = recv_data.len();
284            let start = self.rank * chunk_size;
285            let end = (start + chunk_size).min(reduced.len());
286
287            for (i, &val) in reduced[start..end].iter().enumerate() {
288                if i < recv_data.len() {
289                    recv_data[i] = val;
290                }
291            }
292        }
293    }
294
295    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
296        let mut state = self.state.lock().unwrap();
297
298        // Store this rank's data
299        state.buffers.insert(self.rank, send_data.to_vec());
300
301        // Only destination rank collects
302        if self.rank == dst && state.buffers.len() == self.world_size {
303            let chunk_size = send_data.len();
304            for r in 0..self.world_size {
305                if let Some(data) = state.buffers.get(&r) {
306                    let start = r * chunk_size;
307                    for (i, &val) in data.iter().enumerate() {
308                        if start + i < recv_data.len() {
309                            recv_data[start + i] = val;
310                        }
311                    }
312                }
313            }
314        }
315    }
316
317    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
318        let state = self.state.lock().unwrap();
319
320        // Only source rank has full data
321        if self.rank == src {
322            // Scatter to all (in mock, we copy our portion)
323            let chunk_size = recv_data.len();
324            let start = self.rank * chunk_size;
325            let end = (start + chunk_size).min(send_data.len());
326
327            for (i, &val) in send_data[start..end].iter().enumerate() {
328                recv_data[i] = val;
329            }
330        }
331        drop(state);
332
333        // Others would receive via message passing in real impl
334    }
335
336    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
337        let mut state = self.state.lock().unwrap();
338
339        // Store this rank's data
340        state.buffers.insert(self.rank, send_data.to_vec());
341
342        // Only destination rank reduces
343        if self.rank == dst && state.buffers.len() == self.world_size {
344            let all_data: Vec<Vec<f32>> = (0..self.world_size)
345                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
346                .collect();
347
348            let reduced = op.reduce_slices(&all_data);
349
350            for (i, &val) in reduced.iter().enumerate() {
351                if i < recv_data.len() {
352                    recv_data[i] = val;
353                }
354            }
355        }
356    }
357
358    fn barrier(&self) {
359        let mut state = self.state.lock().unwrap();
360        state.barrier_count += 1;
361
362        // Reset when all have arrived
363        if state.barrier_count == self.world_size {
364            state.barrier_count = 0;
365        }
366    }
367
368    fn send(&self, data: &[f32], dst: usize, tag: usize) {
369        let mut state = self.state.lock().unwrap();
370        state.messages.insert((self.rank, dst, tag), data.to_vec());
371    }
372
373    fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
374        let mut state = self.state.lock().unwrap();
375        if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
376            for (i, &val) in msg.iter().enumerate() {
377                if i < data.len() {
378                    data[i] = val;
379                }
380            }
381        }
382    }
383}
384
385// =============================================================================
386// Tests
387// =============================================================================
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_reduce_op_sum() {
395        let op = ReduceOp::Sum;
396        assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
397    }
398
399    #[test]
400    fn test_reduce_op_product() {
401        let op = ReduceOp::Product;
402        assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
403    }
404
405    #[test]
406    fn test_reduce_op_min() {
407        let op = ReduceOp::Min;
408        assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
409    }
410
411    #[test]
412    fn test_reduce_op_max() {
413        let op = ReduceOp::Max;
414        assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
415    }
416
417    #[test]
418    fn test_reduce_slices_sum() {
419        let op = ReduceOp::Sum;
420        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
421        let result = op.reduce_slices(&slices);
422        assert_eq!(result, vec![9.0, 12.0]);
423    }
424
425    #[test]
426    fn test_reduce_slices_average() {
427        let op = ReduceOp::Average;
428        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
429        let result = op.reduce_slices(&slices);
430        assert_eq!(result, vec![2.0, 3.0]);
431    }
432
433    #[test]
434    fn test_mock_backend_single() {
435        let backend = MockBackend::single();
436        assert_eq!(backend.rank(), 0);
437        assert_eq!(backend.world_size(), 1);
438        assert_eq!(backend.name(), "mock");
439    }
440
441    #[test]
442    fn test_mock_backend_world() {
443        let backends = MockBackend::create_world(4);
444        assert_eq!(backends.len(), 4);
445
446        for (i, b) in backends.iter().enumerate() {
447            assert_eq!(b.rank(), i);
448            assert_eq!(b.world_size(), 4);
449        }
450    }
451
452    #[test]
453    fn test_mock_all_reduce() {
454        // Note: In a real distributed system, all_reduce would be called from different
455        // processes simultaneously. The mock backend simulates a single process,
456        // so values remain unchanged when called sequentially from same thread.
457        let backend = MockBackend::single();
458
459        let mut data = vec![1.0, 2.0];
460        backend.all_reduce(&mut data, ReduceOp::Sum);
461
462        // With single rank, values remain the same
463        assert_eq!(data, vec![1.0, 2.0]);
464    }
465
466    #[test]
467    fn test_mock_broadcast() {
468        let backends = MockBackend::create_world(2);
469
470        let mut data0 = vec![1.0, 2.0, 3.0];
471        let mut data1 = vec![0.0, 0.0, 0.0];
472
473        // Broadcast from rank 0
474        backends[0].broadcast(&mut data0, 0);
475        backends[1].broadcast(&mut data1, 0);
476
477        assert_eq!(data0, vec![1.0, 2.0, 3.0]);
478        assert_eq!(data1, vec![1.0, 2.0, 3.0]);
479    }
480
481    #[test]
482    fn test_mock_send_recv() {
483        let backends = MockBackend::create_world(2);
484
485        // Send from rank 0 to rank 1
486        let send_data = vec![1.0, 2.0, 3.0];
487        backends[0].send(&send_data, 1, 0);
488
489        // Receive on rank 1
490        let mut recv_data = vec![0.0, 0.0, 0.0];
491        backends[1].recv(&mut recv_data, 0, 0);
492
493        assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
494    }
495
496    #[test]
497    fn test_mock_barrier() {
498        let backends = MockBackend::create_world(2);
499
500        // Both call barrier
501        backends[0].barrier();
502        backends[1].barrier();
503
504        // Should not deadlock
505    }
506}