Skip to main content

axonml_distributed/
backend.rs

1//! Backend - Communication Backend Abstractions
2//!
3//! # File
4//! `crates/axonml-distributed/src/backend.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20// =============================================================================
21// Reduce Operations
22// =============================================================================
23
24/// Reduction operation for collective communication.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ReduceOp {
27    /// Sum all values.
28    Sum,
29    /// Compute product of all values.
30    Product,
31    /// Find minimum value.
32    Min,
33    /// Find maximum value.
34    Max,
35    /// Compute average of all values.
36    Average,
37}
38
39impl ReduceOp {
40    /// Applies the reduction operation to two f32 values.
41    #[must_use]
42    pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
43        match self {
44            ReduceOp::Sum => a + b,
45            ReduceOp::Product => a * b,
46            ReduceOp::Min => a.min(b),
47            ReduceOp::Max => a.max(b),
48            ReduceOp::Average => f32::midpoint(a, b),
49        }
50    }
51
52    /// Applies the reduction operation to slices.
53    #[must_use]
54    pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
55        if slices.is_empty() {
56            return Vec::new();
57        }
58
59        let len = slices[0].len();
60
61        // Average gets its own path to avoid incorrect pairwise midpoint
62        if *self == ReduceOp::Average {
63            let mut result = vec![0.0f32; len];
64            for slice in slices {
65                for (i, &val) in slice.iter().enumerate() {
66                    if i < len {
67                        result[i] += val;
68                    }
69                }
70            }
71            let count = slices.len() as f32;
72            for val in &mut result {
73                *val /= count;
74            }
75            return result;
76        }
77
78        // All other ops: pairwise reduction
79        let mut result = slices[0].clone();
80        for slice in slices.iter().skip(1) {
81            for (i, &val) in slice.iter().enumerate() {
82                if i < len {
83                    result[i] = self.apply_f32(result[i], val);
84                }
85            }
86        }
87
88        result
89    }
90}
91
92// =============================================================================
93// Backend Trait
94// =============================================================================
95
96/// Trait for distributed communication backends.
97pub trait Backend: Send + Sync {
98    /// Returns the name of the backend.
99    fn name(&self) -> &str;
100
101    /// Returns the rank of this process.
102    fn rank(&self) -> usize;
103
104    /// Returns the total world size.
105    fn world_size(&self) -> usize;
106
107    /// Performs all-reduce operation.
108    fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
109
110    /// Broadcasts data from a source rank.
111    fn broadcast(&self, data: &mut [f32], src: usize);
112
113    /// Performs all-gather operation.
114    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
115
116    /// Performs reduce-scatter operation.
117    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
118
119    /// Performs gather operation.
120    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
121
122    /// Performs scatter operation.
123    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
124
125    /// Performs reduce operation (result only on dst rank).
126    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
127
128    /// Synchronizes all processes.
129    fn barrier(&self);
130
131    /// Sends data to a specific rank.
132    fn send(&self, data: &[f32], dst: usize, tag: usize);
133
134    /// Receives data from a specific rank.
135    fn recv(&self, data: &mut [f32], src: usize, tag: usize);
136}
137
138// =============================================================================
139// SharedState for Mock Backend
140// =============================================================================
141
142/// Shared state for mock distributed communication.
143#[derive(Debug)]
144struct SharedState {
145    /// Data buffers for each rank.
146    buffers: HashMap<usize, Vec<f32>>,
147    /// Barrier counter.
148    barrier_count: usize,
149    /// Message queue for send/recv operations.
150    messages: HashMap<(usize, usize, usize), Vec<f32>>, // (src, dst, tag) -> data
151}
152
153// =============================================================================
154// Mock Backend
155// =============================================================================
156
157/// A mock backend for testing distributed operations in a single process.
158/// Simulates distributed communication without actual network operations.
159pub struct MockBackend {
160    rank: usize,
161    world_size: usize,
162    state: Arc<Mutex<SharedState>>,
163}
164
165impl MockBackend {
166    /// Creates a collection of mock backends for testing.
167    #[must_use]
168    pub fn create_world(world_size: usize) -> Vec<Self> {
169        let state = Arc::new(Mutex::new(SharedState {
170            buffers: HashMap::new(),
171            barrier_count: 0,
172            messages: HashMap::new(),
173        }));
174
175        (0..world_size)
176            .map(|rank| MockBackend {
177                rank,
178                world_size,
179                state: Arc::clone(&state),
180            })
181            .collect()
182    }
183
184    /// Creates a single mock backend (rank 0, world size 1).
185    #[must_use]
186    pub fn single() -> Self {
187        MockBackend::create_world(1).pop().unwrap()
188    }
189}
190
191impl Backend for MockBackend {
192    fn name(&self) -> &'static str {
193        "mock"
194    }
195
196    fn rank(&self) -> usize {
197        self.rank
198    }
199
200    fn world_size(&self) -> usize {
201        self.world_size
202    }
203
204    fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
205        let mut state = self.state.lock().unwrap();
206
207        // Store this rank's data
208        state.buffers.insert(self.rank, data.to_vec());
209
210        // Check if all ranks have submitted
211        if state.buffers.len() == self.world_size {
212            // Perform reduction
213            let all_data: Vec<Vec<f32>> = (0..self.world_size)
214                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
215                .collect();
216
217            let reduced = op.reduce_slices(&all_data);
218
219            // Update all buffers with result
220            for r in 0..self.world_size {
221                state.buffers.insert(r, reduced.clone());
222            }
223        }
224
225        // Get result for this rank
226        if let Some(result) = state.buffers.get(&self.rank) {
227            for (i, &val) in result.iter().enumerate() {
228                if i < data.len() {
229                    data[i] = val;
230                }
231            }
232        }
233
234        // Clear buffers if this is the last rank to read
235        if state.buffers.len() == self.world_size {
236            state.buffers.clear();
237        }
238    }
239
240    fn broadcast(&self, data: &mut [f32], src: usize) {
241        let mut state = self.state.lock().unwrap();
242
243        if self.rank == src {
244            // Source rank stores its data
245            state.buffers.insert(0, data.to_vec());
246        }
247
248        // Get broadcast data
249        if let Some(src_data) = state.buffers.get(&0) {
250            for (i, &val) in src_data.iter().enumerate() {
251                if i < data.len() {
252                    data[i] = val;
253                }
254            }
255        }
256    }
257
258    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
259        let mut state = self.state.lock().unwrap();
260
261        // Store this rank's data
262        state.buffers.insert(self.rank, send_data.to_vec());
263
264        // Check if all ranks have submitted
265        if state.buffers.len() == self.world_size {
266            // Concatenate all data in rank order
267            let chunk_size = send_data.len();
268            for r in 0..self.world_size {
269                if let Some(data) = state.buffers.get(&r) {
270                    let start = r * chunk_size;
271                    for (i, &val) in data.iter().enumerate() {
272                        if start + i < recv_data.len() {
273                            recv_data[start + i] = val;
274                        }
275                    }
276                }
277            }
278        }
279    }
280
281    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
282        let mut state = self.state.lock().unwrap();
283
284        // Store this rank's data
285        state.buffers.insert(self.rank, send_data.to_vec());
286
287        // Check if all ranks have submitted
288        if state.buffers.len() == self.world_size {
289            // First reduce all data
290            let all_data: Vec<Vec<f32>> = (0..self.world_size)
291                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
292                .collect();
293
294            let reduced = op.reduce_slices(&all_data);
295
296            // Scatter to each rank
297            let chunk_size = recv_data.len();
298            let start = self.rank * chunk_size;
299            let end = (start + chunk_size).min(reduced.len());
300
301            for (i, &val) in reduced[start..end].iter().enumerate() {
302                if i < recv_data.len() {
303                    recv_data[i] = val;
304                }
305            }
306        }
307    }
308
309    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
310        let mut state = self.state.lock().unwrap();
311
312        // Store this rank's data
313        state.buffers.insert(self.rank, send_data.to_vec());
314
315        // Only destination rank collects
316        if self.rank == dst && state.buffers.len() == self.world_size {
317            let chunk_size = send_data.len();
318            for r in 0..self.world_size {
319                if let Some(data) = state.buffers.get(&r) {
320                    let start = r * chunk_size;
321                    for (i, &val) in data.iter().enumerate() {
322                        if start + i < recv_data.len() {
323                            recv_data[start + i] = val;
324                        }
325                    }
326                }
327            }
328        }
329    }
330
331    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
332        let state = self.state.lock().unwrap();
333
334        // Only source rank has full data
335        if self.rank == src {
336            // Scatter to all (in mock, we copy our portion)
337            let chunk_size = recv_data.len();
338            let start = self.rank * chunk_size;
339            let end = (start + chunk_size).min(send_data.len());
340
341            for (i, &val) in send_data[start..end].iter().enumerate() {
342                recv_data[i] = val;
343            }
344        }
345        drop(state);
346
347        // Others would receive via message passing in real impl
348    }
349
350    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
351        let mut state = self.state.lock().unwrap();
352
353        // Store this rank's data
354        state.buffers.insert(self.rank, send_data.to_vec());
355
356        // Only destination rank reduces
357        if self.rank == dst && state.buffers.len() == self.world_size {
358            let all_data: Vec<Vec<f32>> = (0..self.world_size)
359                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
360                .collect();
361
362            let reduced = op.reduce_slices(&all_data);
363
364            for (i, &val) in reduced.iter().enumerate() {
365                if i < recv_data.len() {
366                    recv_data[i] = val;
367                }
368            }
369        }
370    }
371
372    fn barrier(&self) {
373        let mut state = self.state.lock().unwrap();
374        state.barrier_count += 1;
375
376        // Reset when all have arrived
377        if state.barrier_count == self.world_size {
378            state.barrier_count = 0;
379        }
380    }
381
382    fn send(&self, data: &[f32], dst: usize, tag: usize) {
383        let mut state = self.state.lock().unwrap();
384        state.messages.insert((self.rank, dst, tag), data.to_vec());
385    }
386
387    fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
388        let mut state = self.state.lock().unwrap();
389        if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
390            for (i, &val) in msg.iter().enumerate() {
391                if i < data.len() {
392                    data[i] = val;
393                }
394            }
395        }
396    }
397}
398
399// =============================================================================
400// Tests
401// =============================================================================
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_reduce_op_sum() {
409        let op = ReduceOp::Sum;
410        assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
411    }
412
413    #[test]
414    fn test_reduce_op_product() {
415        let op = ReduceOp::Product;
416        assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
417    }
418
419    #[test]
420    fn test_reduce_op_min() {
421        let op = ReduceOp::Min;
422        assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
423    }
424
425    #[test]
426    fn test_reduce_op_max() {
427        let op = ReduceOp::Max;
428        assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
429    }
430
431    #[test]
432    fn test_reduce_slices_sum() {
433        let op = ReduceOp::Sum;
434        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
435        let result = op.reduce_slices(&slices);
436        assert_eq!(result, vec![9.0, 12.0]);
437    }
438
439    #[test]
440    fn test_reduce_slices_average() {
441        let op = ReduceOp::Average;
442        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
443        let result = op.reduce_slices(&slices);
444        assert_eq!(result, vec![2.0, 3.0]);
445    }
446
447    #[test]
448    fn test_mock_backend_single() {
449        let backend = MockBackend::single();
450        assert_eq!(backend.rank(), 0);
451        assert_eq!(backend.world_size(), 1);
452        assert_eq!(backend.name(), "mock");
453    }
454
455    #[test]
456    fn test_mock_backend_world() {
457        let backends = MockBackend::create_world(4);
458        assert_eq!(backends.len(), 4);
459
460        for (i, b) in backends.iter().enumerate() {
461            assert_eq!(b.rank(), i);
462            assert_eq!(b.world_size(), 4);
463        }
464    }
465
466    #[test]
467    fn test_mock_all_reduce() {
468        // Note: In a real distributed system, all_reduce would be called from different
469        // processes simultaneously. The mock backend simulates a single process,
470        // so values remain unchanged when called sequentially from same thread.
471        let backend = MockBackend::single();
472
473        let mut data = vec![1.0, 2.0];
474        backend.all_reduce(&mut data, ReduceOp::Sum);
475
476        // With single rank, values remain the same
477        assert_eq!(data, vec![1.0, 2.0]);
478    }
479
480    #[test]
481    fn test_mock_broadcast() {
482        let backends = MockBackend::create_world(2);
483
484        let mut data0 = vec![1.0, 2.0, 3.0];
485        let mut data1 = vec![0.0, 0.0, 0.0];
486
487        // Broadcast from rank 0
488        backends[0].broadcast(&mut data0, 0);
489        backends[1].broadcast(&mut data1, 0);
490
491        assert_eq!(data0, vec![1.0, 2.0, 3.0]);
492        assert_eq!(data1, vec![1.0, 2.0, 3.0]);
493    }
494
495    #[test]
496    fn test_mock_send_recv() {
497        let backends = MockBackend::create_world(2);
498
499        // Send from rank 0 to rank 1
500        let send_data = vec![1.0, 2.0, 3.0];
501        backends[0].send(&send_data, 1, 0);
502
503        // Receive on rank 1
504        let mut recv_data = vec![0.0, 0.0, 0.0];
505        backends[1].recv(&mut recv_data, 0, 0);
506
507        assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
508    }
509
510    #[test]
511    fn test_mock_barrier() {
512        let backends = MockBackend::create_world(2);
513
514        // Both call barrier
515        backends[0].barrier();
516        backends[1].barrier();
517
518        // Should not deadlock
519    }
520}