Skip to main content

axonml_distributed/
process_group.rs

1//! `ProcessGroup` - Process Group Abstraction
2//!
3//! # File
4//! `crates/axonml-distributed/src/process_group.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::{Backend, MockBackend, ReduceOp};
18use axonml_tensor::Tensor;
19use std::sync::Arc;
20
21// =============================================================================
22// ProcessGroup
23// =============================================================================
24
25/// A group of processes that can communicate with each other.
26pub struct ProcessGroup {
27    backend: Arc<dyn Backend>,
28    ranks: Vec<usize>,
29}
30
31impl ProcessGroup {
32    /// Creates a new process group with all ranks.
33    pub fn new(backend: Arc<dyn Backend>) -> Self {
34        let world_size = backend.world_size();
35        Self {
36            backend,
37            ranks: (0..world_size).collect(),
38        }
39    }
40
41    /// Creates a process group with specific ranks.
42    pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
43        Self { backend, ranks }
44    }
45
46    /// Creates a mock process group for testing.
47    #[must_use]
48    pub fn mock() -> Self {
49        Self::new(Arc::new(MockBackend::single()))
50    }
51
52    /// Returns the backend.
53    #[must_use]
54    pub fn backend(&self) -> &dyn Backend {
55        self.backend.as_ref()
56    }
57
58    /// Returns the rank of this process.
59    #[must_use]
60    pub fn rank(&self) -> usize {
61        self.backend.rank()
62    }
63
64    /// Returns the world size.
65    #[must_use]
66    pub fn world_size(&self) -> usize {
67        self.backend.world_size()
68    }
69
70    /// Returns the number of processes in this group.
71    #[must_use]
72    pub fn size(&self) -> usize {
73        self.ranks.len()
74    }
75
76    /// Returns the ranks in this group.
77    #[must_use]
78    pub fn ranks(&self) -> &[usize] {
79        &self.ranks
80    }
81
82    /// Checks if this process is part of the group.
83    #[must_use]
84    pub fn contains(&self, rank: usize) -> bool {
85        self.ranks.contains(&rank)
86    }
87
88    /// Synchronizes all processes in the group.
89    pub fn barrier(&self) {
90        self.backend.barrier();
91    }
92
93    /// Performs all-reduce on a tensor.
94    pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
95        let mut data = tensor.to_vec();
96        self.backend.all_reduce(&mut data, op);
97        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
98    }
99
100    /// Broadcasts a tensor from a source rank.
101    pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
102        let mut data = tensor.to_vec();
103        self.backend.broadcast(&mut data, src);
104        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
105    }
106
107    /// Performs all-gather on tensors.
108    #[must_use]
109    pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
110        let send_data = send_tensor.to_vec();
111        let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
112        self.backend.all_gather(&send_data, &mut recv_data);
113
114        // Output shape: [world_size, ...original_shape]
115        let mut new_shape = vec![self.world_size()];
116        new_shape.extend(send_tensor.shape());
117        Tensor::from_vec(recv_data, &new_shape).unwrap()
118    }
119
120    /// Performs reduce-scatter on a tensor.
121    #[must_use]
122    pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
123        let send_data = send_tensor.to_vec();
124        let chunk_size = send_data.len() / self.world_size();
125        let mut recv_data = vec![0.0; chunk_size];
126        self.backend.reduce_scatter(&send_data, &mut recv_data, op);
127
128        // Output shape: reduced original shape
129        let original_shape = send_tensor.shape();
130        let mut new_shape = original_shape.to_vec();
131        if !new_shape.is_empty() {
132            new_shape[0] /= self.world_size();
133        }
134        Tensor::from_vec(recv_data, &new_shape).unwrap()
135    }
136
137    /// Sends a tensor to a destination rank.
138    pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
139        let data = tensor.to_vec();
140        self.backend.send(&data, dst, 0);
141    }
142
143    /// Receives a tensor from a source rank.
144    #[must_use]
145    pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
146        let size: usize = shape.iter().product();
147        let mut data = vec![0.0; size];
148        self.backend.recv(&mut data, src, 0);
149        Tensor::from_vec(data, shape).unwrap()
150    }
151}
152
153// =============================================================================
154// World
155// =============================================================================
156
157/// Global distributed world.
158pub struct World {
159    default_group: ProcessGroup,
160}
161
162impl World {
163    /// Initializes the distributed world.
164    pub fn init(backend: Arc<dyn Backend>) -> Self {
165        Self {
166            default_group: ProcessGroup::new(backend),
167        }
168    }
169
170    /// Creates a mock world for testing.
171    #[must_use]
172    pub fn mock() -> Self {
173        Self {
174            default_group: ProcessGroup::mock(),
175        }
176    }
177
178    /// Returns the default process group.
179    #[must_use]
180    pub fn default_group(&self) -> &ProcessGroup {
181        &self.default_group
182    }
183
184    /// Returns the rank of this process.
185    #[must_use]
186    pub fn rank(&self) -> usize {
187        self.default_group.rank()
188    }
189
190    /// Returns the world size.
191    #[must_use]
192    pub fn world_size(&self) -> usize {
193        self.default_group.world_size()
194    }
195
196    /// Checks if this is the main process (rank 0).
197    #[must_use]
198    pub fn is_main(&self) -> bool {
199        self.rank() == 0
200    }
201
202    /// Synchronizes all processes.
203    pub fn barrier(&self) {
204        self.default_group.barrier();
205    }
206
207    /// Creates a new process group with specific ranks.
208    #[must_use]
209    pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
210        ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
211    }
212}
213
214impl Clone for ProcessGroup {
215    fn clone(&self) -> Self {
216        Self {
217            backend: Arc::clone(&self.backend),
218            ranks: self.ranks.clone(),
219        }
220    }
221}
222
223// =============================================================================
224// Tests
225// =============================================================================
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_process_group_mock() {
233        let pg = ProcessGroup::mock();
234        assert_eq!(pg.rank(), 0);
235        assert_eq!(pg.world_size(), 1);
236        assert_eq!(pg.size(), 1);
237    }
238
239    #[test]
240    fn test_process_group_contains() {
241        let pg = ProcessGroup::mock();
242        assert!(pg.contains(0));
243        assert!(!pg.contains(1));
244    }
245
246    #[test]
247    fn test_world_mock() {
248        let world = World::mock();
249        assert_eq!(world.rank(), 0);
250        assert_eq!(world.world_size(), 1);
251        assert!(world.is_main());
252    }
253
254    #[test]
255    fn test_world_new_group() {
256        let world = World::mock();
257        let group = world.new_group(vec![0]);
258        assert_eq!(group.size(), 1);
259    }
260
261    #[test]
262    fn test_process_group_all_reduce_tensor() {
263        let backends = MockBackend::create_world(2);
264        let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
265
266        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
267        pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
268
269        // Single rank, values unchanged
270        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
271    }
272
273    #[test]
274    fn test_process_group_broadcast_tensor() {
275        let pg = ProcessGroup::mock();
276
277        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
278        pg.broadcast_tensor(&mut tensor, 0);
279
280        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
281    }
282
283    #[test]
284    fn test_process_group_all_gather_tensor() {
285        let pg = ProcessGroup::mock();
286
287        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
288        let gathered = pg.all_gather_tensor(&tensor);
289
290        assert_eq!(gathered.shape(), &[1, 2]);
291    }
292
293    #[test]
294    fn test_process_group_barrier() {
295        let pg = ProcessGroup::mock();
296        pg.barrier(); // Should not deadlock
297    }
298
299    #[test]
300    fn test_world_barrier() {
301        let world = World::mock();
302        world.barrier(); // Should not deadlock
303    }
304
305    #[test]
306    fn test_process_group_clone() {
307        let pg = ProcessGroup::mock();
308        let pg2 = pg.clone();
309        assert_eq!(pg.rank(), pg2.rank());
310        assert_eq!(pg.world_size(), pg2.world_size());
311    }
312}