Skip to main content

axonml_distributed/
process_group.rs

1//! `ProcessGroup` - Process Group Abstraction
2//!
3//! Provides a high-level abstraction for managing groups of processes
4//! in distributed training.
5//!
6//! @version 0.1.0
7//! @author `AutomataNexus` Development Team
8
9use crate::backend::{Backend, MockBackend, ReduceOp};
10use axonml_tensor::Tensor;
11use std::sync::Arc;
12
13// =============================================================================
14// ProcessGroup
15// =============================================================================
16
17/// A group of processes that can communicate with each other.
18pub struct ProcessGroup {
19    backend: Arc<dyn Backend>,
20    ranks: Vec<usize>,
21}
22
23impl ProcessGroup {
24    /// Creates a new process group with all ranks.
25    pub fn new(backend: Arc<dyn Backend>) -> Self {
26        let world_size = backend.world_size();
27        Self {
28            backend,
29            ranks: (0..world_size).collect(),
30        }
31    }
32
33    /// Creates a process group with specific ranks.
34    pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
35        Self { backend, ranks }
36    }
37
38    /// Creates a mock process group for testing.
39    #[must_use] pub fn mock() -> Self {
40        Self::new(Arc::new(MockBackend::single()))
41    }
42
43    /// Returns the backend.
44    #[must_use] pub fn backend(&self) -> &dyn Backend {
45        self.backend.as_ref()
46    }
47
48    /// Returns the rank of this process.
49    #[must_use] pub fn rank(&self) -> usize {
50        self.backend.rank()
51    }
52
53    /// Returns the world size.
54    #[must_use] pub fn world_size(&self) -> usize {
55        self.backend.world_size()
56    }
57
58    /// Returns the number of processes in this group.
59    #[must_use] pub fn size(&self) -> usize {
60        self.ranks.len()
61    }
62
63    /// Returns the ranks in this group.
64    #[must_use] pub fn ranks(&self) -> &[usize] {
65        &self.ranks
66    }
67
68    /// Checks if this process is part of the group.
69    #[must_use] pub fn contains(&self, rank: usize) -> bool {
70        self.ranks.contains(&rank)
71    }
72
73    /// Synchronizes all processes in the group.
74    pub fn barrier(&self) {
75        self.backend.barrier();
76    }
77
78    /// Performs all-reduce on a tensor.
79    pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
80        let mut data = tensor.to_vec();
81        self.backend.all_reduce(&mut data, op);
82        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
83    }
84
85    /// Broadcasts a tensor from a source rank.
86    pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
87        let mut data = tensor.to_vec();
88        self.backend.broadcast(&mut data, src);
89        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
90    }
91
92    /// Performs all-gather on tensors.
93    #[must_use] pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
94        let send_data = send_tensor.to_vec();
95        let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
96        self.backend.all_gather(&send_data, &mut recv_data);
97
98        // Output shape: [world_size, ...original_shape]
99        let mut new_shape = vec![self.world_size()];
100        new_shape.extend(send_tensor.shape());
101        Tensor::from_vec(recv_data, &new_shape).unwrap()
102    }
103
104    /// Performs reduce-scatter on a tensor.
105    #[must_use] pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
106        let send_data = send_tensor.to_vec();
107        let chunk_size = send_data.len() / self.world_size();
108        let mut recv_data = vec![0.0; chunk_size];
109        self.backend.reduce_scatter(&send_data, &mut recv_data, op);
110
111        // Output shape: reduced original shape
112        let original_shape = send_tensor.shape();
113        let mut new_shape = original_shape.to_vec();
114        if !new_shape.is_empty() {
115            new_shape[0] /= self.world_size();
116        }
117        Tensor::from_vec(recv_data, &new_shape).unwrap()
118    }
119
120    /// Sends a tensor to a destination rank.
121    pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
122        let data = tensor.to_vec();
123        self.backend.send(&data, dst, 0);
124    }
125
126    /// Receives a tensor from a source rank.
127    #[must_use] pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
128        let size: usize = shape.iter().product();
129        let mut data = vec![0.0; size];
130        self.backend.recv(&mut data, src, 0);
131        Tensor::from_vec(data, shape).unwrap()
132    }
133}
134
135// =============================================================================
136// World
137// =============================================================================
138
139/// Global distributed world.
140pub struct World {
141    default_group: ProcessGroup,
142}
143
144impl World {
145    /// Initializes the distributed world.
146    pub fn init(backend: Arc<dyn Backend>) -> Self {
147        Self {
148            default_group: ProcessGroup::new(backend),
149        }
150    }
151
152    /// Creates a mock world for testing.
153    #[must_use] pub fn mock() -> Self {
154        Self {
155            default_group: ProcessGroup::mock(),
156        }
157    }
158
159    /// Returns the default process group.
160    #[must_use] pub fn default_group(&self) -> &ProcessGroup {
161        &self.default_group
162    }
163
164    /// Returns the rank of this process.
165    #[must_use] pub fn rank(&self) -> usize {
166        self.default_group.rank()
167    }
168
169    /// Returns the world size.
170    #[must_use] pub fn world_size(&self) -> usize {
171        self.default_group.world_size()
172    }
173
174    /// Checks if this is the main process (rank 0).
175    #[must_use] pub fn is_main(&self) -> bool {
176        self.rank() == 0
177    }
178
179    /// Synchronizes all processes.
180    pub fn barrier(&self) {
181        self.default_group.barrier();
182    }
183
184    /// Creates a new process group with specific ranks.
185    #[must_use] pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
186        ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
187    }
188}
189
190impl Clone for ProcessGroup {
191    fn clone(&self) -> Self {
192        Self {
193            backend: Arc::clone(&self.backend),
194            ranks: self.ranks.clone(),
195        }
196    }
197}
198
199// =============================================================================
200// Tests
201// =============================================================================
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_process_group_mock() {
209        let pg = ProcessGroup::mock();
210        assert_eq!(pg.rank(), 0);
211        assert_eq!(pg.world_size(), 1);
212        assert_eq!(pg.size(), 1);
213    }
214
215    #[test]
216    fn test_process_group_contains() {
217        let pg = ProcessGroup::mock();
218        assert!(pg.contains(0));
219        assert!(!pg.contains(1));
220    }
221
222    #[test]
223    fn test_world_mock() {
224        let world = World::mock();
225        assert_eq!(world.rank(), 0);
226        assert_eq!(world.world_size(), 1);
227        assert!(world.is_main());
228    }
229
230    #[test]
231    fn test_world_new_group() {
232        let world = World::mock();
233        let group = world.new_group(vec![0]);
234        assert_eq!(group.size(), 1);
235    }
236
237    #[test]
238    fn test_process_group_all_reduce_tensor() {
239        let backends = MockBackend::create_world(2);
240        let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
241
242        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
243        pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
244
245        // Single rank, values unchanged
246        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
247    }
248
249    #[test]
250    fn test_process_group_broadcast_tensor() {
251        let pg = ProcessGroup::mock();
252
253        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
254        pg.broadcast_tensor(&mut tensor, 0);
255
256        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
257    }
258
259    #[test]
260    fn test_process_group_all_gather_tensor() {
261        let pg = ProcessGroup::mock();
262
263        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
264        let gathered = pg.all_gather_tensor(&tensor);
265
266        assert_eq!(gathered.shape(), &[1, 2]);
267    }
268
269    #[test]
270    fn test_process_group_barrier() {
271        let pg = ProcessGroup::mock();
272        pg.barrier(); // Should not deadlock
273    }
274
275    #[test]
276    fn test_world_barrier() {
277        let world = World::mock();
278        world.barrier(); // Should not deadlock
279    }
280
281    #[test]
282    fn test_process_group_clone() {
283        let pg = ProcessGroup::mock();
284        let pg2 = pg.clone();
285        assert_eq!(pg.rank(), pg2.rank());
286        assert_eq!(pg.world_size(), pg2.world_size());
287    }
288}