Skip to main content

axonml_distributed/
backend.rs

1//! Backend - Communication Backend Abstractions
2//!
3//! # File
4//! `crates/axonml-distributed/src/backend.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20// =============================================================================
21// Reduce Operations
22// =============================================================================
23
24/// Reduction operation for collective communication.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ReduceOp {
27    /// Sum all values.
28    Sum,
29    /// Compute product of all values.
30    Product,
31    /// Find minimum value.
32    Min,
33    /// Find maximum value.
34    Max,
35    /// Compute average of all values.
36    Average,
37}
38
39impl ReduceOp {
40    /// Applies the reduction operation to two f32 values.
41    #[must_use]
42    pub fn apply_f32(&self, a: f32, b: f32) -> f32 {
43        match self {
44            ReduceOp::Sum => a + b,
45            ReduceOp::Product => a * b,
46            ReduceOp::Min => a.min(b),
47            ReduceOp::Max => a.max(b),
48            ReduceOp::Average => f32::midpoint(a, b),
49        }
50    }
51
52    /// Applies the reduction operation to slices.
53    #[must_use]
54    pub fn reduce_slices(&self, slices: &[Vec<f32>]) -> Vec<f32> {
55        if slices.is_empty() {
56            return Vec::new();
57        }
58
59        let len = slices[0].len();
60        let mut result = slices[0].clone();
61
62        for slice in slices.iter().skip(1) {
63            for (i, &val) in slice.iter().enumerate() {
64                if i < len {
65                    result[i] = self.apply_f32(result[i], val);
66                }
67            }
68        }
69
70        // For average, we need to divide by count (already averaged pairwise above)
71        if *self == ReduceOp::Average && slices.len() > 1 {
72            // Re-compute as actual average
73            result = vec![0.0; len];
74            for slice in slices {
75                for (i, &val) in slice.iter().enumerate() {
76                    if i < len {
77                        result[i] += val;
78                    }
79                }
80            }
81            let count = slices.len() as f32;
82            for val in &mut result {
83                *val /= count;
84            }
85        }
86
87        result
88    }
89}
90
91// =============================================================================
92// Backend Trait
93// =============================================================================
94
95/// Trait for distributed communication backends.
96pub trait Backend: Send + Sync {
97    /// Returns the name of the backend.
98    fn name(&self) -> &str;
99
100    /// Returns the rank of this process.
101    fn rank(&self) -> usize;
102
103    /// Returns the total world size.
104    fn world_size(&self) -> usize;
105
106    /// Performs all-reduce operation.
107    fn all_reduce(&self, data: &mut [f32], op: ReduceOp);
108
109    /// Broadcasts data from a source rank.
110    fn broadcast(&self, data: &mut [f32], src: usize);
111
112    /// Performs all-gather operation.
113    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]);
114
115    /// Performs reduce-scatter operation.
116    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp);
117
118    /// Performs gather operation.
119    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize);
120
121    /// Performs scatter operation.
122    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize);
123
124    /// Performs reduce operation (result only on dst rank).
125    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp);
126
127    /// Synchronizes all processes.
128    fn barrier(&self);
129
130    /// Sends data to a specific rank.
131    fn send(&self, data: &[f32], dst: usize, tag: usize);
132
133    /// Receives data from a specific rank.
134    fn recv(&self, data: &mut [f32], src: usize, tag: usize);
135}
136
137// =============================================================================
138// SharedState for Mock Backend
139// =============================================================================
140
141/// Shared state for mock distributed communication.
142#[derive(Debug)]
143struct SharedState {
144    /// Data buffers for each rank.
145    buffers: HashMap<usize, Vec<f32>>,
146    /// Barrier counter.
147    barrier_count: usize,
148    /// Message queue for send/recv operations.
149    messages: HashMap<(usize, usize, usize), Vec<f32>>, // (src, dst, tag) -> data
150}
151
152// =============================================================================
153// Mock Backend
154// =============================================================================
155
156/// A mock backend for testing distributed operations in a single process.
157/// Simulates distributed communication without actual network operations.
158pub struct MockBackend {
159    rank: usize,
160    world_size: usize,
161    state: Arc<Mutex<SharedState>>,
162}
163
164impl MockBackend {
165    /// Creates a collection of mock backends for testing.
166    #[must_use]
167    pub fn create_world(world_size: usize) -> Vec<Self> {
168        let state = Arc::new(Mutex::new(SharedState {
169            buffers: HashMap::new(),
170            barrier_count: 0,
171            messages: HashMap::new(),
172        }));
173
174        (0..world_size)
175            .map(|rank| MockBackend {
176                rank,
177                world_size,
178                state: Arc::clone(&state),
179            })
180            .collect()
181    }
182
183    /// Creates a single mock backend (rank 0, world size 1).
184    #[must_use]
185    pub fn single() -> Self {
186        MockBackend::create_world(1).pop().unwrap()
187    }
188}
189
190impl Backend for MockBackend {
191    fn name(&self) -> &'static str {
192        "mock"
193    }
194
195    fn rank(&self) -> usize {
196        self.rank
197    }
198
199    fn world_size(&self) -> usize {
200        self.world_size
201    }
202
203    fn all_reduce(&self, data: &mut [f32], op: ReduceOp) {
204        let mut state = self.state.lock().unwrap();
205
206        // Store this rank's data
207        state.buffers.insert(self.rank, data.to_vec());
208
209        // Check if all ranks have submitted
210        if state.buffers.len() == self.world_size {
211            // Perform reduction
212            let all_data: Vec<Vec<f32>> = (0..self.world_size)
213                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
214                .collect();
215
216            let reduced = op.reduce_slices(&all_data);
217
218            // Update all buffers with result
219            for r in 0..self.world_size {
220                state.buffers.insert(r, reduced.clone());
221            }
222        }
223
224        // Get result for this rank
225        if let Some(result) = state.buffers.get(&self.rank) {
226            for (i, &val) in result.iter().enumerate() {
227                if i < data.len() {
228                    data[i] = val;
229                }
230            }
231        }
232
233        // Clear buffers if this is the last rank to read
234        if state.buffers.len() == self.world_size {
235            state.buffers.clear();
236        }
237    }
238
239    fn broadcast(&self, data: &mut [f32], src: usize) {
240        let mut state = self.state.lock().unwrap();
241
242        if self.rank == src {
243            // Source rank stores its data
244            state.buffers.insert(0, data.to_vec());
245        }
246
247        // Get broadcast data
248        if let Some(src_data) = state.buffers.get(&0) {
249            for (i, &val) in src_data.iter().enumerate() {
250                if i < data.len() {
251                    data[i] = val;
252                }
253            }
254        }
255    }
256
257    fn all_gather(&self, send_data: &[f32], recv_data: &mut [f32]) {
258        let mut state = self.state.lock().unwrap();
259
260        // Store this rank's data
261        state.buffers.insert(self.rank, send_data.to_vec());
262
263        // Check if all ranks have submitted
264        if state.buffers.len() == self.world_size {
265            // Concatenate all data in rank order
266            let chunk_size = send_data.len();
267            for r in 0..self.world_size {
268                if let Some(data) = state.buffers.get(&r) {
269                    let start = r * chunk_size;
270                    for (i, &val) in data.iter().enumerate() {
271                        if start + i < recv_data.len() {
272                            recv_data[start + i] = val;
273                        }
274                    }
275                }
276            }
277        }
278    }
279
280    fn reduce_scatter(&self, send_data: &[f32], recv_data: &mut [f32], op: ReduceOp) {
281        let mut state = self.state.lock().unwrap();
282
283        // Store this rank's data
284        state.buffers.insert(self.rank, send_data.to_vec());
285
286        // Check if all ranks have submitted
287        if state.buffers.len() == self.world_size {
288            // First reduce all data
289            let all_data: Vec<Vec<f32>> = (0..self.world_size)
290                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
291                .collect();
292
293            let reduced = op.reduce_slices(&all_data);
294
295            // Scatter to each rank
296            let chunk_size = recv_data.len();
297            let start = self.rank * chunk_size;
298            let end = (start + chunk_size).min(reduced.len());
299
300            for (i, &val) in reduced[start..end].iter().enumerate() {
301                if i < recv_data.len() {
302                    recv_data[i] = val;
303                }
304            }
305        }
306    }
307
308    fn gather(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize) {
309        let mut state = self.state.lock().unwrap();
310
311        // Store this rank's data
312        state.buffers.insert(self.rank, send_data.to_vec());
313
314        // Only destination rank collects
315        if self.rank == dst && state.buffers.len() == self.world_size {
316            let chunk_size = send_data.len();
317            for r in 0..self.world_size {
318                if let Some(data) = state.buffers.get(&r) {
319                    let start = r * chunk_size;
320                    for (i, &val) in data.iter().enumerate() {
321                        if start + i < recv_data.len() {
322                            recv_data[start + i] = val;
323                        }
324                    }
325                }
326            }
327        }
328    }
329
330    fn scatter(&self, send_data: &[f32], recv_data: &mut [f32], src: usize) {
331        let state = self.state.lock().unwrap();
332
333        // Only source rank has full data
334        if self.rank == src {
335            // Scatter to all (in mock, we copy our portion)
336            let chunk_size = recv_data.len();
337            let start = self.rank * chunk_size;
338            let end = (start + chunk_size).min(send_data.len());
339
340            for (i, &val) in send_data[start..end].iter().enumerate() {
341                recv_data[i] = val;
342            }
343        }
344        drop(state);
345
346        // Others would receive via message passing in real impl
347    }
348
349    fn reduce(&self, send_data: &[f32], recv_data: &mut [f32], dst: usize, op: ReduceOp) {
350        let mut state = self.state.lock().unwrap();
351
352        // Store this rank's data
353        state.buffers.insert(self.rank, send_data.to_vec());
354
355        // Only destination rank reduces
356        if self.rank == dst && state.buffers.len() == self.world_size {
357            let all_data: Vec<Vec<f32>> = (0..self.world_size)
358                .map(|r| state.buffers.get(&r).cloned().unwrap_or_default())
359                .collect();
360
361            let reduced = op.reduce_slices(&all_data);
362
363            for (i, &val) in reduced.iter().enumerate() {
364                if i < recv_data.len() {
365                    recv_data[i] = val;
366                }
367            }
368        }
369    }
370
371    fn barrier(&self) {
372        let mut state = self.state.lock().unwrap();
373        state.barrier_count += 1;
374
375        // Reset when all have arrived
376        if state.barrier_count == self.world_size {
377            state.barrier_count = 0;
378        }
379    }
380
381    fn send(&self, data: &[f32], dst: usize, tag: usize) {
382        let mut state = self.state.lock().unwrap();
383        state.messages.insert((self.rank, dst, tag), data.to_vec());
384    }
385
386    fn recv(&self, data: &mut [f32], src: usize, tag: usize) {
387        let mut state = self.state.lock().unwrap();
388        if let Some(msg) = state.messages.remove(&(src, self.rank, tag)) {
389            for (i, &val) in msg.iter().enumerate() {
390                if i < data.len() {
391                    data[i] = val;
392                }
393            }
394        }
395    }
396}
397
398// =============================================================================
399// Tests
400// =============================================================================
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_reduce_op_sum() {
408        let op = ReduceOp::Sum;
409        assert_eq!(op.apply_f32(1.0, 2.0), 3.0);
410    }
411
412    #[test]
413    fn test_reduce_op_product() {
414        let op = ReduceOp::Product;
415        assert_eq!(op.apply_f32(2.0, 3.0), 6.0);
416    }
417
418    #[test]
419    fn test_reduce_op_min() {
420        let op = ReduceOp::Min;
421        assert_eq!(op.apply_f32(2.0, 3.0), 2.0);
422    }
423
424    #[test]
425    fn test_reduce_op_max() {
426        let op = ReduceOp::Max;
427        assert_eq!(op.apply_f32(2.0, 3.0), 3.0);
428    }
429
430    #[test]
431    fn test_reduce_slices_sum() {
432        let op = ReduceOp::Sum;
433        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
434        let result = op.reduce_slices(&slices);
435        assert_eq!(result, vec![9.0, 12.0]);
436    }
437
438    #[test]
439    fn test_reduce_slices_average() {
440        let op = ReduceOp::Average;
441        let slices = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
442        let result = op.reduce_slices(&slices);
443        assert_eq!(result, vec![2.0, 3.0]);
444    }
445
446    #[test]
447    fn test_mock_backend_single() {
448        let backend = MockBackend::single();
449        assert_eq!(backend.rank(), 0);
450        assert_eq!(backend.world_size(), 1);
451        assert_eq!(backend.name(), "mock");
452    }
453
454    #[test]
455    fn test_mock_backend_world() {
456        let backends = MockBackend::create_world(4);
457        assert_eq!(backends.len(), 4);
458
459        for (i, b) in backends.iter().enumerate() {
460            assert_eq!(b.rank(), i);
461            assert_eq!(b.world_size(), 4);
462        }
463    }
464
465    #[test]
466    fn test_mock_all_reduce() {
467        // Note: In a real distributed system, all_reduce would be called from different
468        // processes simultaneously. The mock backend simulates a single process,
469        // so values remain unchanged when called sequentially from same thread.
470        let backend = MockBackend::single();
471
472        let mut data = vec![1.0, 2.0];
473        backend.all_reduce(&mut data, ReduceOp::Sum);
474
475        // With single rank, values remain the same
476        assert_eq!(data, vec![1.0, 2.0]);
477    }
478
479    #[test]
480    fn test_mock_broadcast() {
481        let backends = MockBackend::create_world(2);
482
483        let mut data0 = vec![1.0, 2.0, 3.0];
484        let mut data1 = vec![0.0, 0.0, 0.0];
485
486        // Broadcast from rank 0
487        backends[0].broadcast(&mut data0, 0);
488        backends[1].broadcast(&mut data1, 0);
489
490        assert_eq!(data0, vec![1.0, 2.0, 3.0]);
491        assert_eq!(data1, vec![1.0, 2.0, 3.0]);
492    }
493
494    #[test]
495    fn test_mock_send_recv() {
496        let backends = MockBackend::create_world(2);
497
498        // Send from rank 0 to rank 1
499        let send_data = vec![1.0, 2.0, 3.0];
500        backends[0].send(&send_data, 1, 0);
501
502        // Receive on rank 1
503        let mut recv_data = vec![0.0, 0.0, 0.0];
504        backends[1].recv(&mut recv_data, 0, 0);
505
506        assert_eq!(recv_data, vec![1.0, 2.0, 3.0]);
507    }
508
509    #[test]
510    fn test_mock_barrier() {
511        let backends = MockBackend::create_world(2);
512
513        // Both call barrier
514        backends[0].barrier();
515        backends[1].barrier();
516
517        // Should not deadlock
518    }
519}