Skip to main content

axonml_distributed/
comm.rs

1//! Communication - High-level Communication Utilities
2//!
3//! # File
4//! `crates/axonml-distributed/src/comm.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::backend::ReduceOp;
19use crate::process_group::ProcessGroup;
20use axonml_tensor::Tensor;
21
22// =============================================================================
23// All-Reduce Operations
24// =============================================================================
25
26/// Performs all-reduce sum on a tensor.
27pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
28    pg.all_reduce_tensor(tensor, ReduceOp::Sum);
29}
30
31/// Performs all-reduce mean on a tensor.
32pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
33    pg.all_reduce_tensor(tensor, ReduceOp::Average);
34}
35
36/// Performs all-reduce min on a tensor.
37pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
38    pg.all_reduce_tensor(tensor, ReduceOp::Min);
39}
40
41/// Performs all-reduce max on a tensor.
42pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
43    pg.all_reduce_tensor(tensor, ReduceOp::Max);
44}
45
46/// Performs all-reduce product on a tensor.
47pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
48    pg.all_reduce_tensor(tensor, ReduceOp::Product);
49}
50
51// =============================================================================
52// Broadcast Operations
53// =============================================================================
54
55/// Broadcasts a tensor from the root rank (0).
56pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
57    broadcast_from(tensor, 0, pg);
58}
59
60/// Broadcasts a tensor from a specific rank.
61pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
62    pg.broadcast_tensor(tensor, src);
63}
64
65// =============================================================================
66// Gather Operations
67// =============================================================================
68
69/// All-gathers a tensor across all ranks.
70#[must_use]
71pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
72    pg.all_gather_tensor(tensor)
73}
74
75// =============================================================================
76// Reduce-Scatter Operations
77// =============================================================================
78
79/// Reduce-scatters a tensor with sum.
80#[must_use]
81pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
82    pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
83}
84
85/// Reduce-scatters a tensor with mean.
86#[must_use]
87pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
88    pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
89}
90
91// =============================================================================
92// Utility Functions
93// =============================================================================
94
95/// Synchronizes all processes.
96pub fn barrier(pg: &ProcessGroup) {
97    pg.barrier();
98}
99
100/// Checks if this is the main process (rank 0).
101#[must_use]
102pub fn is_main_process(pg: &ProcessGroup) -> bool {
103    pg.rank() == 0
104}
105
106/// Returns the world size.
107#[must_use]
108pub fn world_size(pg: &ProcessGroup) -> usize {
109    pg.world_size()
110}
111
112/// Returns the current rank.
113#[must_use]
114pub fn rank(pg: &ProcessGroup) -> usize {
115    pg.rank()
116}
117
118// =============================================================================
119// Model Parallel Utilities
120// =============================================================================
121
122/// Splits a tensor along a dimension for model parallelism.
123#[must_use]
124pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
125    let shape = tensor.shape();
126    if dim >= shape.len() {
127        return tensor.clone();
128    }
129
130    let world_size = pg.world_size();
131    let rank = pg.rank();
132    let dim_size = shape[dim];
133
134    if dim_size % world_size != 0 {
135        return tensor.clone();
136    }
137
138    let chunk_size = dim_size / world_size;
139    let start = rank * chunk_size;
140    let end = start + chunk_size;
141
142    // For 1D tensors along dim 0
143    if shape.len() == 1 && dim == 0 {
144        let data = tensor.to_vec();
145        let chunk = data[start..end].to_vec();
146        return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
147    }
148
149    // For 2D tensors along dim 0
150    if shape.len() == 2 && dim == 0 {
151        let data = tensor.to_vec();
152        let cols = shape[1];
153        let mut chunk = Vec::with_capacity(chunk_size * cols);
154        for row in start..end {
155            let row_start = row * cols;
156            let row_end = row_start + cols;
157            chunk.extend_from_slice(&data[row_start..row_end]);
158        }
159        return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
160    }
161
162    tensor.clone()
163}
164
165/// Gathers scattered tensor chunks back together.
166#[must_use]
167pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
168    let gathered = pg.all_gather_tensor(tensor);
169
170    // Reshape gathered tensor
171    let world_size = pg.world_size();
172    let shape = tensor.shape();
173
174    if shape.len() == 1 && dim == 0 {
175        // Flatten [world_size, chunk_size] to [total_size]
176        let data = gathered.to_vec();
177        return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
178    }
179
180    gathered
181}
182
183// =============================================================================
184// Gradient Synchronization
185// =============================================================================
186
187/// Synchronizes gradients by averaging across all processes.
188pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
189    for grad in gradients.iter_mut() {
190        all_reduce_mean(grad, pg);
191    }
192}
193
194/// Synchronizes a single gradient tensor.
195pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
196    all_reduce_mean(gradient, pg);
197}
198
199// =============================================================================
200// Ring Communication Pattern
201// =============================================================================
202
203/// Ring all-reduce implementation (educational).
204/// In practice, the backend handles this more efficiently.
205pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
206    let world_size = pg.world_size();
207    if world_size == 1 {
208        return;
209    }
210
211    // Use backend's all-reduce
212    pg.backend().all_reduce(data, op);
213}
214
215// =============================================================================
216// Tests
217// =============================================================================
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_all_reduce_sum() {
225        let pg = ProcessGroup::mock();
226        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
227
228        all_reduce_sum(&mut tensor, &pg);
229        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
230    }
231
232    #[test]
233    fn test_all_reduce_mean() {
234        let pg = ProcessGroup::mock();
235        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
236
237        all_reduce_mean(&mut tensor, &pg);
238        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
239    }
240
241    #[test]
242    fn test_all_reduce_min() {
243        let pg = ProcessGroup::mock();
244        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
245
246        all_reduce_min(&mut tensor, &pg);
247        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
248    }
249
250    #[test]
251    fn test_all_reduce_max() {
252        let pg = ProcessGroup::mock();
253        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
254
255        all_reduce_max(&mut tensor, &pg);
256        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
257    }
258
259    #[test]
260    fn test_broadcast() {
261        let pg = ProcessGroup::mock();
262        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
263
264        broadcast(&mut tensor, &pg);
265        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
266    }
267
268    #[test]
269    fn test_broadcast_from() {
270        let pg = ProcessGroup::mock();
271        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
272
273        broadcast_from(&mut tensor, 0, &pg);
274        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
275    }
276
277    #[test]
278    fn test_all_gather() {
279        let pg = ProcessGroup::mock();
280        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
281
282        let gathered = all_gather(&tensor, &pg);
283        assert_eq!(gathered.shape(), &[1, 2]);
284    }
285
286    #[test]
287    fn test_reduce_scatter_sum() {
288        let pg = ProcessGroup::mock();
289        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
290
291        let scattered = reduce_scatter_sum(&tensor, &pg);
292        assert_eq!(scattered.shape(), &[2]);
293    }
294
295    #[test]
296    fn test_barrier() {
297        let pg = ProcessGroup::mock();
298        barrier(&pg); // Should not deadlock
299    }
300
301    #[test]
302    fn test_is_main_process() {
303        let pg = ProcessGroup::mock();
304        assert!(is_main_process(&pg));
305    }
306
307    #[test]
308    fn test_world_size() {
309        let pg = ProcessGroup::mock();
310        assert_eq!(world_size(&pg), 1);
311    }
312
313    #[test]
314    fn test_rank() {
315        let pg = ProcessGroup::mock();
316        assert_eq!(rank(&pg), 0);
317    }
318
319    #[test]
320    fn test_scatter_tensor_1d() {
321        let pg = ProcessGroup::mock();
322        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
323
324        let scattered = scatter_tensor(&tensor, 0, &pg);
325        // With world_size=1, should return full tensor
326        assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
327    }
328
329    #[test]
330    fn test_gather_tensor() {
331        let pg = ProcessGroup::mock();
332        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
333
334        let gathered = gather_tensor(&tensor, 0, &pg);
335        assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
336    }
337
338    #[test]
339    fn test_sync_gradients() {
340        let pg = ProcessGroup::mock();
341        let mut grads = vec![
342            Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
343            Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
344        ];
345
346        sync_gradients(&mut grads, &pg);
347
348        assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
349        assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
350    }
351
352    #[test]
353    fn test_sync_gradient() {
354        let pg = ProcessGroup::mock();
355        let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
356
357        sync_gradient(&mut grad, &pg);
358        assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
359    }
360
361    #[test]
362    fn test_ring_all_reduce() {
363        let pg = ProcessGroup::mock();
364        let mut data = vec![1.0, 2.0, 3.0];
365
366        ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
367        assert_eq!(data, vec![1.0, 2.0, 3.0]);
368    }
369}