Skip to main content

axonml_distributed/
comm.rs

1//! Communication - High-level Communication Utilities
2//!
3//! # File
4//! `crates/axonml-distributed/src/comm.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::backend::ReduceOp;
18use crate::process_group::ProcessGroup;
19use axonml_tensor::Tensor;
20
21// =============================================================================
22// All-Reduce Operations
23// =============================================================================
24
25/// Performs all-reduce sum on a tensor.
26pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
27    pg.all_reduce_tensor(tensor, ReduceOp::Sum);
28}
29
30/// Performs all-reduce mean on a tensor.
31pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
32    pg.all_reduce_tensor(tensor, ReduceOp::Average);
33}
34
35/// Performs all-reduce min on a tensor.
36pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
37    pg.all_reduce_tensor(tensor, ReduceOp::Min);
38}
39
40/// Performs all-reduce max on a tensor.
41pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
42    pg.all_reduce_tensor(tensor, ReduceOp::Max);
43}
44
45/// Performs all-reduce product on a tensor.
46pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
47    pg.all_reduce_tensor(tensor, ReduceOp::Product);
48}
49
50// =============================================================================
51// Broadcast Operations
52// =============================================================================
53
54/// Broadcasts a tensor from the root rank (0).
55pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
56    broadcast_from(tensor, 0, pg);
57}
58
59/// Broadcasts a tensor from a specific rank.
60pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
61    pg.broadcast_tensor(tensor, src);
62}
63
64// =============================================================================
65// Gather Operations
66// =============================================================================
67
68/// All-gathers a tensor across all ranks.
69#[must_use]
70pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
71    pg.all_gather_tensor(tensor)
72}
73
74// =============================================================================
75// Reduce-Scatter Operations
76// =============================================================================
77
78/// Reduce-scatters a tensor with sum.
79#[must_use]
80pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
81    pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
82}
83
84/// Reduce-scatters a tensor with mean.
85#[must_use]
86pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
87    pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
88}
89
90// =============================================================================
91// Utility Functions
92// =============================================================================
93
94/// Synchronizes all processes.
95pub fn barrier(pg: &ProcessGroup) {
96    pg.barrier();
97}
98
99/// Checks if this is the main process (rank 0).
100#[must_use]
101pub fn is_main_process(pg: &ProcessGroup) -> bool {
102    pg.rank() == 0
103}
104
105/// Returns the world size.
106#[must_use]
107pub fn world_size(pg: &ProcessGroup) -> usize {
108    pg.world_size()
109}
110
111/// Returns the current rank.
112#[must_use]
113pub fn rank(pg: &ProcessGroup) -> usize {
114    pg.rank()
115}
116
117// =============================================================================
118// Model Parallel Utilities
119// =============================================================================
120
121/// Splits a tensor along a dimension for model parallelism.
122#[must_use]
123pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
124    let shape = tensor.shape();
125    if dim >= shape.len() {
126        return tensor.clone();
127    }
128
129    let world_size = pg.world_size();
130    let rank = pg.rank();
131    let dim_size = shape[dim];
132
133    if dim_size % world_size != 0 {
134        return tensor.clone();
135    }
136
137    let chunk_size = dim_size / world_size;
138    let start = rank * chunk_size;
139    let end = start + chunk_size;
140
141    // For 1D tensors along dim 0
142    if shape.len() == 1 && dim == 0 {
143        let data = tensor.to_vec();
144        let chunk = data[start..end].to_vec();
145        return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
146    }
147
148    // For 2D tensors along dim 0
149    if shape.len() == 2 && dim == 0 {
150        let data = tensor.to_vec();
151        let cols = shape[1];
152        let mut chunk = Vec::with_capacity(chunk_size * cols);
153        for row in start..end {
154            let row_start = row * cols;
155            let row_end = row_start + cols;
156            chunk.extend_from_slice(&data[row_start..row_end]);
157        }
158        return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
159    }
160
161    tensor.clone()
162}
163
164/// Gathers scattered tensor chunks back together.
165#[must_use]
166pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
167    let gathered = pg.all_gather_tensor(tensor);
168
169    // Reshape gathered tensor
170    let world_size = pg.world_size();
171    let shape = tensor.shape();
172
173    if shape.len() == 1 && dim == 0 {
174        // Flatten [world_size, chunk_size] to [total_size]
175        let data = gathered.to_vec();
176        return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
177    }
178
179    gathered
180}
181
182// =============================================================================
183// Gradient Synchronization
184// =============================================================================
185
186/// Synchronizes gradients by averaging across all processes.
187pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
188    for grad in gradients.iter_mut() {
189        all_reduce_mean(grad, pg);
190    }
191}
192
193/// Synchronizes a single gradient tensor.
194pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
195    all_reduce_mean(gradient, pg);
196}
197
198// =============================================================================
199// Ring Communication Pattern
200// =============================================================================
201
202/// Ring all-reduce implementation (educational).
203/// In practice, the backend handles this more efficiently.
204pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
205    let world_size = pg.world_size();
206    if world_size == 1 {
207        return;
208    }
209
210    // Use backend's all-reduce
211    pg.backend().all_reduce(data, op);
212}
213
214// =============================================================================
215// Tests
216// =============================================================================
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_all_reduce_sum() {
224        let pg = ProcessGroup::mock();
225        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
226
227        all_reduce_sum(&mut tensor, &pg);
228        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
229    }
230
231    #[test]
232    fn test_all_reduce_mean() {
233        let pg = ProcessGroup::mock();
234        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
235
236        all_reduce_mean(&mut tensor, &pg);
237        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
238    }
239
240    #[test]
241    fn test_all_reduce_min() {
242        let pg = ProcessGroup::mock();
243        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
244
245        all_reduce_min(&mut tensor, &pg);
246        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
247    }
248
249    #[test]
250    fn test_all_reduce_max() {
251        let pg = ProcessGroup::mock();
252        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
253
254        all_reduce_max(&mut tensor, &pg);
255        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
256    }
257
258    #[test]
259    fn test_broadcast() {
260        let pg = ProcessGroup::mock();
261        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
262
263        broadcast(&mut tensor, &pg);
264        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
265    }
266
267    #[test]
268    fn test_broadcast_from() {
269        let pg = ProcessGroup::mock();
270        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
271
272        broadcast_from(&mut tensor, 0, &pg);
273        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
274    }
275
276    #[test]
277    fn test_all_gather() {
278        let pg = ProcessGroup::mock();
279        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
280
281        let gathered = all_gather(&tensor, &pg);
282        assert_eq!(gathered.shape(), &[1, 2]);
283    }
284
285    #[test]
286    fn test_reduce_scatter_sum() {
287        let pg = ProcessGroup::mock();
288        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
289
290        let scattered = reduce_scatter_sum(&tensor, &pg);
291        assert_eq!(scattered.shape(), &[2]);
292    }
293
294    #[test]
295    fn test_barrier() {
296        let pg = ProcessGroup::mock();
297        barrier(&pg); // Should not deadlock
298    }
299
300    #[test]
301    fn test_is_main_process() {
302        let pg = ProcessGroup::mock();
303        assert!(is_main_process(&pg));
304    }
305
306    #[test]
307    fn test_world_size() {
308        let pg = ProcessGroup::mock();
309        assert_eq!(world_size(&pg), 1);
310    }
311
312    #[test]
313    fn test_rank() {
314        let pg = ProcessGroup::mock();
315        assert_eq!(rank(&pg), 0);
316    }
317
318    #[test]
319    fn test_scatter_tensor_1d() {
320        let pg = ProcessGroup::mock();
321        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
322
323        let scattered = scatter_tensor(&tensor, 0, &pg);
324        // With world_size=1, should return full tensor
325        assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
326    }
327
328    #[test]
329    fn test_gather_tensor() {
330        let pg = ProcessGroup::mock();
331        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
332
333        let gathered = gather_tensor(&tensor, 0, &pg);
334        assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
335    }
336
337    #[test]
338    fn test_sync_gradients() {
339        let pg = ProcessGroup::mock();
340        let mut grads = vec![
341            Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
342            Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
343        ];
344
345        sync_gradients(&mut grads, &pg);
346
347        assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
348        assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
349    }
350
351    #[test]
352    fn test_sync_gradient() {
353        let pg = ProcessGroup::mock();
354        let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
355
356        sync_gradient(&mut grad, &pg);
357        assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
358    }
359
360    #[test]
361    fn test_ring_all_reduce() {
362        let pg = ProcessGroup::mock();
363        let mut data = vec![1.0, 2.0, 3.0];
364
365        ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
366        assert_eq!(data, vec![1.0, 2.0, 3.0]);
367    }
368}