axonml_distributed/
process_group.rs1use crate::backend::{Backend, MockBackend, ReduceOp};
10use axonml_tensor::Tensor;
11use std::sync::Arc;
12
13pub struct ProcessGroup {
19 backend: Arc<dyn Backend>,
20 ranks: Vec<usize>,
21}
22
23impl ProcessGroup {
24 pub fn new(backend: Arc<dyn Backend>) -> Self {
26 let world_size = backend.world_size();
27 Self {
28 backend,
29 ranks: (0..world_size).collect(),
30 }
31 }
32
33 pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
35 Self { backend, ranks }
36 }
37
38 #[must_use]
40 pub fn mock() -> Self {
41 Self::new(Arc::new(MockBackend::single()))
42 }
43
44 #[must_use]
46 pub fn backend(&self) -> &dyn Backend {
47 self.backend.as_ref()
48 }
49
50 #[must_use]
52 pub fn rank(&self) -> usize {
53 self.backend.rank()
54 }
55
56 #[must_use]
58 pub fn world_size(&self) -> usize {
59 self.backend.world_size()
60 }
61
62 #[must_use]
64 pub fn size(&self) -> usize {
65 self.ranks.len()
66 }
67
68 #[must_use]
70 pub fn ranks(&self) -> &[usize] {
71 &self.ranks
72 }
73
74 #[must_use]
76 pub fn contains(&self, rank: usize) -> bool {
77 self.ranks.contains(&rank)
78 }
79
80 pub fn barrier(&self) {
82 self.backend.barrier();
83 }
84
85 pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
87 let mut data = tensor.to_vec();
88 self.backend.all_reduce(&mut data, op);
89 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
90 }
91
92 pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
94 let mut data = tensor.to_vec();
95 self.backend.broadcast(&mut data, src);
96 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
97 }
98
99 #[must_use]
101 pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
102 let send_data = send_tensor.to_vec();
103 let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
104 self.backend.all_gather(&send_data, &mut recv_data);
105
106 let mut new_shape = vec![self.world_size()];
108 new_shape.extend(send_tensor.shape());
109 Tensor::from_vec(recv_data, &new_shape).unwrap()
110 }
111
112 #[must_use]
114 pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
115 let send_data = send_tensor.to_vec();
116 let chunk_size = send_data.len() / self.world_size();
117 let mut recv_data = vec![0.0; chunk_size];
118 self.backend.reduce_scatter(&send_data, &mut recv_data, op);
119
120 let original_shape = send_tensor.shape();
122 let mut new_shape = original_shape.to_vec();
123 if !new_shape.is_empty() {
124 new_shape[0] /= self.world_size();
125 }
126 Tensor::from_vec(recv_data, &new_shape).unwrap()
127 }
128
129 pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
131 let data = tensor.to_vec();
132 self.backend.send(&data, dst, 0);
133 }
134
135 #[must_use]
137 pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
138 let size: usize = shape.iter().product();
139 let mut data = vec![0.0; size];
140 self.backend.recv(&mut data, src, 0);
141 Tensor::from_vec(data, shape).unwrap()
142 }
143}
144
145pub struct World {
151 default_group: ProcessGroup,
152}
153
154impl World {
155 pub fn init(backend: Arc<dyn Backend>) -> Self {
157 Self {
158 default_group: ProcessGroup::new(backend),
159 }
160 }
161
162 #[must_use]
164 pub fn mock() -> Self {
165 Self {
166 default_group: ProcessGroup::mock(),
167 }
168 }
169
170 #[must_use]
172 pub fn default_group(&self) -> &ProcessGroup {
173 &self.default_group
174 }
175
176 #[must_use]
178 pub fn rank(&self) -> usize {
179 self.default_group.rank()
180 }
181
182 #[must_use]
184 pub fn world_size(&self) -> usize {
185 self.default_group.world_size()
186 }
187
188 #[must_use]
190 pub fn is_main(&self) -> bool {
191 self.rank() == 0
192 }
193
194 pub fn barrier(&self) {
196 self.default_group.barrier();
197 }
198
199 #[must_use]
201 pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
202 ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
203 }
204}
205
206impl Clone for ProcessGroup {
207 fn clone(&self) -> Self {
208 Self {
209 backend: Arc::clone(&self.backend),
210 ranks: self.ranks.clone(),
211 }
212 }
213}
214
215#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_process_group_mock() {
225 let pg = ProcessGroup::mock();
226 assert_eq!(pg.rank(), 0);
227 assert_eq!(pg.world_size(), 1);
228 assert_eq!(pg.size(), 1);
229 }
230
231 #[test]
232 fn test_process_group_contains() {
233 let pg = ProcessGroup::mock();
234 assert!(pg.contains(0));
235 assert!(!pg.contains(1));
236 }
237
238 #[test]
239 fn test_world_mock() {
240 let world = World::mock();
241 assert_eq!(world.rank(), 0);
242 assert_eq!(world.world_size(), 1);
243 assert!(world.is_main());
244 }
245
246 #[test]
247 fn test_world_new_group() {
248 let world = World::mock();
249 let group = world.new_group(vec![0]);
250 assert_eq!(group.size(), 1);
251 }
252
253 #[test]
254 fn test_process_group_all_reduce_tensor() {
255 let backends = MockBackend::create_world(2);
256 let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
257
258 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
259 pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
260
261 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
263 }
264
265 #[test]
266 fn test_process_group_broadcast_tensor() {
267 let pg = ProcessGroup::mock();
268
269 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
270 pg.broadcast_tensor(&mut tensor, 0);
271
272 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
273 }
274
275 #[test]
276 fn test_process_group_all_gather_tensor() {
277 let pg = ProcessGroup::mock();
278
279 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
280 let gathered = pg.all_gather_tensor(&tensor);
281
282 assert_eq!(gathered.shape(), &[1, 2]);
283 }
284
285 #[test]
286 fn test_process_group_barrier() {
287 let pg = ProcessGroup::mock();
288 pg.barrier(); }
290
291 #[test]
292 fn test_world_barrier() {
293 let world = World::mock();
294 world.barrier(); }
296
297 #[test]
298 fn test_process_group_clone() {
299 let pg = ProcessGroup::mock();
300 let pg2 = pg.clone();
301 assert_eq!(pg.rank(), pg2.rank());
302 assert_eq!(pg.world_size(), pg2.world_size());
303 }
304}