Skip to main content

axonml_distributed/
backend.rs

1//! Backend - Communication Backend Abstractions
2//!
3//! # File
4//! `crates/axonml-distributed/src/backend.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use std::collections::HashMap;
19use std::sync::{Arc, Mutex};
20
21// =============================================================================
22// Reduce Operations
23// =============================================================================
24
25/// Reduction operation for collective communication.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ReduceOp {
28    /// Sum all values.
29    Sum,
30    /// Compute product of all values.
31    Product,
32    /// Find minimum value.
33    Min,
34    /// Find maximum value.
35    Max,
36    /// Compute average of all values.
37    Average,
38}
39
40impl ReduceOp {
41    /// Applies the reduction operation to two f32 values.
42    #[must_use]
43    pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
44        match self {
45            ReduceOp::Sum => a + b,
46            ReduceOp::Product => a * b,
47            ReduceOp::Min => a.min(b),
48            ReduceOp::Max => a.max(b),
49            ReduceOp::Average => f32::midpoint(a, b),
50        }
51    }
52
53    /// Applies the reduction operation to slices.
54    #[must_use]
55    pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
56        if slices.is_empty() {
57            return Vec::new();
58        }
59
60        let len = slices[0].len();
61
62        // Average gets its own path to avoid incorrect pairwise midpoint
63        if *self == ReduceOp::Average {
64            let mut result = vec![0.0f32; len];
65            for slice in slices {
66                for (i, &val) in slice.iter().enumerate() {
67                    if i < len {
68                        result[i] += val;
69                    }
70                }
71            }
72            let count = slices.len() as f32;
73            for val in &mut result {
74                *val /= count;
75            }
76            return result;
77        }
78
79        // All other ops: pairwise reduction
80        let mut result = slices[0].clone();
81        for slice in slices.iter().skip(1) {
82            for (i, &val) in slice.iter().enumerate() {
83                if i < len {
84                    result[i] = self.apply_f32(result[i], val);
85                }
86            }
87        }
88
89        result
90    }
91}
92
93// =============================================================================
94// Backend Trait
95// =============================================================================
96
97/// Trait for distributed communication backends.
98pub trait Backend: Send + Sync {
99    /// Returns the name of the backend.
100    fn name(&self) -> &str;
101
102    /// Returns the rank of this process.
103    fn rank(&self) -> usize;
104
105    /// Returns the total world size.
106    fn world_size(&self) -> usize;
107
108    /// Performs all-reduce operation.
109    fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
110
111    /// Broadcasts data from a source rank.
112    fn broadcast(&self, data: &mut [f32], src: usize);
113
114    /// Performs all-gather operation.
115    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
116
117    /// Performs reduce-scatter operation.
118    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
119
120    /// Performs gather operation.
121    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
122
123    /// Performs scatter operation.
124    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
125
126    /// Performs reduce operation (result only on dst rank).
127    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
128
129    /// Synchronizes all processes.
130    fn barrier(&self);
131
132    /// Sends data to a specific rank.
133    fn send(&self, data: &[f32], dst: usize, tag: usize);
134
135    /// Receives data from a specific rank.
136    fn recv(&self, data: &mut [f32], src: usize, tag: usize);
137}
138
139// =============================================================================
140// SharedState for Mock Backend
141// =============================================================================
142
143/// Shared state for mock distributed communication.
144#[derive(Debug)]
145struct SharedState {
146    /// Data buffers for each rank.
147    buffers: HashMap<usize, Vec<f32>>,
148    /// Barrier counter.
149    barrier_count: usize,
150    /// Message queue for send/recv operations.
151    messages: HashMap<(usize, usize, usize), Vec<f32>>, // (src, dst, tag) -> data
152}
153
154// =============================================================================
155// Mock Backend
156// =============================================================================
157
158/// A mock backend for testing distributed operations in a single process.
159/// Simulates distributed communication without actual network operations.
160pub struct MockBackend {
161    rank: usize,
162    world_size: usize,
163    state: Arc<Mutex<SharedState>>,
164}
165
166impl MockBackend {
167    /// Creates a collection of mock backends for testing.
168    #[must_use]
169    pub fn create_world(world_size: usize) -> Vec<Self> {
170        let state = Arc::new(Mutex::new(SharedState {
171            buffers: HashMap::new(),
172            barrier_count: 0,
173            messages: HashMap::new(),
174        }));
175
176        (0..world_size)
177            .map(|rank| MockBackend {
178                rank,
179                world_size,
180                state: Arc::clone(&state),
181            })
182            .collect()
183    }
184
185    /// Creates a single mock backend (rank 0, world size 1).
186    #[must_use]
187    pub fn single() -> Self {
188        MockBackend::create_world(1).pop().unwrap()
189    }
190}
191
192impl Backend for MockBackend {
193    fn name(&self) -> &'static str {
194        "mock"
195    }
196
197    fn rank(&self) -> usize {
198        self.rank
199    }
200
201    fn world_size(&self) -> usize {
202        self.world_size
203    }
204
205    fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
206        let mut state = self.state.lock().unwrap();
207
208        // Store this rank's data
209        state.buffers.insert(self.rank, data.to_vec());
210
211        // Check if all ranks have submitted
212        if state.buffers.len() == self.world_size {
213            // Perform reduction
214            let all_data: Vec<Vec<f32>> = (0..self.world_size)
215                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
216                .collect();
217
218            let reduced = op.reduce_slices(&all_data);
219
220            // Update all buffers with result
221            for r in 0..self.world_size {
222                state.buffers.insert(r, reduced.clone());
223            }
224        }
225
226        // Get result for this rank
227        if let Some(result) = state.buffers.get(&self.rank) {
228            for (i, &val) in result.iter().enumerate() {
229                if i < data.len() {
230                    data[i] = val;
231                }
232            }
233        }
234
235        // Clear buffers if this is the last rank to read
236        if state.buffers.len() == self.world_size {
237            state.buffers.clear();
238        }
239    }
240
241    fn broadcast(&self, data: &mut [f32], src: usize) {
242        let mut state = self.state.lock().unwrap();
243
244        if self.rank == src {
245            // Source rank stores its data
246            state.buffers.insert(0, data.to_vec());
247        }
248
249        // Get broadcast data
250        if let Some(src_data) = state.buffers.get(&0) {
251            for (i, &val) in src_data.iter().enumerate() {
252                if i < data.len() {
253                    data[i] = val;
254                }
255            }
256        }
257    }
258
259    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
260        let mut state = self.state.lock().unwrap();
261
262        // Store this rank's data
263        state.buffers.insert(self.rank, send_data.to_vec());
264
265        // Check if all ranks have submitted
266        if state.buffers.len() == self.world_size {
267            // Concatenate all data in rank order
268            let chunk_size = send_data.len();
269            for r in 0..self.world_size {
270                if let Some(data) = state.buffers.get(&r) {
271                    let start = r * chunk_size;
272                    for (i, &val) in data.iter().enumerate() {
273                        if start + i < recv_data.len() {
274                            recv_data[start + i] = val;
275                        }
276                    }
277                }
278            }
279        }
280    }
281
282    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
283        let mut state = self.state.lock().unwrap();
284
285        // Store this rank's data
286        state.buffers.insert(self.rank, send_data.to_vec());
287
288        // Check if all ranks have submitted
289        if state.buffers.len() == self.world_size {
290            // First reduce all data
291            let all_data: Vec<Vec<f32>> = (0..self.world_size)
292                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
293                .collect();
294
295            let reduced = op.reduce_slices(&all_data);
296
297            // Scatter to each rank
298            let chunk_size = recv_data.len();
299            let start = self.rank * chunk_size;
300            let end = (start + chunk_size).min(reduced.len());
301
302            for (i, &val) in reduced[start..end].iter().enumerate() {
303                if i < recv_data.len() {
304                    recv_data[i] = val;
305                }
306            }
307        }
308    }
309
310    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
311        let mut state = self.state.lock().unwrap();
312
313        // Store this rank's data
314        state.buffers.insert(self.rank, send_data.to_vec());
315
316        // Only destination rank collects
317        if self.rank == dst && state.buffers.len() == self.world_size {
318            let chunk_size = send_data.len();
319            for r in 0..self.world_size {
320                if let Some(data) = state.buffers.get(&r) {
321                    let start = r * chunk_size;
322                    for (i, &val) in data.iter().enumerate() {
323                        if start + i < recv_data.len() {
324                            recv_data[start + i] = val;
325                        }
326                    }
327                }
328            }
329        }
330    }
331
332    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
333        let state = self.state.lock().unwrap();
334
335        // Only source rank has full data
336        if self.rank == src {
337            // Scatter to all (in mock, we copy our portion)
338            let chunk_size = recv_data.len();
339            let start = self.rank * chunk_size;
340            let end = (start + chunk_size).min(send_data.len());
341
342            for (i, &val) in send_data[start..end].iter().enumerate() {
343                recv_data[i] = val;
344            }
345        }
346        drop(state);
347
348        // Others would receive via message passing in real impl
349    }
350
351    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
352        let mut state = self.state.lock().unwrap();
353
354        // Store this rank's data
355        state.buffers.insert(self.rank, send_data.to_vec());
356
357        // Only destination rank reduces
358        if self.rank == dst && state.buffers.len() == self.world_size {
359            let all_data: Vec<Vec<f32>> = (0..self.world_size)
360                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
361                .collect();
362
363            let reduced = op.reduce_slices(&all_data);
364
365            for (i, &val) in reduced.iter().enumerate() {
366                if i < recv_data.len() {
367                    recv_data[i] = val;
368                }
369            }
370        }
371    }
372
373    fn barrier(&self) {
374        let mut state = self.state.lock().unwrap();
375        state.barrier_count += 1;
376
377        // Reset when all have arrived
378        if state.barrier_count == self.world_size {
379            state.barrier_count = 0;
380        }
381    }
382
383    fn send(&self, data: &[f32], dst: usize, tag: usize) {
384        let mut state = self.state.lock().unwrap();
385        state.messages.insert((self.rank, dst, tag), data.to_vec());
386    }
387
388    fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
389        let mut state = self.state.lock().unwrap();
390        if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
391            for (i, &val) in msg.iter().enumerate() {
392                if i < data.len() {
393                    data[i] = val;
394                }
395            }
396        }
397    }
398}
399
400// =============================================================================
401// Tests
402// =============================================================================
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_reduce_op_sum() {
410        let op = ReduceOp::Sum;
411        assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
412    }
413
414    #[test]
415    fn test_reduce_op_product() {
416        let op = ReduceOp::Product;
417        assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
418    }
419
420    #[test]
421    fn test_reduce_op_min() {
422        let op = ReduceOp::Min;
423        assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
424    }
425
426    #[test]
427    fn test_reduce_op_max() {
428        let op = ReduceOp::Max;
429        assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
430    }
431
432    #[test]
433    fn test_reduce_slices_sum() {
434        let op = ReduceOp::Sum;
435        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
436        let result = op.reduce_slices(&slices);
437        assert_eq!(result, vec![9.0, 12.0]);
438    }
439
440    #[test]
441    fn test_reduce_slices_average() {
442        let op = ReduceOp::Average;
443        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
444        let result = op.reduce_slices(&slices);
445        assert_eq!(result, vec![2.0, 3.0]);
446    }
447
448    #[test]
449    fn test_mock_backend_single() {
450        let backend = MockBackend::single();
451        assert_eq!(backend.rank(), 0);
452        assert_eq!(backend.world_size(), 1);
453        assert_eq!(backend.name(), "mock");
454    }
455
456    #[test]
457    fn test_mock_backend_world() {
458        let backends = MockBackend::create_world(4);
459        assert_eq!(backends.len(), 4);
460
461        for (i, b) in backends.iter().enumerate() {
462            assert_eq!(b.rank(), i);
463            assert_eq!(b.world_size(), 4);
464        }
465    }
466
467    #[test]
468    fn test_mock_all_reduce() {
469        // Note: In a real distributed system, all_reduce would be called from different
470        // processes simultaneously. The mock backend simulates a single process,
471        // so values remain unchanged when called sequentially from same thread.
472        let backend = MockBackend::single();
473
474        let mut data = vec![1.0, 2.0];
475        backend.all_reduce(&mut data, ReduceOp::Sum);
476
477        // With single rank, values remain the same
478        assert_eq!(data, vec![1.0, 2.0]);
479    }
480
481    #[test]
482    fn test_mock_broadcast() {
483        let backends = MockBackend::create_world(2);
484
485        let mut data0 = vec![1.0, 2.0, 3.0];
486        let mut data1 = vec![0.0, 0.0, 0.0];
487
488        // Broadcast from rank 0
489        backends[0].broadcast(&mut data0, 0);
490        backends[1].broadcast(&mut data1, 0);
491
492        assert_eq!(data0, vec![1.0, 2.0, 3.0]);
493        assert_eq!(data1, vec![1.0, 2.0, 3.0]);
494    }
495
496    #[test]
497    fn test_mock_send_recv() {
498        let backends = MockBackend::create_world(2);
499
500        // Send from rank 0 to rank 1
501        let send_data = vec![1.0, 2.0, 3.0];
502        backends[0].send(&send_data, 1, 0);
503
504        // Receive on rank 1
505        let mut recv_data = vec![0.0, 0.0, 0.0];
506        backends[1].recv(&mut recv_data, 0, 0);
507
508        assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
509    }
510
511    #[test]
512    fn test_mock_barrier() {
513        let backends = MockBackend::create_world(2);
514
515        // Both call barrier
516        backends[0].barrier();
517        backends[1].barrier();
518
519        // Should not deadlock
520    }
521}