Skip to main content

axonml_distributed/
backend.rs

1//! Backend - Communication Backend Abstractions
2//!
3//! Provides backend trait and implementations for distributed communication.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11// =============================================================================
12// Reduce Operations
13// =============================================================================
14
15/// Reduction operation for collective communication.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ReduceOp {
18    /// Sum all values.
19    Sum,
20    /// Compute product of all values.
21    Product,
22    /// Find minimum value.
23    Min,
24    /// Find maximum value.
25    Max,
26    /// Compute average of all values.
27    Average,
28}
29
30impl ReduceOp {
31    /// Applies the reduction operation to two f32 values.
32    #[must_use]
33    pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
34        match self {
35            ReduceOp::Sum => a + b,
36            ReduceOp::Product => a * b,
37            ReduceOp::Min => a.min(b),
38            ReduceOp::Max => a.max(b),
39            ReduceOp::Average => (a + b) / 2.0,
40        }
41    }
42
43    /// Applies the reduction operation to slices.
44    #[must_use]
45    pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
46        if slices.is_empty() {
47            return Vec::new();
48        }
49
50        let len = slices[0].len();
51        let mut result = slices[0].clone();
52
53        for slice in slices.iter().skip(1) {
54            for (i, &val) in slice.iter().enumerate() {
55                if i < len {
56                    result[i] = self.apply_f32(result[i], val);
57                }
58            }
59        }
60
61        // For average, we need to divide by count (already averaged pairwise above)
62        if *self == ReduceOp::Average && slices.len() > 1 {
63            // Re-compute as actual average
64            result = vec![0.0; len];
65            for slice in slices {
66                for (i, &val) in slice.iter().enumerate() {
67                    if i < len {
68                        result[i] += val;
69                    }
70                }
71            }
72            let count = slices.len() as f32;
73            for val in &mut result {
74                *val /= count;
75            }
76        }
77
78        result
79    }
80}
81
82// =============================================================================
83// Backend Trait
84// =============================================================================
85
86/// Trait for distributed communication backends.
87pub trait Backend: Send + Sync {
88    /// Returns the name of the backend.
89    fn name(&self) -> &str;
90
91    /// Returns the rank of this process.
92    fn rank(&self) -> usize;
93
94    /// Returns the total world size.
95    fn world_size(&self) -> usize;
96
97    /// Performs all-reduce operation.
98    fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
99
100    /// Broadcasts data from a source rank.
101    fn broadcast(&self, data: &mut [f32], src: usize);
102
103    /// Performs all-gather operation.
104    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
105
106    /// Performs reduce-scatter operation.
107    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
108
109    /// Performs gather operation.
110    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
111
112    /// Performs scatter operation.
113    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
114
115    /// Performs reduce operation (result only on dst rank).
116    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
117
118    /// Synchronizes all processes.
119    fn barrier(&self);
120
121    /// Sends data to a specific rank.
122    fn send(&self, data: &[f32], dst: usize, tag: usize);
123
124    /// Receives data from a specific rank.
125    fn recv(&self, data: &mut [f32], src: usize, tag: usize);
126}
127
128// =============================================================================
129// SharedState for Mock Backend
130// =============================================================================
131
132/// Shared state for mock distributed communication.
133#[derive(Debug)]
134struct SharedState {
135    /// Data buffers for each rank.
136    buffers: HashMap<usize, Vec<f32>>,
137    /// Barrier counter.
138    barrier_count: usize,
139    /// Message queue for send/recv operations.
140    messages: HashMap<(usize, usize, usize), Vec<f32>>, // (src, dst, tag) -> data
141}
142
143// =============================================================================
144// Mock Backend
145// =============================================================================
146
147/// A mock backend for testing distributed operations in a single process.
148/// Simulates distributed communication without actual network operations.
149pub struct MockBackend {
150    rank: usize,
151    world_size: usize,
152    state: Arc<Mutex<SharedState>>,
153}
154
155impl MockBackend {
156    /// Creates a collection of mock backends for testing.
157    #[must_use]
158    pub fn create_world(world_size: usize) -> Vec<Self> {
159        let state = Arc::new(Mutex::new(SharedState {
160            buffers: HashMap::new(),
161            barrier_count: 0,
162            messages: HashMap::new(),
163        }));
164
165        (0..world_size)
166            .map(|rank| MockBackend {
167                rank,
168                world_size,
169                state: Arc::clone(&state),
170            })
171            .collect()
172    }
173
174    /// Creates a single mock backend (rank 0, world size 1).
175    #[must_use]
176    pub fn single() -> Self {
177        MockBackend::create_world(1).pop().unwrap()
178    }
179}
180
181impl Backend for MockBackend {
182    fn name(&self) -> &'static str {
183        "mock"
184    }
185
186    fn rank(&self) -> usize {
187        self.rank
188    }
189
190    fn world_size(&self) -> usize {
191        self.world_size
192    }
193
194    fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
195        let mut state = self.state.lock().unwrap();
196
197        // Store this rank's data
198        state.buffers.insert(self.rank, data.to_vec());
199
200        // Check if all ranks have submitted
201        if state.buffers.len() == self.world_size {
202            // Perform reduction
203            let all_data: Vec<Vec<f32>> = (0..self.world_size)
204                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
205                .collect();
206
207            let reduced = op.reduce_slices(&all_data);
208
209            // Update all buffers with result
210            for r in 0..self.world_size {
211                state.buffers.insert(r, reduced.clone());
212            }
213        }
214
215        // Get result for this rank
216        if let Some(result) = state.buffers.get(&self.rank) {
217            for (i, &val) in result.iter().enumerate() {
218                if i < data.len() {
219                    data[i] = val;
220                }
221            }
222        }
223
224        // Clear buffers if this is the last rank to read
225        if state.buffers.len() == self.world_size {
226            state.buffers.clear();
227        }
228    }
229
230    fn broadcast(&self, data: &mut [f32], src: usize) {
231        let mut state = self.state.lock().unwrap();
232
233        if self.rank == src {
234            // Source rank stores its data
235            state.buffers.insert(0, data.to_vec());
236        }
237
238        // Get broadcast data
239        if let Some(src_data) = state.buffers.get(&0) {
240            for (i, &val) in src_data.iter().enumerate() {
241                if i < data.len() {
242                    data[i] = val;
243                }
244            }
245        }
246    }
247
248    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
249        let mut state = self.state.lock().unwrap();
250
251        // Store this rank's data
252        state.buffers.insert(self.rank, send_data.to_vec());
253
254        // Check if all ranks have submitted
255        if state.buffers.len() == self.world_size {
256            // Concatenate all data in rank order
257            let chunk_size = send_data.len();
258            for r in 0..self.world_size {
259                if let Some(data) = state.buffers.get(&r) {
260                    let start = r * chunk_size;
261                    for (i, &val) in data.iter().enumerate() {
262                        if start + i < recv_data.len() {
263                            recv_data[start + i] = val;
264                        }
265                    }
266                }
267            }
268        }
269    }
270
271    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
272        let mut state = self.state.lock().unwrap();
273
274        // Store this rank's data
275        state.buffers.insert(self.rank, send_data.to_vec());
276
277        // Check if all ranks have submitted
278        if state.buffers.len() == self.world_size {
279            // First reduce all data
280            let all_data: Vec<Vec<f32>> = (0..self.world_size)
281                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
282                .collect();
283
284            let reduced = op.reduce_slices(&all_data);
285
286            // Scatter to each rank
287            let chunk_size = recv_data.len();
288            let start = self.rank * chunk_size;
289            let end = (start + chunk_size).min(reduced.len());
290
291            for (i, &val) in reduced[start..end].iter().enumerate() {
292                if i < recv_data.len() {
293                    recv_data[i] = val;
294                }
295            }
296        }
297    }
298
299    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
300        let mut state = self.state.lock().unwrap();
301
302        // Store this rank's data
303        state.buffers.insert(self.rank, send_data.to_vec());
304
305        // Only destination rank collects
306        if self.rank == dst && state.buffers.len() == self.world_size {
307            let chunk_size = send_data.len();
308            for r in 0..self.world_size {
309                if let Some(data) = state.buffers.get(&r) {
310                    let start = r * chunk_size;
311                    for (i, &val) in data.iter().enumerate() {
312                        if start + i < recv_data.len() {
313                            recv_data[start + i] = val;
314                        }
315                    }
316                }
317            }
318        }
319    }
320
321    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
322        let state = self.state.lock().unwrap();
323
324        // Only source rank has full data
325        if self.rank == src {
326            // Scatter to all (in mock, we copy our portion)
327            let chunk_size = recv_data.len();
328            let start = self.rank * chunk_size;
329            let end = (start + chunk_size).min(send_data.len());
330
331            for (i, &val) in send_data[start..end].iter().enumerate() {
332                recv_data[i] = val;
333            }
334        }
335        drop(state);
336
337        // Others would receive via message passing in real impl
338    }
339
340    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
341        let mut state = self.state.lock().unwrap();
342
343        // Store this rank's data
344        state.buffers.insert(self.rank, send_data.to_vec());
345
346        // Only destination rank reduces
347        if self.rank == dst && state.buffers.len() == self.world_size {
348            let all_data: Vec<Vec<f32>> = (0..self.world_size)
349                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
350                .collect();
351
352            let reduced = op.reduce_slices(&all_data);
353
354            for (i, &val) in reduced.iter().enumerate() {
355                if i < recv_data.len() {
356                    recv_data[i] = val;
357                }
358            }
359        }
360    }
361
362    fn barrier(&self) {
363        let mut state = self.state.lock().unwrap();
364        state.barrier_count += 1;
365
366        // Reset when all have arrived
367        if state.barrier_count == self.world_size {
368            state.barrier_count = 0;
369        }
370    }
371
372    fn send(&self, data: &[f32], dst: usize, tag: usize) {
373        let mut state = self.state.lock().unwrap();
374        state.messages.insert((self.rank, dst, tag), data.to_vec());
375    }
376
377    fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
378        let mut state = self.state.lock().unwrap();
379        if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
380            for (i, &val) in msg.iter().enumerate() {
381                if i < data.len() {
382                    data[i] = val;
383                }
384            }
385        }
386    }
387}
388
389// =============================================================================
390// Tests
391// =============================================================================
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_reduce_op_sum() {
399        let op = ReduceOp::Sum;
400        assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
401    }
402
403    #[test]
404    fn test_reduce_op_product() {
405        let op = ReduceOp::Product;
406        assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
407    }
408
409    #[test]
410    fn test_reduce_op_min() {
411        let op = ReduceOp::Min;
412        assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
413    }
414
415    #[test]
416    fn test_reduce_op_max() {
417        let op = ReduceOp::Max;
418        assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
419    }
420
421    #[test]
422    fn test_reduce_slices_sum() {
423        let op = ReduceOp::Sum;
424        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
425        let result = op.reduce_slices(&slices);
426        assert_eq!(result, vec![9.0, 12.0]);
427    }
428
429    #[test]
430    fn test_reduce_slices_average() {
431        let op = ReduceOp::Average;
432        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
433        let result = op.reduce_slices(&slices);
434        assert_eq!(result, vec![2.0, 3.0]);
435    }
436
437    #[test]
438    fn test_mock_backend_single() {
439        let backend = MockBackend::single();
440        assert_eq!(backend.rank(), 0);
441        assert_eq!(backend.world_size(), 1);
442        assert_eq!(backend.name(), "mock");
443    }
444
445    #[test]
446    fn test_mock_backend_world() {
447        let backends = MockBackend::create_world(4);
448        assert_eq!(backends.len(), 4);
449
450        for (i, b) in backends.iter().enumerate() {
451            assert_eq!(b.rank(), i);
452            assert_eq!(b.world_size(), 4);
453        }
454    }
455
456    #[test]
457    fn test_mock_all_reduce() {
458        // Note: In a real distributed system, all_reduce would be called from different
459        // processes simultaneously. The mock backend simulates a single process,
460        // so values remain unchanged when called sequentially from same thread.
461        let backend = MockBackend::single();
462
463        let mut data = vec![1.0, 2.0];
464        backend.all_reduce(&mut data, ReduceOp::Sum);
465
466        // With single rank, values remain the same
467        assert_eq!(data, vec![1.0, 2.0]);
468    }
469
470    #[test]
471    fn test_mock_broadcast() {
472        let backends = MockBackend::create_world(2);
473
474        let mut data0 = vec![1.0, 2.0, 3.0];
475        let mut data1 = vec![0.0, 0.0, 0.0];
476
477        // Broadcast from rank 0
478        backends[0].broadcast(&mut data0, 0);
479        backends[1].broadcast(&mut data1, 0);
480
481        assert_eq!(data0, vec![1.0, 2.0, 3.0]);
482        assert_eq!(data1, vec![1.0, 2.0, 3.0]);
483    }
484
485    #[test]
486    fn test_mock_send_recv() {
487        let backends = MockBackend::create_world(2);
488
489        // Send from rank 0 to rank 1
490        let send_data = vec![1.0, 2.0, 3.0];
491        backends[0].send(&send_data, 1, 0);
492
493        // Receive on rank 1
494        let mut recv_data = vec![0.0, 0.0, 0.0];
495        backends[1].recv(&mut recv_data, 0, 0);
496
497        assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
498    }
499
500    #[test]
501    fn test_mock_barrier() {
502        let backends = MockBackend::create_world(2);
503
504        // Both call barrier
505        backends[0].barrier();
506        backends[1].barrier();
507
508        // Should not deadlock
509    }
510}