axonml_distributed/
process_group.rs1use crate::backend::{Backend, MockBackend, ReduceOp};
10use axonml_tensor::Tensor;
11use std::sync::Arc;
12
13pub struct ProcessGroup {
19 backend: Arc<dyn Backend>,
20 ranks: Vec<usize>,
21}
22
23impl ProcessGroup {
24 pub fn new(backend: Arc<dyn Backend>) -> Self {
26 let world_size = backend.world_size();
27 Self {
28 backend,
29 ranks: (0..world_size).collect(),
30 }
31 }
32
33 pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
35 Self { backend, ranks }
36 }
37
38 #[must_use] pub fn mock() -> Self {
40 Self::new(Arc::new(MockBackend::single()))
41 }
42
43 #[must_use] pub fn backend(&self) -> &dyn Backend {
45 self.backend.as_ref()
46 }
47
48 #[must_use] pub fn rank(&self) -> usize {
50 self.backend.rank()
51 }
52
53 #[must_use] pub fn world_size(&self) -> usize {
55 self.backend.world_size()
56 }
57
58 #[must_use] pub fn size(&self) -> usize {
60 self.ranks.len()
61 }
62
63 #[must_use] pub fn ranks(&self) -> &[usize] {
65 &self.ranks
66 }
67
68 #[must_use] pub fn contains(&self, rank: usize) -> bool {
70 self.ranks.contains(&rank)
71 }
72
73 pub fn barrier(&self) {
75 self.backend.barrier();
76 }
77
78 pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
80 let mut data = tensor.to_vec();
81 self.backend.all_reduce(&mut data, op);
82 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
83 }
84
85 pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
87 let mut data = tensor.to_vec();
88 self.backend.broadcast(&mut data, src);
89 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
90 }
91
92 #[must_use] pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
94 let send_data = send_tensor.to_vec();
95 let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
96 self.backend.all_gather(&send_data, &mut recv_data);
97
98 let mut new_shape = vec![self.world_size()];
100 new_shape.extend(send_tensor.shape());
101 Tensor::from_vec(recv_data, &new_shape).unwrap()
102 }
103
104 #[must_use] pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
106 let send_data = send_tensor.to_vec();
107 let chunk_size = send_data.len() / self.world_size();
108 let mut recv_data = vec![0.0; chunk_size];
109 self.backend.reduce_scatter(&send_data, &mut recv_data, op);
110
111 let original_shape = send_tensor.shape();
113 let mut new_shape = original_shape.to_vec();
114 if !new_shape.is_empty() {
115 new_shape[0] /= self.world_size();
116 }
117 Tensor::from_vec(recv_data, &new_shape).unwrap()
118 }
119
120 pub fn send_tensor(&self, tensor: &mut Tensor<f32>, dst: usize) {
122 let data = tensor.to_vec();
123 self.backend.send(&data, dst, 0);
124 }
125
126 #[must_use] pub fn recv_tensor(&self, src: usize, shape: &[usize]) -> Tensor<f32> {
128 let size: usize = shape.iter().product();
129 let mut data = vec![0.0; size];
130 self.backend.recv(&mut data, src, 0);
131 Tensor::from_vec(data, shape).unwrap()
132 }
133}
134
135pub struct World {
141 default_group: ProcessGroup,
142}
143
144impl World {
145 pub fn init(backend: Arc<dyn Backend>) -> Self {
147 Self {
148 default_group: ProcessGroup::new(backend),
149 }
150 }
151
152 #[must_use] pub fn mock() -> Self {
154 Self {
155 default_group: ProcessGroup::mock(),
156 }
157 }
158
159 #[must_use] pub fn default_group(&self) -> &ProcessGroup {
161 &self.default_group
162 }
163
164 #[must_use] pub fn rank(&self) -> usize {
166 self.default_group.rank()
167 }
168
169 #[must_use] pub fn world_size(&self) -> usize {
171 self.default_group.world_size()
172 }
173
174 #[must_use] pub fn is_main(&self) -> bool {
176 self.rank() == 0
177 }
178
179 pub fn barrier(&self) {
181 self.default_group.barrier();
182 }
183
184 #[must_use] pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
186 ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
187 }
188}
189
190impl Clone for ProcessGroup {
191 fn clone(&self) -> Self {
192 Self {
193 backend: Arc::clone(&self.backend),
194 ranks: self.ranks.clone(),
195 }
196 }
197}
198
199#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_process_group_mock() {
209 let pg = ProcessGroup::mock();
210 assert_eq!(pg.rank(), 0);
211 assert_eq!(pg.world_size(), 1);
212 assert_eq!(pg.size(), 1);
213 }
214
215 #[test]
216 fn test_process_group_contains() {
217 let pg = ProcessGroup::mock();
218 assert!(pg.contains(0));
219 assert!(!pg.contains(1));
220 }
221
222 #[test]
223 fn test_world_mock() {
224 let world = World::mock();
225 assert_eq!(world.rank(), 0);
226 assert_eq!(world.world_size(), 1);
227 assert!(world.is_main());
228 }
229
230 #[test]
231 fn test_world_new_group() {
232 let world = World::mock();
233 let group = world.new_group(vec![0]);
234 assert_eq!(group.size(), 1);
235 }
236
237 #[test]
238 fn test_process_group_all_reduce_tensor() {
239 let backends = MockBackend::create_world(2);
240 let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
241
242 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
243 pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
244
245 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
247 }
248
249 #[test]
250 fn test_process_group_broadcast_tensor() {
251 let pg = ProcessGroup::mock();
252
253 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
254 pg.broadcast_tensor(&mut tensor, 0);
255
256 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
257 }
258
259 #[test]
260 fn test_process_group_all_gather_tensor() {
261 let pg = ProcessGroup::mock();
262
263 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
264 let gathered = pg.all_gather_tensor(&tensor);
265
266 assert_eq!(gathered.shape(), &[1, 2]);
267 }
268
269 #[test]
270 fn test_process_group_barrier() {
271 let pg = ProcessGroup::mock();
272 pg.barrier(); }
274
275 #[test]
276 fn test_world_barrier() {
277 let world = World::mock();
278 world.barrier(); }
280
281 #[test]
282 fn test_process_group_clone() {
283 let pg = ProcessGroup::mock();
284 let pg2 = pg.clone();
285 assert_eq!(pg.rank(), pg2.rank());
286 assert_eq!(pg.world_size(), pg2.world_size());
287 }
288}