axonml_distributed/
process_group.rs

1//! `ProcessGroup` - Process Group Abstraction
2//!
3//! Provides a high-level abstraction for managing groups of processes
4//! in distributed training.
5//!
6//! @version 0.1.0
7//! @author `AutomataNexus` Development Team
8
9use crate::backend::{Backend, MockBackend, ReduceOp};
10use axonml_tensor::Tensor;
11use std::sync::Arc;
12
13// =============================================================================
14// ProcessGroup
15// =============================================================================
16
17/// A group of processes that can communicate with each other.
18pub struct ProcessGroup {
19    backend: Arc<dyn Backend>,
20    ranks: Vec<usize>,
21}
22
23impl ProcessGroup {
24    /// Creates a new process group with all ranks.
25    pub fn new(backend: Arc<dyn Backend>) -> Self {
26        let world_size = backend.world_size();
27        Self {
28            backend,
29            ranks: (0..world_size).collect(),
30        }
31    }
32
33    /// Creates a process group with specific ranks.
34    pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
35        Self { backend, ranks }
36    }
37
38    /// Creates a mock process group for testing.
39    #[must_use] pub fn mock() -> Self {
40        Self::new(Arc::new(MockBackend::single()))
41    }
42
43    /// Returns the backend.
44    #[must_use] pub fn backend(&self) -> &dyn Backend {
45        self.backend.as_ref()
46    }
47
48    /// Returns the rank of this process.
49    #[must_use] pub fn rank(&self) -> usize {
50        self.backend.rank()
51    }
52
53    /// Returns the world size.
54    #[must_use] pub fn world_size(&self) -> usize {
55        self.backend.world_size()
56    }
57
58    /// Returns the number of processes in this group.
59    #[must_use] pub fn size(&self) -> usize {
60        self.ranks.len()
61    }
62
63    /// Returns the ranks in this group.
64    #[must_use] pub fn ranks(&self) -> &[usize] {
65        &self.ranks
66    }
67
68    /// Checks if this process is part of the group.
69    #[must_use] pub fn contains(&self, rank: usize) -> bool {
70        self.ranks.contains(&rank)
71    }
72
73    /// Synchronizes all processes in the group.
74    pub fn barrier(&self) {
75        self.backend.barrier();
76    }
77
78    /// Performs all-reduce on a tensor.
79    pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
80        let mut data = tensor.to_vec();
81        self.backend.all_reduce(&mut data, op);
82        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
83    }
84
85    /// Broadcasts a tensor from a source rank.
86    pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
87        let mut data = tensor.to_vec();
88        self.backend.broadcast(&mut data, src);
89        *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
90    }
91
92    /// Performs all-gather on tensors.
93    #[must_use] pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
94        let send_data = send_tensor.to_vec();
95        let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
96        self.backend.all_gather(&send_data, &mut recv_data);
97
98        // Output shape: [world_size, ...original_shape]
99        let mut new_shape = vec![self.world_size()];
100        new_shape.extend(send_tensor.shape());
101        Tensor::from_vec(recv_data, &new_shape).unwrap()
102    }
103
104    /// Performs reduce-scatter on a tensor.
105    #[must_use] pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
106        let send_data = send_tensor.to_vec();
107        let chunk_size = send_data.len() / self.world_size();
108        let mut recv_data = vec![0.0; chunk_size];
109        self.backend.reduce_scatter(&send_data, &mut recv_data, op);
110
111        // Output shape: reduced original shape
112        let original_shape = send_tensor.shape();
113        let mut new_shape = original_shape.to_vec();
114        if !new_shape.is_empty() {
115            new_shape[0] /= self.world_size();
116        }
117        Tensor::from_vec(recv_data, &new_shape).unwrap()
118    }
119}
120
121// =============================================================================
122// World
123// =============================================================================
124
125/// Global distributed world.
126pub struct World {
127    default_group: ProcessGroup,
128}
129
130impl World {
131    /// Initializes the distributed world.
132    pub fn init(backend: Arc<dyn Backend>) -> Self {
133        Self {
134            default_group: ProcessGroup::new(backend),
135        }
136    }
137
138    /// Creates a mock world for testing.
139    #[must_use] pub fn mock() -> Self {
140        Self {
141            default_group: ProcessGroup::mock(),
142        }
143    }
144
145    /// Returns the default process group.
146    #[must_use] pub fn default_group(&self) -> &ProcessGroup {
147        &self.default_group
148    }
149
150    /// Returns the rank of this process.
151    #[must_use] pub fn rank(&self) -> usize {
152        self.default_group.rank()
153    }
154
155    /// Returns the world size.
156    #[must_use] pub fn world_size(&self) -> usize {
157        self.default_group.world_size()
158    }
159
160    /// Checks if this is the main process (rank 0).
161    #[must_use] pub fn is_main(&self) -> bool {
162        self.rank() == 0
163    }
164
165    /// Synchronizes all processes.
166    pub fn barrier(&self) {
167        self.default_group.barrier();
168    }
169
170    /// Creates a new process group with specific ranks.
171    #[must_use] pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
172        ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
173    }
174}
175
176impl Clone for ProcessGroup {
177    fn clone(&self) -> Self {
178        Self {
179            backend: Arc::clone(&self.backend),
180            ranks: self.ranks.clone(),
181        }
182    }
183}
184
185// =============================================================================
186// Tests
187// =============================================================================
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_process_group_mock() {
195        let pg = ProcessGroup::mock();
196        assert_eq!(pg.rank(), 0);
197        assert_eq!(pg.world_size(), 1);
198        assert_eq!(pg.size(), 1);
199    }
200
201    #[test]
202    fn test_process_group_contains() {
203        let pg = ProcessGroup::mock();
204        assert!(pg.contains(0));
205        assert!(!pg.contains(1));
206    }
207
208    #[test]
209    fn test_world_mock() {
210        let world = World::mock();
211        assert_eq!(world.rank(), 0);
212        assert_eq!(world.world_size(), 1);
213        assert!(world.is_main());
214    }
215
216    #[test]
217    fn test_world_new_group() {
218        let world = World::mock();
219        let group = world.new_group(vec![0]);
220        assert_eq!(group.size(), 1);
221    }
222
223    #[test]
224    fn test_process_group_all_reduce_tensor() {
225        let backends = MockBackend::create_world(2);
226        let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
227
228        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
229        pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
230
231        // Single rank, values unchanged
232        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
233    }
234
235    #[test]
236    fn test_process_group_broadcast_tensor() {
237        let pg = ProcessGroup::mock();
238
239        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
240        pg.broadcast_tensor(&mut tensor, 0);
241
242        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
243    }
244
245    #[test]
246    fn test_process_group_all_gather_tensor() {
247        let pg = ProcessGroup::mock();
248
249        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
250        let gathered = pg.all_gather_tensor(&tensor);
251
252        assert_eq!(gathered.shape(), &[1, 2]);
253    }
254
255    #[test]
256    fn test_process_group_barrier() {
257        let pg = ProcessGroup::mock();
258        pg.barrier(); // Should not deadlock
259    }
260
261    #[test]
262    fn test_world_barrier() {
263        let world = World::mock();
264        world.barrier(); // Should not deadlock
265    }
266
267    #[test]
268    fn test_process_group_clone() {
269        let pg = ProcessGroup::mock();
270        let pg2 = pg.clone();
271        assert_eq!(pg.rank(), pg2.rank());
272        assert_eq!(pg.world_size(), pg2.world_size());
273    }
274}