1use crate::backend::ReduceOp;
19use crate::process_group::ProcessGroup;
20use axonml_tensor::Tensor;
21
22pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
28 pg.all_reduce_tensor(tensor, ReduceOp::Sum);
29}
30
31pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
33 pg.all_reduce_tensor(tensor, ReduceOp::Average);
34}
35
36pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
38 pg.all_reduce_tensor(tensor, ReduceOp::Min);
39}
40
41pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
43 pg.all_reduce_tensor(tensor, ReduceOp::Max);
44}
45
46pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
48 pg.all_reduce_tensor(tensor, ReduceOp::Product);
49}
50
51pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
57 broadcast_from(tensor, 0, pg);
58}
59
60pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
62 pg.broadcast_tensor(tensor, src);
63}
64
65#[must_use]
71pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
72 pg.all_gather_tensor(tensor)
73}
74
75#[must_use]
81pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
82 pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
83}
84
85#[must_use]
87pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
88 pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
89}
90
91pub fn barrier(pg: &ProcessGroup) {
97 pg.barrier();
98}
99
100#[must_use]
102pub fn is_main_process(pg: &ProcessGroup) -> bool {
103 pg.rank() == 0
104}
105
106#[must_use]
108pub fn world_size(pg: &ProcessGroup) -> usize {
109 pg.world_size()
110}
111
112#[must_use]
114pub fn rank(pg: &ProcessGroup) -> usize {
115 pg.rank()
116}
117
118#[must_use]
124pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
125 let shape = tensor.shape();
126 if dim >= shape.len() {
127 return tensor.clone();
128 }
129
130 let world_size = pg.world_size();
131 let rank = pg.rank();
132 let dim_size = shape[dim];
133
134 if dim_size % world_size != 0 {
135 return tensor.clone();
136 }
137
138 let chunk_size = dim_size / world_size;
139 let start = rank * chunk_size;
140 let end = start + chunk_size;
141
142 if shape.len() == 1 && dim == 0 {
144 let data = tensor.to_vec();
145 let chunk = data[start..end].to_vec();
146 return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
147 }
148
149 if shape.len() == 2 && dim == 0 {
151 let data = tensor.to_vec();
152 let cols = shape[1];
153 let mut chunk = Vec::with_capacity(chunk_size * cols);
154 for row in start..end {
155 let row_start = row * cols;
156 let row_end = row_start + cols;
157 chunk.extend_from_slice(&data[row_start..row_end]);
158 }
159 return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
160 }
161
162 tensor.clone()
163}
164
165#[must_use]
167pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
168 let gathered = pg.all_gather_tensor(tensor);
169
170 let world_size = pg.world_size();
172 let shape = tensor.shape();
173
174 if shape.len() == 1 && dim == 0 {
175 let data = gathered.to_vec();
177 return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
178 }
179
180 gathered
181}
182
183pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
189 for grad in gradients.iter_mut() {
190 all_reduce_mean(grad, pg);
191 }
192}
193
194pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
196 all_reduce_mean(gradient, pg);
197}
198
199pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
206 let world_size = pg.world_size();
207 if world_size == 1 {
208 return;
209 }
210
211 pg.backend().all_reduce(data, op);
213}
214
215#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_all_reduce_sum() {
225 let pg = ProcessGroup::mock();
226 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
227
228 all_reduce_sum(&mut tensor, &pg);
229 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
230 }
231
232 #[test]
233 fn test_all_reduce_mean() {
234 let pg = ProcessGroup::mock();
235 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
236
237 all_reduce_mean(&mut tensor, &pg);
238 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
239 }
240
241 #[test]
242 fn test_all_reduce_min() {
243 let pg = ProcessGroup::mock();
244 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
245
246 all_reduce_min(&mut tensor, &pg);
247 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
248 }
249
250 #[test]
251 fn test_all_reduce_max() {
252 let pg = ProcessGroup::mock();
253 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
254
255 all_reduce_max(&mut tensor, &pg);
256 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
257 }
258
259 #[test]
260 fn test_broadcast() {
261 let pg = ProcessGroup::mock();
262 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
263
264 broadcast(&mut tensor, &pg);
265 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
266 }
267
268 #[test]
269 fn test_broadcast_from() {
270 let pg = ProcessGroup::mock();
271 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
272
273 broadcast_from(&mut tensor, 0, &pg);
274 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
275 }
276
277 #[test]
278 fn test_all_gather() {
279 let pg = ProcessGroup::mock();
280 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
281
282 let gathered = all_gather(&tensor, &pg);
283 assert_eq!(gathered.shape(), &[1, 2]);
284 }
285
286 #[test]
287 fn test_reduce_scatter_sum() {
288 let pg = ProcessGroup::mock();
289 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
290
291 let scattered = reduce_scatter_sum(&tensor, &pg);
292 assert_eq!(scattered.shape(), &[2]);
293 }
294
295 #[test]
296 fn test_barrier() {
297 let pg = ProcessGroup::mock();
298 barrier(&pg); }
300
301 #[test]
302 fn test_is_main_process() {
303 let pg = ProcessGroup::mock();
304 assert!(is_main_process(&pg));
305 }
306
307 #[test]
308 fn test_world_size() {
309 let pg = ProcessGroup::mock();
310 assert_eq!(world_size(&pg), 1);
311 }
312
313 #[test]
314 fn test_rank() {
315 let pg = ProcessGroup::mock();
316 assert_eq!(rank(&pg), 0);
317 }
318
319 #[test]
320 fn test_scatter_tensor_1d() {
321 let pg = ProcessGroup::mock();
322 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
323
324 let scattered = scatter_tensor(&tensor, 0, &pg);
325 assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
327 }
328
329 #[test]
330 fn test_gather_tensor() {
331 let pg = ProcessGroup::mock();
332 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
333
334 let gathered = gather_tensor(&tensor, 0, &pg);
335 assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
336 }
337
338 #[test]
339 fn test_sync_gradients() {
340 let pg = ProcessGroup::mock();
341 let mut grads = vec![
342 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
343 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
344 ];
345
346 sync_gradients(&mut grads, &pg);
347
348 assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
349 assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
350 }
351
352 #[test]
353 fn test_sync_gradient() {
354 let pg = ProcessGroup::mock();
355 let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
356
357 sync_gradient(&mut grad, &pg);
358 assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
359 }
360
361 #[test]
362 fn test_ring_all_reduce() {
363 let pg = ProcessGroup::mock();
364 let mut data = vec![1.0, 2.0, 3.0];
365
366 ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
367 assert_eq!(data, vec![1.0, 2.0, 3.0]);
368 }
369}