axonml_distributed/
process_group.rs1use crate::backend::{Backend, MockBackend, ReduceOp};
18use axonml_tensor::Tensor;
19use std::sync::Arc;
20
21pub struct ProcessGroup {
27 backend: Arc<dyn Backend>,
28 ranks: Vec<usize>,
29}
30
31impl ProcessGroup {
32 pub fn new(backend: Arc<dyn Backend>) -> Self {
34 let world_size = backend.world_size();
35 Self {
36 backend,
37 ranks: (0..world_size).collect(),
38 }
39 }
40
41 pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
43 Self { backend, ranks }
44 }
45
46 #[must_use]
48 pub fn mock() -> Self {
49 Self::new(Arc::new(MockBackend::single()))
50 }
51
52 #[must_use]
54 pub fn backend(&self) -> &dyn Backend {
55 self.backend.as_ref()
56 }
57
58 #[must_use]
60 pub fn rank(&self) -> usize {
61 self.backend.rank()
62 }
63
64 #[must_use]
66 pub fn world_size(&self) -> usize {
67 self.backend.world_size()
68 }
69
70 #[must_use]
72 pub fn size(&self) -> usize {
73 self.ranks.len()
74 }
75
76 #[must_use]
78 pub fn ranks(&self) -> &[usize] {
79 &self.ranks
80 }
81
82 #[must_use]
84 pub fn contains(&self, rank: usize) -> bool {
85 self.ranks.contains(&rank)
86 }
87
88 pub fn barrier(&self) {
90 self.backend.barrier();
91 }
92
93 pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
95 let mut data = tensor.to_vec();
96 self.backend.all_reduce(&mut data, op);
97 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
98 }
99
100 pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
102 let mut data = tensor.to_vec();
103 self.backend.broadcast(&mut data, src);
104 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
105 }
106
107 #[must_use]
109 pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
110 let send_data = send_tensor.to_vec();
111 let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
112 self.backend.all_gather(&send_data, &mut recv_data);
113
114 let mut new_shape = vec![self.world_size()];
116 new_shape.extend(send_tensor.shape());
117 Tensor::from_vec(recv_data, &new_shape).unwrap()
118 }
119
120 #[must_use]
122 pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
123 let send_data = send_tensor.to_vec();
124 let chunk_size = send_data.len() / self.world_size();
125 let mut recv_data = vec![0.0; chunk_size];
126 self.backend.reduce_scatter(&send_data, &mut recv_data, op);
127
128 let original_shape = send_tensor.shape();
130 let mut new_shape = original_shape.to_vec();
131 if !new_shape.is_empty() {
132 new_shape[0] /= self.world_size();
133 }
134 Tensor::from_vec(recv_data, &new_shape).unwrap()
135 }
136
137 pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
139 let data = tensor.to_vec();
140 self.backend.send(&data, dst, 0);
141 }
142
143 #[must_use]
145 pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
146 let size: usize = shape.iter().product();
147 let mut data = vec![0.0; size];
148 self.backend.recv(&mut data, src, 0);
149 Tensor::from_vec(data, shape).unwrap()
150 }
151}
152
153pub struct World {
159 default_group: ProcessGroup,
160}
161
162impl World {
163 pub fn init(backend: Arc<dyn Backend>) -> Self {
165 Self {
166 default_group: ProcessGroup::new(backend),
167 }
168 }
169
170 #[must_use]
172 pub fn mock() -> Self {
173 Self {
174 default_group: ProcessGroup::mock(),
175 }
176 }
177
178 #[must_use]
180 pub fn default_group(&self) -> &ProcessGroup {
181 &self.default_group
182 }
183
184 #[must_use]
186 pub fn rank(&self) -> usize {
187 self.default_group.rank()
188 }
189
190 #[must_use]
192 pub fn world_size(&self) -> usize {
193 self.default_group.world_size()
194 }
195
196 #[must_use]
198 pub fn is_main(&self) -> bool {
199 self.rank() == 0
200 }
201
202 pub fn barrier(&self) {
204 self.default_group.barrier();
205 }
206
207 #[must_use]
209 pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
210 ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
211 }
212}
213
214impl Clone for ProcessGroup {
215 fn clone(&self) -> Self {
216 Self {
217 backend: Arc::clone(&self.backend),
218 ranks: self.ranks.clone(),
219 }
220 }
221}
222
223#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_process_group_mock() {
233 let pg = ProcessGroup::mock();
234 assert_eq!(pg.rank(), 0);
235 assert_eq!(pg.world_size(), 1);
236 assert_eq!(pg.size(), 1);
237 }
238
239 #[test]
240 fn test_process_group_contains() {
241 let pg = ProcessGroup::mock();
242 assert!(pg.contains(0));
243 assert!(!pg.contains(1));
244 }
245
246 #[test]
247 fn test_world_mock() {
248 let world = World::mock();
249 assert_eq!(world.rank(), 0);
250 assert_eq!(world.world_size(), 1);
251 assert!(world.is_main());
252 }
253
254 #[test]
255 fn test_world_new_group() {
256 let world = World::mock();
257 let group = world.new_group(vec![0]);
258 assert_eq!(group.size(), 1);
259 }
260
261 #[test]
262 fn test_process_group_all_reduce_tensor() {
263 let backends = MockBackend::create_world(2);
264 let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
265
266 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
267 pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
268
269 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
271 }
272
273 #[test]
274 fn test_process_group_broadcast_tensor() {
275 let pg = ProcessGroup::mock();
276
277 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
278 pg.broadcast_tensor(&mut tensor, 0);
279
280 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
281 }
282
283 #[test]
284 fn test_process_group_all_gather_tensor() {
285 let pg = ProcessGroup::mock();
286
287 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
288 let gathered = pg.all_gather_tensor(&tensor);
289
290 assert_eq!(gathered.shape(), &[1, 2]);
291 }
292
293 #[test]
294 fn test_process_group_barrier() {
295 let pg = ProcessGroup::mock();
296 pg.barrier(); }
298
299 #[test]
300 fn test_world_barrier() {
301 let world = World::mock();
302 world.barrier(); }
304
305 #[test]
306 fn test_process_group_clone() {
307 let pg = ProcessGroup::mock();
308 let pg2 = pg.clone();
309 assert_eq!(pg.rank(), pg2.rank());
310 assert_eq!(pg.world_size(), pg2.world_size());
311 }
312}