axonml_distributed/
process_group.rs1use crate::backend::{Backend, MockBackend, ReduceOp};
10use axonml_tensor::Tensor;
11use std::sync::Arc;
12
13pub struct ProcessGroup {
19 backend: Arc<dyn Backend>,
20 ranks: Vec<usize>,
21}
22
23impl ProcessGroup {
24 pub fn new(backend: Arc<dyn Backend>) -> Self {
26 let world_size = backend.world_size();
27 Self {
28 backend,
29 ranks: (0..world_size).collect(),
30 }
31 }
32
33 pub fn with_ranks(backend: Arc<dyn Backend>, ranks: Vec<usize>) -> Self {
35 Self { backend, ranks }
36 }
37
38 #[must_use] pub fn mock() -> Self {
40 Self::new(Arc::new(MockBackend::single()))
41 }
42
43 #[must_use] pub fn backend(&self) -> &dyn Backend {
45 self.backend.as_ref()
46 }
47
48 #[must_use] pub fn rank(&self) -> usize {
50 self.backend.rank()
51 }
52
53 #[must_use] pub fn world_size(&self) -> usize {
55 self.backend.world_size()
56 }
57
58 #[must_use] pub fn size(&self) -> usize {
60 self.ranks.len()
61 }
62
63 #[must_use] pub fn ranks(&self) -> &[usize] {
65 &self.ranks
66 }
67
68 #[must_use] pub fn contains(&self, rank: usize) -> bool {
70 self.ranks.contains(&rank)
71 }
72
73 pub fn barrier(&self) {
75 self.backend.barrier();
76 }
77
78 pub fn all_reduce_tensor(&self, tensor: &mut Tensor<f32>, op: ReduceOp) {
80 let mut data = tensor.to_vec();
81 self.backend.all_reduce(&mut data, op);
82 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
83 }
84
85 pub fn broadcast_tensor(&self, tensor: &mut Tensor<f32>, src: usize) {
87 let mut data = tensor.to_vec();
88 self.backend.broadcast(&mut data, src);
89 *tensor = Tensor::from_vec(data, tensor.shape()).unwrap();
90 }
91
92 #[must_use] pub fn all_gather_tensor(&self, send_tensor: &Tensor<f32>) -> Tensor<f32> {
94 let send_data = send_tensor.to_vec();
95 let mut recv_data = vec![0.0; send_data.len() * self.world_size()];
96 self.backend.all_gather(&send_data, &mut recv_data);
97
98 let mut new_shape = vec![self.world_size()];
100 new_shape.extend(send_tensor.shape());
101 Tensor::from_vec(recv_data, &new_shape).unwrap()
102 }
103
104 #[must_use] pub fn reduce_scatter_tensor(&self, send_tensor: &Tensor<f32>, op: ReduceOp) -> Tensor<f32> {
106 let send_data = send_tensor.to_vec();
107 let chunk_size = send_data.len() / self.world_size();
108 let mut recv_data = vec![0.0; chunk_size];
109 self.backend.reduce_scatter(&send_data, &mut recv_data, op);
110
111 let original_shape = send_tensor.shape();
113 let mut new_shape = original_shape.to_vec();
114 if !new_shape.is_empty() {
115 new_shape[0] /= self.world_size();
116 }
117 Tensor::from_vec(recv_data, &new_shape).unwrap()
118 }
119}
120
121pub struct World {
127 default_group: ProcessGroup,
128}
129
130impl World {
131 pub fn init(backend: Arc<dyn Backend>) -> Self {
133 Self {
134 default_group: ProcessGroup::new(backend),
135 }
136 }
137
138 #[must_use] pub fn mock() -> Self {
140 Self {
141 default_group: ProcessGroup::mock(),
142 }
143 }
144
145 #[must_use] pub fn default_group(&self) -> &ProcessGroup {
147 &self.default_group
148 }
149
150 #[must_use] pub fn rank(&self) -> usize {
152 self.default_group.rank()
153 }
154
155 #[must_use] pub fn world_size(&self) -> usize {
157 self.default_group.world_size()
158 }
159
160 #[must_use] pub fn is_main(&self) -> bool {
162 self.rank() == 0
163 }
164
165 pub fn barrier(&self) {
167 self.default_group.barrier();
168 }
169
170 #[must_use] pub fn new_group(&self, ranks: Vec<usize>) -> ProcessGroup {
172 ProcessGroup::with_ranks(Arc::clone(&self.default_group.backend), ranks)
173 }
174}
175
176impl Clone for ProcessGroup {
177 fn clone(&self) -> Self {
178 Self {
179 backend: Arc::clone(&self.backend),
180 ranks: self.ranks.clone(),
181 }
182 }
183}
184
185#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_process_group_mock() {
195 let pg = ProcessGroup::mock();
196 assert_eq!(pg.rank(), 0);
197 assert_eq!(pg.world_size(), 1);
198 assert_eq!(pg.size(), 1);
199 }
200
201 #[test]
202 fn test_process_group_contains() {
203 let pg = ProcessGroup::mock();
204 assert!(pg.contains(0));
205 assert!(!pg.contains(1));
206 }
207
208 #[test]
209 fn test_world_mock() {
210 let world = World::mock();
211 assert_eq!(world.rank(), 0);
212 assert_eq!(world.world_size(), 1);
213 assert!(world.is_main());
214 }
215
216 #[test]
217 fn test_world_new_group() {
218 let world = World::mock();
219 let group = world.new_group(vec![0]);
220 assert_eq!(group.size(), 1);
221 }
222
223 #[test]
224 fn test_process_group_all_reduce_tensor() {
225 let backends = MockBackend::create_world(2);
226 let pg0 = ProcessGroup::new(Arc::new(backends.into_iter().next().unwrap()));
227
228 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
229 pg0.all_reduce_tensor(&mut tensor, ReduceOp::Sum);
230
231 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
233 }
234
235 #[test]
236 fn test_process_group_broadcast_tensor() {
237 let pg = ProcessGroup::mock();
238
239 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
240 pg.broadcast_tensor(&mut tensor, 0);
241
242 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
243 }
244
245 #[test]
246 fn test_process_group_all_gather_tensor() {
247 let pg = ProcessGroup::mock();
248
249 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
250 let gathered = pg.all_gather_tensor(&tensor);
251
252 assert_eq!(gathered.shape(), &[1, 2]);
253 }
254
255 #[test]
256 fn test_process_group_barrier() {
257 let pg = ProcessGroup::mock();
258 pg.barrier(); }
260
261 #[test]
262 fn test_world_barrier() {
263 let world = World::mock();
264 world.barrier(); }
266
267 #[test]
268 fn test_process_group_clone() {
269 let pg = ProcessGroup::mock();
270 let pg2 = pg.clone();
271 assert_eq!(pg.rank(), pg2.rank());
272 assert_eq!(pg.world_size(), pg2.world_size());
273 }
274}