Skip to main content

axonml_distributed/
process_group.rs

1//! `ProcessGroup` - Process Group Abstraction
2//!
3//! # File
4//! `crates/axonml-distributed/src/process_group.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::backend::{Backend, MockBackend, ReduceOp};
19use axonml_tensor::Tensor;
20use std::sync::Arc;
21
22// =============================================================================
23// ProcessGroup
24// =============================================================================
25
26/// A group of processes that can communicate with each other.
27pub struct ProcessGroup {
28    backend: Arc<dyn Backend>,
29    ranks: Vec<usize>,
30}
31
32impl ProcessGroup {
33    /// Creates a new process group with all ranks.
34    pub fn new(backend: Arc<dyn Backend>) -> Self {
35        let world_size = backend.world_size();
36        Self {
37            backend,
38            ranks: (0..world_size).collect(),
39        }
40    }
41
42    /// Creates a process group with specific ranks.
43    pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
44        Self { backend, ranks }
45    }
46
47    /// Creates a mock process group for testing.
48    #[must_use]
49    pub fn mock() -> Self {
50        Self::new(Arc::new(MockBackend::single()))
51    }
52
53    /// Returns the backend.
54    #[must_use]
55    pub fn backend(&self) -> &dyn Backend {
56        self.backend.as_ref()
57    }
58
59    /// Returns the rank of this process.
60    #[must_use]
61    pub fn rank(&self) -> usize {
62        self.backend.rank()
63    }
64
65    /// Returns the world size.
66    #[must_use]
67    pub fn world_size(&self) -> usize {
68        self.backend.world_size()
69    }
70
71    /// Returns the number of processes in this group.
72    #[must_use]
73    pub fn size(&self) -> usize {
74        self.ranks.len()
75    }
76
77    /// Returns the ranks in this group.
78    #[must_use]
79    pub fn ranks(&self) -> &[usize] {
80        &self.ranks
81    }
82
83    /// Checks if this process is part of the group.
84    #[must_use]
85    pub fn contains(&self, rank: usize) -> bool {
86        self.ranks.contains(&rank)
87    }
88
89    /// Synchronizes all processes in the group.
90    pub fn barrier(&self) {
91        self.backend.barrier();
92    }
93
94    /// Performs all-reduce on a tensor.
95    pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
96        let mut data = tensor.to_vec();
97        self.backend.all_reduce(&mut data, op);
98        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
99    }
100
101    /// Broadcasts a tensor from a source rank.
102    pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
103        let mut data = tensor.to_vec();
104        self.backend.broadcast(&mut data, src);
105        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
106    }
107
108    /// Performs all-gather on tensors.
109    #[must_use]
110    pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
111        let send_data = send_tensor.to_vec();
112        let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
113        self.backend.all_gather(&send_data, &mut recv_data);
114
115        // Output shape: [world_size, ...original_shape]
116        let mut new_shape = vec![self.world_size()];
117        new_shape.extend(send_tensor.shape());
118        Tensor::from_vec(recv_data, &new_shape).unwrap()
119    }
120
121    /// Performs reduce-scatter on a tensor.
122    #[must_use]
123    pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
124        let send_data = send_tensor.to_vec();
125        let chunk_size = send_data.len() / self.world_size();
126        let mut recv_data = vec![0.0; chunk_size];
127        self.backend.reduce_scatter(&send_data, &mut recv_data, op);
128
129        // Output shape: reduced original shape
130        let original_shape = send_tensor.shape();
131        let mut new_shape = original_shape.to_vec();
132        if !new_shape.is_empty() {
133            new_shape[0] /= self.world_size();
134        }
135        Tensor::from_vec(recv_data, &new_shape).unwrap()
136    }
137
138    /// Sends a tensor to a destination rank.
139    pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
140        let data = tensor.to_vec();
141        self.backend.send(&data, dst, 0);
142    }
143
144    /// Receives a tensor from a source rank.
145    #[must_use]
146    pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
147        let size: usize = shape.iter().product();
148        let mut data = vec![0.0; size];
149        self.backend.recv(&mut data, src, 0);
150        Tensor::from_vec(data, shape).unwrap()
151    }
152}
153
154// =============================================================================
155// World
156// =============================================================================
157
158/// Global distributed world.
159pub struct World {
160    default_group: ProcessGroup,
161}
162
163impl World {
164    /// Initializes the distributed world.
165    pub fn init(backend: Arc<dyn Backend>) -> Self {
166        Self {
167            default_group: ProcessGroup::new(backend),
168        }
169    }
170
171    /// Creates a mock world for testing.
172    #[must_use]
173    pub fn mock() -> Self {
174        Self {
175            default_group: ProcessGroup::mock(),
176        }
177    }
178
179    /// Returns the default process group.
180    #[must_use]
181    pub fn default_group(&self) -> &ProcessGroup {
182        &self.default_group
183    }
184
185    /// Returns the rank of this process.
186    #[must_use]
187    pub fn rank(&self) -> usize {
188        self.default_group.rank()
189    }
190
191    /// Returns the world size.
192    #[must_use]
193    pub fn world_size(&self) -> usize {
194        self.default_group.world_size()
195    }
196
197    /// Checks if this is the main process (rank 0).
198    #[must_use]
199    pub fn is_main(&self) -> bool {
200        self.rank() == 0
201    }
202
203    /// Synchronizes all processes.
204    pub fn barrier(&self) {
205        self.default_group.barrier();
206    }
207
208    /// Creates a new process group with specific ranks.
209    #[must_use]
210    pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
211        ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
212    }
213}
214
215impl Clone for ProcessGroup {
216    fn clone(&self) -> Self {
217        Self {
218            backend: Arc::clone(&self.backend),
219            ranks: self.ranks.clone(),
220        }
221    }
222}
223
224// =============================================================================
225// Tests
226// =============================================================================
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_process_group_mock() {
234        let pg = ProcessGroup::mock();
235        assert_eq!(pg.rank(), 0);
236        assert_eq!(pg.world_size(), 1);
237        assert_eq!(pg.size(), 1);
238    }
239
240    #[test]
241    fn test_process_group_contains() {
242        let pg = ProcessGroup::mock();
243        assert!(pg.contains(0));
244        assert!(!pg.contains(1));
245    }
246
247    #[test]
248    fn test_world_mock() {
249        let world = World::mock();
250        assert_eq!(world.rank(), 0);
251        assert_eq!(world.world_size(), 1);
252        assert!(world.is_main());
253    }
254
255    #[test]
256    fn test_world_new_group() {
257        let world = World::mock();
258        let group = world.new_group(vec![0]);
259        assert_eq!(group.size(), 1);
260    }
261
262    #[test]
263    fn test_process_group_all_reduce_tensor() {
264        let backends = MockBackend::create_world(2);
265        let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
266
267        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
268        pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
269
270        // Single rank, values unchanged
271        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
272    }
273
274    #[test]
275    fn test_process_group_broadcast_tensor() {
276        let pg = ProcessGroup::mock();
277
278        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
279        pg.broadcast_tensor(&mut tensor, 0);
280
281        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
282    }
283
284    #[test]
285    fn test_process_group_all_gather_tensor() {
286        let pg = ProcessGroup::mock();
287
288        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
289        let gathered = pg.all_gather_tensor(&tensor);
290
291        assert_eq!(gathered.shape(), &[1, 2]);
292    }
293
294    #[test]
295    fn test_process_group_barrier() {
296        let pg = ProcessGroup::mock();
297        pg.barrier(); // Should not deadlock
298    }
299
300    #[test]
301    fn test_world_barrier() {
302        let world = World::mock();
303        world.barrier(); // Should not deadlock
304    }
305
306    #[test]
307    fn test_process_group_clone() {
308        let pg = ProcessGroup::mock();
309        let pg2 = pg.clone();
310        assert_eq!(pg.rank(), pg2.rank());
311        assert_eq!(pg.world_size(), pg2.world_size());
312    }
313}