axonml_distributed/
process_group.rs1use crate::backend::{Backend, MockBackend, ReduceOp};
19use axonml_tensor::Tensor;
20use std::sync::Arc;
21
22pub struct ProcessGroup {
28 backend: Arc<dyn Backend>,
29 ranks: Vec<usize>,
30}
31
32impl ProcessGroup {
33 pub fn new(backend: Arc<dyn Backend>) -> Self {
35 let world_size = backend.world_size();
36 Self {
37 backend,
38 ranks: (0..world_size).collect(),
39 }
40 }
41
42 pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
44 Self { backend, ranks }
45 }
46
47 #[must_use]
49 pub fn mock() -> Self {
50 Self::new(Arc::new(MockBackend::single()))
51 }
52
53 #[must_use]
55 pub fn backend(&self) -> &dyn Backend {
56 self.backend.as_ref()
57 }
58
59 #[must_use]
61 pub fn rank(&self) -> usize {
62 self.backend.rank()
63 }
64
65 #[must_use]
67 pub fn world_size(&self) -> usize {
68 self.backend.world_size()
69 }
70
71 #[must_use]
73 pub fn size(&self) -> usize {
74 self.ranks.len()
75 }
76
77 #[must_use]
79 pub fn ranks(&self) -> &[usize] {
80 &self.ranks
81 }
82
83 #[must_use]
85 pub fn contains(&self, rank: usize) -> bool {
86 self.ranks.contains(&rank)
87 }
88
89 pub fn barrier(&self) {
91 self.backend.barrier();
92 }
93
94 pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
96 let mut data = tensor.to_vec();
97 self.backend.all_reduce(&mut data, op);
98 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
99 }
100
101 pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
103 let mut data = tensor.to_vec();
104 self.backend.broadcast(&mut data, src);
105 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
106 }
107
108 #[must_use]
110 pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
111 let send_data = send_tensor.to_vec();
112 let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
113 self.backend.all_gather(&send_data, &mut recv_data);
114
115 let mut new_shape = vec![self.world_size()];
117 new_shape.extend(send_tensor.shape());
118 Tensor::from_vec(recv_data, &new_shape).unwrap()
119 }
120
121 #[must_use]
123 pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
124 let send_data = send_tensor.to_vec();
125 let chunk_size = send_data.len() / self.world_size();
126 let mut recv_data = vec![0.0; chunk_size];
127 self.backend.reduce_scatter(&send_data, &mut recv_data, op);
128
129 let original_shape = send_tensor.shape();
131 let mut new_shape = original_shape.to_vec();
132 if !new_shape.is_empty() {
133 new_shape[0] /= self.world_size();
134 }
135 Tensor::from_vec(recv_data, &new_shape).unwrap()
136 }
137
138 pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
140 let data = tensor.to_vec();
141 self.backend.send(&data, dst, 0);
142 }
143
144 #[must_use]
146 pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
147 let size: usize = shape.iter().product();
148 let mut data = vec![0.0; size];
149 self.backend.recv(&mut data, src, 0);
150 Tensor::from_vec(data, shape).unwrap()
151 }
152}
153
154pub struct World {
160 default_group: ProcessGroup,
161}
162
163impl World {
164 pub fn init(backend: Arc<dyn Backend>) -> Self {
166 Self {
167 default_group: ProcessGroup::new(backend),
168 }
169 }
170
171 #[must_use]
173 pub fn mock() -> Self {
174 Self {
175 default_group: ProcessGroup::mock(),
176 }
177 }
178
179 #[must_use]
181 pub fn default_group(&self) -> &ProcessGroup {
182 &self.default_group
183 }
184
185 #[must_use]
187 pub fn rank(&self) -> usize {
188 self.default_group.rank()
189 }
190
191 #[must_use]
193 pub fn world_size(&self) -> usize {
194 self.default_group.world_size()
195 }
196
197 #[must_use]
199 pub fn is_main(&self) -> bool {
200 self.rank() == 0
201 }
202
203 pub fn barrier(&self) {
205 self.default_group.barrier();
206 }
207
208 #[must_use]
210 pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
211 ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
212 }
213}
214
215impl Clone for ProcessGroup {
216 fn clone(&self) -> Self {
217 Self {
218 backend: Arc::clone(&self.backend),
219 ranks: self.ranks.clone(),
220 }
221 }
222}
223
224#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_process_group_mock() {
234 let pg = ProcessGroup::mock();
235 assert_eq!(pg.rank(), 0);
236 assert_eq!(pg.world_size(), 1);
237 assert_eq!(pg.size(), 1);
238 }
239
240 #[test]
241 fn test_process_group_contains() {
242 let pg = ProcessGroup::mock();
243 assert!(pg.contains(0));
244 assert!(!pg.contains(1));
245 }
246
247 #[test]
248 fn test_world_mock() {
249 let world = World::mock();
250 assert_eq!(world.rank(), 0);
251 assert_eq!(world.world_size(), 1);
252 assert!(world.is_main());
253 }
254
255 #[test]
256 fn test_world_new_group() {
257 let world = World::mock();
258 let group = world.new_group(vec![0]);
259 assert_eq!(group.size(), 1);
260 }
261
262 #[test]
263 fn test_process_group_all_reduce_tensor() {
264 let backends = MockBackend::create_world(2);
265 let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
266
267 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
268 pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
269
270 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
272 }
273
274 #[test]
275 fn test_process_group_broadcast_tensor() {
276 let pg = ProcessGroup::mock();
277
278 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
279 pg.broadcast_tensor(&mut tensor, 0);
280
281 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
282 }
283
284 #[test]
285 fn test_process_group_all_gather_tensor() {
286 let pg = ProcessGroup::mock();
287
288 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
289 let gathered = pg.all_gather_tensor(&tensor);
290
291 assert_eq!(gathered.shape(), &[1, 2]);
292 }
293
294 #[test]
295 fn test_process_group_barrier() {
296 let pg = ProcessGroup::mock();
297 pg.barrier(); }
299
300 #[test]
301 fn test_world_barrier() {
302 let world = World::mock();
303 world.barrier(); }
305
306 #[test]
307 fn test_process_group_clone() {
308 let pg = ProcessGroup::mock();
309 let pg2 = pg.clone();
310 assert_eq!(pg.rank(), pg2.rank());
311 assert_eq!(pg.world_size(), pg2.world_size());
312 }
313}