1use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_tensor::Tensor;
20
21pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
27 pg.all_reduce_tensor(tensor, ReduceOp::Sum);
28}
29
30pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
32 pg.all_reduce_tensor(tensor, ReduceOp::Average);
33}
34
35pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
37 pg.all_reduce_tensor(tensor, ReduceOp::Min);
38}
39
40pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
42 pg.all_reduce_tensor(tensor, ReduceOp::Max);
43}
44
45pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
47 pg.all_reduce_tensor(tensor, ReduceOp::Product);
48}
49
50pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
56 broadcast_from(tensor, 0, pg);
57}
58
59pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
61 pg.broadcast_tensor(tensor, src);
62}
63
64#[must_use]
70pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
71 pg.all_gather_tensor(tensor)
72}
73
74#[must_use]
80pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
81 pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
82}
83
84#[must_use]
86pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
87 pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
88}
89
90pub fn barrier(pg: &ProcessGroup) {
96 pg.barrier();
97}
98
99#[must_use]
101pub fn is_main_process(pg: &ProcessGroup) -> bool {
102 pg.rank() == 0
103}
104
105#[must_use]
107pub fn world_size(pg: &ProcessGroup) -> usize {
108 pg.world_size()
109}
110
111#[must_use]
113pub fn rank(pg: &ProcessGroup) -> usize {
114 pg.rank()
115}
116
117#[must_use]
123pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
124 let shape = tensor.shape();
125 if dim >= shape.len() {
126 return tensor.clone();
127 }
128
129 let world_size = pg.world_size();
130 let rank = pg.rank();
131 let dim_size = shape[dim];
132
133 if dim_size % world_size != 0 {
134 return tensor.clone();
135 }
136
137 let chunk_size = dim_size / world_size;
138 let start = rank * chunk_size;
139 let end = start + chunk_size;
140
141 if shape.len() == 1 && dim == 0 {
143 let data = tensor.to_vec();
144 let chunk = data[start..end].to_vec();
145 return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
146 }
147
148 if shape.len() == 2 && dim == 0 {
150 let data = tensor.to_vec();
151 let cols = shape[1];
152 let mut chunk = Vec::with_capacity(chunk_size * cols);
153 for row in start..end {
154 let row_start = row * cols;
155 let row_end = row_start + cols;
156 chunk.extend_from_slice(&data[row_start..row_end]);
157 }
158 return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
159 }
160
161 tensor.clone()
162}
163
164#[must_use]
166pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
167 let gathered = pg.all_gather_tensor(tensor);
168
169 let world_size = pg.world_size();
171 let shape = tensor.shape();
172
173 if shape.len() == 1 && dim == 0 {
174 let data = gathered.to_vec();
176 return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
177 }
178
179 gathered
180}
181
182pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
188 for grad in gradients.iter_mut() {
189 all_reduce_mean(grad, pg);
190 }
191}
192
193pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
195 all_reduce_mean(gradient, pg);
196}
197
198pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
205 let world_size = pg.world_size();
206 if world_size == 1 {
207 return;
208 }
209
210 pg.backend().all_reduce(data, op);
212}
213
214#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_all_reduce_sum() {
224 let pg = ProcessGroup::mock();
225 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
226
227 all_reduce_sum(&mut tensor, &pg);
228 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
229 }
230
231 #[test]
232 fn test_all_reduce_mean() {
233 let pg = ProcessGroup::mock();
234 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
235
236 all_reduce_mean(&mut tensor, &pg);
237 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
238 }
239
240 #[test]
241 fn test_all_reduce_min() {
242 let pg = ProcessGroup::mock();
243 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
244
245 all_reduce_min(&mut tensor, &pg);
246 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
247 }
248
249 #[test]
250 fn test_all_reduce_max() {
251 let pg = ProcessGroup::mock();
252 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
253
254 all_reduce_max(&mut tensor, &pg);
255 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
256 }
257
258 #[test]
259 fn test_broadcast() {
260 let pg = ProcessGroup::mock();
261 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
262
263 broadcast(&mut tensor, &pg);
264 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
265 }
266
267 #[test]
268 fn test_broadcast_from() {
269 let pg = ProcessGroup::mock();
270 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
271
272 broadcast_from(&mut tensor, 0, &pg);
273 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
274 }
275
276 #[test]
277 fn test_all_gather() {
278 let pg = ProcessGroup::mock();
279 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
280
281 let gathered = all_gather(&tensor, &pg);
282 assert_eq!(gathered.shape(), &[1, 2]);
283 }
284
285 #[test]
286 fn test_reduce_scatter_sum() {
287 let pg = ProcessGroup::mock();
288 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
289
290 let scattered = reduce_scatter_sum(&tensor, &pg);
291 assert_eq!(scattered.shape(), &[2]);
292 }
293
294 #[test]
295 fn test_barrier() {
296 let pg = ProcessGroup::mock();
297 barrier(&pg); }
299
300 #[test]
301 fn test_is_main_process() {
302 let pg = ProcessGroup::mock();
303 assert!(is_main_process(&pg));
304 }
305
306 #[test]
307 fn test_world_size() {
308 let pg = ProcessGroup::mock();
309 assert_eq!(world_size(&pg), 1);
310 }
311
312 #[test]
313 fn test_rank() {
314 let pg = ProcessGroup::mock();
315 assert_eq!(rank(&pg), 0);
316 }
317
318 #[test]
319 fn test_scatter_tensor_1d() {
320 let pg = ProcessGroup::mock();
321 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
322
323 let scattered = scatter_tensor(&tensor, 0, &pg);
324 assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
326 }
327
328 #[test]
329 fn test_gather_tensor() {
330 let pg = ProcessGroup::mock();
331 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
332
333 let gathered = gather_tensor(&tensor, 0, &pg);
334 assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
335 }
336
337 #[test]
338 fn test_sync_gradients() {
339 let pg = ProcessGroup::mock();
340 let mut grads = vec![
341 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
342 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
343 ];
344
345 sync_gradients(&mut grads, &pg);
346
347 assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
348 assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
349 }
350
351 #[test]
352 fn test_sync_gradient() {
353 let pg = ProcessGroup::mock();
354 let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
355
356 sync_gradient(&mut grad, &pg);
357 assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
358 }
359
360 #[test]
361 fn test_ring_all_reduce() {
362 let pg = ProcessGroup::mock();
363 let mut data = vec![1.0, 2.0, 3.0];
364
365 ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
366 assert_eq!(data, vec![1.0, 2.0, 3.0]);
367 }
368}