1use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_tensor::Tensor;
11
12pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
18 pg.all_reduce_tensor(tensor, ReduceOp::Sum);
19}
20
21pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
23 pg.all_reduce_tensor(tensor, ReduceOp::Average);
24}
25
26pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
28 pg.all_reduce_tensor(tensor, ReduceOp::Min);
29}
30
31pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
33 pg.all_reduce_tensor(tensor, ReduceOp::Max);
34}
35
36pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
38 pg.all_reduce_tensor(tensor, ReduceOp::Product);
39}
40
41pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
47 broadcast_from(tensor, 0, pg);
48}
49
50pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
52 pg.broadcast_tensor(tensor, src);
53}
54
55#[must_use]
61pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
62 pg.all_gather_tensor(tensor)
63}
64
65#[must_use]
71pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
72 pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
73}
74
75#[must_use]
77pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
78 pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
79}
80
81pub fn barrier(pg: &ProcessGroup) {
87 pg.barrier();
88}
89
90#[must_use]
92pub fn is_main_process(pg: &ProcessGroup) -> bool {
93 pg.rank() == 0
94}
95
96#[must_use]
98pub fn world_size(pg: &ProcessGroup) -> usize {
99 pg.world_size()
100}
101
102#[must_use]
104pub fn rank(pg: &ProcessGroup) -> usize {
105 pg.rank()
106}
107
108#[must_use]
114pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
115 let shape = tensor.shape();
116 if dim >= shape.len() {
117 return tensor.clone();
118 }
119
120 let world_size = pg.world_size();
121 let rank = pg.rank();
122 let dim_size = shape[dim];
123
124 if dim_size % world_size != 0 {
125 return tensor.clone();
126 }
127
128 let chunk_size = dim_size / world_size;
129 let start = rank * chunk_size;
130 let end = start + chunk_size;
131
132 if shape.len() == 1 && dim == 0 {
134 let data = tensor.to_vec();
135 let chunk = data[start..end].to_vec();
136 return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
137 }
138
139 if shape.len() == 2 && dim == 0 {
141 let data = tensor.to_vec();
142 let cols = shape[1];
143 let mut chunk = Vec::with_capacity(chunk_size * cols);
144 for row in start..end {
145 let row_start = row * cols;
146 let row_end = row_start + cols;
147 chunk.extend_from_slice(&data[row_start..row_end]);
148 }
149 return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
150 }
151
152 tensor.clone()
153}
154
155#[must_use]
157pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
158 let gathered = pg.all_gather_tensor(tensor);
159
160 let world_size = pg.world_size();
162 let shape = tensor.shape();
163
164 if shape.len() == 1 && dim == 0 {
165 let data = gathered.to_vec();
167 return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
168 }
169
170 gathered
171}
172
173pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
179 for grad in gradients.iter_mut() {
180 all_reduce_mean(grad, pg);
181 }
182}
183
184pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
186 all_reduce_mean(gradient, pg);
187}
188
189pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
196 let world_size = pg.world_size();
197 if world_size == 1 {
198 return;
199 }
200
201 pg.backend().all_reduce(data, op);
203}
204
205#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_all_reduce_sum() {
215 let pg = ProcessGroup::mock();
216 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
217
218 all_reduce_sum(&mut tensor, &pg);
219 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
220 }
221
222 #[test]
223 fn test_all_reduce_mean() {
224 let pg = ProcessGroup::mock();
225 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
226
227 all_reduce_mean(&mut tensor, &pg);
228 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
229 }
230
231 #[test]
232 fn test_all_reduce_min() {
233 let pg = ProcessGroup::mock();
234 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
235
236 all_reduce_min(&mut tensor, &pg);
237 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
238 }
239
240 #[test]
241 fn test_all_reduce_max() {
242 let pg = ProcessGroup::mock();
243 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
244
245 all_reduce_max(&mut tensor, &pg);
246 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
247 }
248
249 #[test]
250 fn test_broadcast() {
251 let pg = ProcessGroup::mock();
252 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
253
254 broadcast(&mut tensor, &pg);
255 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
256 }
257
258 #[test]
259 fn test_broadcast_from() {
260 let pg = ProcessGroup::mock();
261 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
262
263 broadcast_from(&mut tensor, 0, &pg);
264 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
265 }
266
267 #[test]
268 fn test_all_gather() {
269 let pg = ProcessGroup::mock();
270 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
271
272 let gathered = all_gather(&tensor, &pg);
273 assert_eq!(gathered.shape(), &[1, 2]);
274 }
275
276 #[test]
277 fn test_reduce_scatter_sum() {
278 let pg = ProcessGroup::mock();
279 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
280
281 let scattered = reduce_scatter_sum(&tensor, &pg);
282 assert_eq!(scattered.shape(), &[2]);
283 }
284
285 #[test]
286 fn test_barrier() {
287 let pg = ProcessGroup::mock();
288 barrier(&pg); }
290
291 #[test]
292 fn test_is_main_process() {
293 let pg = ProcessGroup::mock();
294 assert!(is_main_process(&pg));
295 }
296
297 #[test]
298 fn test_world_size() {
299 let pg = ProcessGroup::mock();
300 assert_eq!(world_size(&pg), 1);
301 }
302
303 #[test]
304 fn test_rank() {
305 let pg = ProcessGroup::mock();
306 assert_eq!(rank(&pg), 0);
307 }
308
309 #[test]
310 fn test_scatter_tensor_1d() {
311 let pg = ProcessGroup::mock();
312 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
313
314 let scattered = scatter_tensor(&tensor, 0, &pg);
315 assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
317 }
318
319 #[test]
320 fn test_gather_tensor() {
321 let pg = ProcessGroup::mock();
322 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
323
324 let gathered = gather_tensor(&tensor, 0, &pg);
325 assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
326 }
327
328 #[test]
329 fn test_sync_gradients() {
330 let pg = ProcessGroup::mock();
331 let mut grads = vec![
332 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
333 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
334 ];
335
336 sync_gradients(&mut grads, &pg);
337
338 assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
339 assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
340 }
341
342 #[test]
343 fn test_sync_gradient() {
344 let pg = ProcessGroup::mock();
345 let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
346
347 sync_gradient(&mut grad, &pg);
348 assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
349 }
350
351 #[test]
352 fn test_ring_all_reduce() {
353 let pg = ProcessGroup::mock();
354 let mut data = vec![1.0, 2.0, 3.0];
355
356 ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
357 assert_eq!(data, vec![1.0, 2.0, 3.0]);
358 }
359}