1use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_tensor::Tensor;
11
12pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
18 pg.all_reduce_tensor(tensor, ReduceOp::Sum);
19}
20
21pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
23 pg.all_reduce_tensor(tensor, ReduceOp::Average);
24}
25
26pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
28 pg.all_reduce_tensor(tensor, ReduceOp::Min);
29}
30
31pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
33 pg.all_reduce_tensor(tensor, ReduceOp::Max);
34}
35
36pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
38 pg.all_reduce_tensor(tensor, ReduceOp::Product);
39}
40
41pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
47 broadcast_from(tensor, 0, pg);
48}
49
50pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
52 pg.broadcast_tensor(tensor, src);
53}
54
55#[must_use] pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
61 pg.all_gather_tensor(tensor)
62}
63
64#[must_use] pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
70 pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
71}
72
73#[must_use] pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
75 pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
76}
77
78pub fn barrier(pg: &ProcessGroup) {
84 pg.barrier();
85}
86
87#[must_use] pub fn is_main_process(pg: &ProcessGroup) -> bool {
89 pg.rank() == 0
90}
91
92#[must_use] pub fn world_size(pg: &ProcessGroup) -> usize {
94 pg.world_size()
95}
96
97#[must_use] pub fn rank(pg: &ProcessGroup) -> usize {
99 pg.rank()
100}
101
102#[must_use] pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
108 let shape = tensor.shape();
109 if dim >= shape.len() {
110 return tensor.clone();
111 }
112
113 let world_size = pg.world_size();
114 let rank = pg.rank();
115 let dim_size = shape[dim];
116
117 if dim_size % world_size != 0 {
118 return tensor.clone();
119 }
120
121 let chunk_size = dim_size / world_size;
122 let start = rank * chunk_size;
123 let end = start + chunk_size;
124
125 if shape.len() == 1 && dim == 0 {
127 let data = tensor.to_vec();
128 let chunk = data[start..end].to_vec();
129 return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
130 }
131
132 if shape.len() == 2 && dim == 0 {
134 let data = tensor.to_vec();
135 let cols = shape[1];
136 let mut chunk = Vec::with_capacity(chunk_size * cols);
137 for row in start..end {
138 let row_start = row * cols;
139 let row_end = row_start + cols;
140 chunk.extend_from_slice(&data[row_start..row_end]);
141 }
142 return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
143 }
144
145 tensor.clone()
146}
147
148#[must_use] pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
150 let gathered = pg.all_gather_tensor(tensor);
151
152 let world_size = pg.world_size();
154 let shape = tensor.shape();
155
156 if shape.len() == 1 && dim == 0 {
157 let data = gathered.to_vec();
159 return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
160 }
161
162 gathered
163}
164
165pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
171 for grad in gradients.iter_mut() {
172 all_reduce_mean(grad, pg);
173 }
174}
175
176pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
178 all_reduce_mean(gradient, pg);
179}
180
181pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
188 let world_size = pg.world_size();
189 if world_size == 1 {
190 return;
191 }
192
193 pg.backend().all_reduce(data, op);
195}
196
197#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_all_reduce_sum() {
207 let pg = ProcessGroup::mock();
208 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
209
210 all_reduce_sum(&mut tensor, &pg);
211 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
212 }
213
214 #[test]
215 fn test_all_reduce_mean() {
216 let pg = ProcessGroup::mock();
217 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
218
219 all_reduce_mean(&mut tensor, &pg);
220 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
221 }
222
223 #[test]
224 fn test_all_reduce_min() {
225 let pg = ProcessGroup::mock();
226 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
227
228 all_reduce_min(&mut tensor, &pg);
229 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
230 }
231
232 #[test]
233 fn test_all_reduce_max() {
234 let pg = ProcessGroup::mock();
235 let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
236
237 all_reduce_max(&mut tensor, &pg);
238 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
239 }
240
241 #[test]
242 fn test_broadcast() {
243 let pg = ProcessGroup::mock();
244 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
245
246 broadcast(&mut tensor, &pg);
247 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
248 }
249
250 #[test]
251 fn test_broadcast_from() {
252 let pg = ProcessGroup::mock();
253 let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
254
255 broadcast_from(&mut tensor, 0, &pg);
256 assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
257 }
258
259 #[test]
260 fn test_all_gather() {
261 let pg = ProcessGroup::mock();
262 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
263
264 let gathered = all_gather(&tensor, &pg);
265 assert_eq!(gathered.shape(), &[1, 2]);
266 }
267
268 #[test]
269 fn test_reduce_scatter_sum() {
270 let pg = ProcessGroup::mock();
271 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
272
273 let scattered = reduce_scatter_sum(&tensor, &pg);
274 assert_eq!(scattered.shape(), &[2]);
275 }
276
277 #[test]
278 fn test_barrier() {
279 let pg = ProcessGroup::mock();
280 barrier(&pg); }
282
283 #[test]
284 fn test_is_main_process() {
285 let pg = ProcessGroup::mock();
286 assert!(is_main_process(&pg));
287 }
288
289 #[test]
290 fn test_world_size() {
291 let pg = ProcessGroup::mock();
292 assert_eq!(world_size(&pg), 1);
293 }
294
295 #[test]
296 fn test_rank() {
297 let pg = ProcessGroup::mock();
298 assert_eq!(rank(&pg), 0);
299 }
300
301 #[test]
302 fn test_scatter_tensor_1d() {
303 let pg = ProcessGroup::mock();
304 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
305
306 let scattered = scatter_tensor(&tensor, 0, &pg);
307 assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
309 }
310
311 #[test]
312 fn test_gather_tensor() {
313 let pg = ProcessGroup::mock();
314 let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
315
316 let gathered = gather_tensor(&tensor, 0, &pg);
317 assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
318 }
319
320 #[test]
321 fn test_sync_gradients() {
322 let pg = ProcessGroup::mock();
323 let mut grads = vec![
324 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
325 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
326 ];
327
328 sync_gradients(&mut grads, &pg);
329
330 assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
331 assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
332 }
333
334 #[test]
335 fn test_sync_gradient() {
336 let pg = ProcessGroup::mock();
337 let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
338
339 sync_gradient(&mut grad, &pg);
340 assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
341 }
342
343 #[test]
344 fn test_ring_all_reduce() {
345 let pg = ProcessGroup::mock();
346 let mut data = vec![1.0, 2.0, 3.0];
347
348 ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
349 assert_eq!(data, vec![1.0, 2.0, 3.0]);
350 }
351}