Skip to main content

axonml_distributed/
process_group.rs

1//! `ProcessGroup` - Process Group Abstraction
2//!
3//! Provides a high-level abstraction for managing groups of processes
4//! in distributed training.
5//!
6//! @version 0.1.0
7//! @author `AutomataNexus` Development Team
8
9use crate::backend::{Backend, MockBackend, ReduceOp};
10use axonml_tensor::Tensor;
11use std::sync::Arc;
12
13// =============================================================================
14// ProcessGroup
15// =============================================================================
16
17/// A group of processes that can communicate with each other.
18pub struct ProcessGroup {
19    backend: Arc<dyn Backend>,
20    ranks: Vec<usize>,
21}
22
23impl ProcessGroup {
24    /// Creates a new process group with all ranks.
25    pub fn new(backend: Arc<dyn Backend>) -> Self {
26        let world_size = backend.world_size();
27        Self {
28            backend,
29            ranks: (0..world_size).collect(),
30        }
31    }
32
33    /// Creates a process group with specific ranks.
34    pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
35        Self { backend, ranks }
36    }
37
38    /// Creates a mock process group for testing.
39    #[must_use]
40    pub fn mock() -> Self {
41        Self::new(Arc::new(MockBackend::single()))
42    }
43
44    /// Returns the backend.
45    #[must_use]
46    pub fn backend(&self) -> &dyn Backend {
47        self.backend.as_ref()
48    }
49
50    /// Returns the rank of this process.
51    #[must_use]
52    pub fn rank(&self) -> usize {
53        self.backend.rank()
54    }
55
56    /// Returns the world size.
57    #[must_use]
58    pub fn world_size(&self) -> usize {
59        self.backend.world_size()
60    }
61
62    /// Returns the number of processes in this group.
63    #[must_use]
64    pub fn size(&self) -> usize {
65        self.ranks.len()
66    }
67
68    /// Returns the ranks in this group.
69    #[must_use]
70    pub fn ranks(&self) -> &[usize] {
71        &self.ranks
72    }
73
74    /// Checks if this process is part of the group.
75    #[must_use]
76    pub fn contains(&self, rank: usize) -> bool {
77        self.ranks.contains(&rank)
78    }
79
80    /// Synchronizes all processes in the group.
81    pub fn barrier(&self) {
82        self.backend.barrier();
83    }
84
85    /// Performs all-reduce on a tensor.
86    pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
87        let mut data = tensor.to_vec();
88        self.backend.all_reduce(&mut data, op);
89        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
90    }
91
92    /// Broadcasts a tensor from a source rank.
93    pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
94        let mut data = tensor.to_vec();
95        self.backend.broadcast(&mut data, src);
96        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
97    }
98
99    /// Performs all-gather on tensors.
100    #[must_use]
101    pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
102        let send_data = send_tensor.to_vec();
103        let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
104        self.backend.all_gather(&send_data, &mut recv_data);
105
106        // Output shape: [world_size, ...original_shape]
107        let mut new_shape = vec![self.world_size()];
108        new_shape.extend(send_tensor.shape());
109        Tensor::from_vec(recv_data, &new_shape).unwrap()
110    }
111
112    /// Performs reduce-scatter on a tensor.
113    #[must_use]
114    pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
115        let send_data = send_tensor.to_vec();
116        let chunk_size = send_data.len() / self.world_size();
117        let mut recv_data = vec![0.0; chunk_size];
118        self.backend.reduce_scatter(&send_data, &mut recv_data, op);
119
120        // Output shape: reduced original shape
121        let original_shape = send_tensor.shape();
122        let mut new_shape = original_shape.to_vec();
123        if !new_shape.is_empty() {
124            new_shape[0] /= self.world_size();
125        }
126        Tensor::from_vec(recv_data, &new_shape).unwrap()
127    }
128
129    /// Sends a tensor to a destination rank.
130    pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
131        let data = tensor.to_vec();
132        self.backend.send(&data, dst, 0);
133    }
134
135    /// Receives a tensor from a source rank.
136    #[must_use]
137    pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
138        let size: usize = shape.iter().product();
139        let mut data = vec![0.0; size];
140        self.backend.recv(&mut data, src, 0);
141        Tensor::from_vec(data, shape).unwrap()
142    }
143}
144
145// =============================================================================
146// World
147// =============================================================================
148
149/// Global distributed world.
150pub struct World {
151    default_group: ProcessGroup,
152}
153
154impl World {
155    /// Initializes the distributed world.
156    pub fn init(backend: Arc<dyn Backend>) -> Self {
157        Self {
158            default_group: ProcessGroup::new(backend),
159        }
160    }
161
162    /// Creates a mock world for testing.
163    #[must_use]
164    pub fn mock() -> Self {
165        Self {
166            default_group: ProcessGroup::mock(),
167        }
168    }
169
170    /// Returns the default process group.
171    #[must_use]
172    pub fn default_group(&self) -> &ProcessGroup {
173        &self.default_group
174    }
175
176    /// Returns the rank of this process.
177    #[must_use]
178    pub fn rank(&self) -> usize {
179        self.default_group.rank()
180    }
181
182    /// Returns the world size.
183    #[must_use]
184    pub fn world_size(&self) -> usize {
185        self.default_group.world_size()
186    }
187
188    /// Checks if this is the main process (rank 0).
189    #[must_use]
190    pub fn is_main(&self) -> bool {
191        self.rank() == 0
192    }
193
194    /// Synchronizes all processes.
195    pub fn barrier(&self) {
196        self.default_group.barrier();
197    }
198
199    /// Creates a new process group with specific ranks.
200    #[must_use]
201    pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
202        ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
203    }
204}
205
206impl Clone for ProcessGroup {
207    fn clone(&self) -> Self {
208        Self {
209            backend: Arc::clone(&self.backend),
210            ranks: self.ranks.clone(),
211        }
212    }
213}
214
215// =============================================================================
216// Tests
217// =============================================================================
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_process_group_mock() {
225        let pg = ProcessGroup::mock();
226        assert_eq!(pg.rank(), 0);
227        assert_eq!(pg.world_size(), 1);
228        assert_eq!(pg.size(), 1);
229    }
230
231    #[test]
232    fn test_process_group_contains() {
233        let pg = ProcessGroup::mock();
234        assert!(pg.contains(0));
235        assert!(!pg.contains(1));
236    }
237
238    #[test]
239    fn test_world_mock() {
240        let world = World::mock();
241        assert_eq!(world.rank(), 0);
242        assert_eq!(world.world_size(), 1);
243        assert!(world.is_main());
244    }
245
246    #[test]
247    fn test_world_new_group() {
248        let world = World::mock();
249        let group = world.new_group(vec![0]);
250        assert_eq!(group.size(), 1);
251    }
252
253    #[test]
254    fn test_process_group_all_reduce_tensor() {
255        let backends = MockBackend::create_world(2);
256        let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
257
258        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
259        pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
260
261        // Single rank, values unchanged
262        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
263    }
264
265    #[test]
266    fn test_process_group_broadcast_tensor() {
267        let pg = ProcessGroup::mock();
268
269        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
270        pg.broadcast_tensor(&mut tensor, 0);
271
272        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
273    }
274
275    #[test]
276    fn test_process_group_all_gather_tensor() {
277        let pg = ProcessGroup::mock();
278
279        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
280        let gathered = pg.all_gather_tensor(&tensor);
281
282        assert_eq!(gathered.shape(), &[1, 2]);
283    }
284
285    #[test]
286    fn test_process_group_barrier() {
287        let pg = ProcessGroup::mock();
288        pg.barrier(); // Should not deadlock
289    }
290
291    #[test]
292    fn test_world_barrier() {
293        let world = World::mock();
294        world.barrier(); // Should not deadlock
295    }
296
297    #[test]
298    fn test_process_group_clone() {
299        let pg = ProcessGroup::mock();
300        let pg2 = pg.clone();
301        assert_eq!(pg.rank(), pg2.rank());
302        assert_eq!(pg.world_size(), pg2.world_size());
303    }
304}