Skip to main content

axonml_distributed/
comm.rs

1//! Communication - High-level Communication Utilities
2//!
3//! Provides high-level functions for common distributed communication patterns.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_tensor::Tensor;
11
12// =============================================================================
13// All-Reduce Operations
14// =============================================================================
15
16/// Performs all-reduce sum on a tensor.
17pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
18    pg.all_reduce_tensor(tensor, ReduceOp::Sum);
19}
20
21/// Performs all-reduce mean on a tensor.
22pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
23    pg.all_reduce_tensor(tensor, ReduceOp::Average);
24}
25
26/// Performs all-reduce min on a tensor.
27pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
28    pg.all_reduce_tensor(tensor, ReduceOp::Min);
29}
30
31/// Performs all-reduce max on a tensor.
32pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
33    pg.all_reduce_tensor(tensor, ReduceOp::Max);
34}
35
36/// Performs all-reduce product on a tensor.
37pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
38    pg.all_reduce_tensor(tensor, ReduceOp::Product);
39}
40
41// =============================================================================
42// Broadcast Operations
43// =============================================================================
44
45/// Broadcasts a tensor from the root rank (0).
46pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
47    broadcast_from(tensor, 0, pg);
48}
49
50/// Broadcasts a tensor from a specific rank.
51pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
52    pg.broadcast_tensor(tensor, src);
53}
54
55// =============================================================================
56// Gather Operations
57// =============================================================================
58
59/// All-gathers a tensor across all ranks.
60#[must_use]
61pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
62    pg.all_gather_tensor(tensor)
63}
64
65// =============================================================================
66// Reduce-Scatter Operations
67// =============================================================================
68
69/// Reduce-scatters a tensor with sum.
70#[must_use]
71pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
72    pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
73}
74
75/// Reduce-scatters a tensor with mean.
76#[must_use]
77pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
78    pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
79}
80
81// =============================================================================
82// Utility Functions
83// =============================================================================
84
85/// Synchronizes all processes.
86pub fn barrier(pg: &ProcessGroup) {
87    pg.barrier();
88}
89
90/// Checks if this is the main process (rank 0).
91#[must_use]
92pub fn is_main_process(pg: &ProcessGroup) -> bool {
93    pg.rank() == 0
94}
95
96/// Returns the world size.
97#[must_use]
98pub fn world_size(pg: &ProcessGroup) -> usize {
99    pg.world_size()
100}
101
102/// Returns the current rank.
103#[must_use]
104pub fn rank(pg: &ProcessGroup) -> usize {
105    pg.rank()
106}
107
108// =============================================================================
109// Model Parallel Utilities
110// =============================================================================
111
112/// Splits a tensor along a dimension for model parallelism.
113#[must_use]
114pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
115    let shape = tensor.shape();
116    if dim >= shape.len() {
117        return tensor.clone();
118    }
119
120    let world_size = pg.world_size();
121    let rank = pg.rank();
122    let dim_size = shape[dim];
123
124    if dim_size % world_size != 0 {
125        return tensor.clone();
126    }
127
128    let chunk_size = dim_size / world_size;
129    let start = rank * chunk_size;
130    let end = start + chunk_size;
131
132    // For 1D tensors along dim 0
133    if shape.len() == 1 && dim == 0 {
134        let data = tensor.to_vec();
135        let chunk = data[start..end].to_vec();
136        return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
137    }
138
139    // For 2D tensors along dim 0
140    if shape.len() == 2 && dim == 0 {
141        let data = tensor.to_vec();
142        let cols = shape[1];
143        let mut chunk = Vec::with_capacity(chunk_size * cols);
144        for row in start..end {
145            let row_start = row * cols;
146            let row_end = row_start + cols;
147            chunk.extend_from_slice(&data[row_start..row_end]);
148        }
149        return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
150    }
151
152    tensor.clone()
153}
154
155/// Gathers scattered tensor chunks back together.
156#[must_use]
157pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
158    let gathered = pg.all_gather_tensor(tensor);
159
160    // Reshape gathered tensor
161    let world_size = pg.world_size();
162    let shape = tensor.shape();
163
164    if shape.len() == 1 && dim == 0 {
165        // Flatten [world_size, chunk_size] to [total_size]
166        let data = gathered.to_vec();
167        return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
168    }
169
170    gathered
171}
172
173// =============================================================================
174// Gradient Synchronization
175// =============================================================================
176
177/// Synchronizes gradients by averaging across all processes.
178pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
179    for grad in gradients.iter_mut() {
180        all_reduce_mean(grad, pg);
181    }
182}
183
184/// Synchronizes a single gradient tensor.
185pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
186    all_reduce_mean(gradient, pg);
187}
188
189// =============================================================================
190// Ring Communication Pattern
191// =============================================================================
192
193/// Ring all-reduce implementation (educational).
194/// In practice, the backend handles this more efficiently.
195pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
196    let world_size = pg.world_size();
197    if world_size == 1 {
198        return;
199    }
200
201    // Use backend's all-reduce
202    pg.backend().all_reduce(data, op);
203}
204
205// =============================================================================
206// Tests
207// =============================================================================
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_all_reduce_sum() {
215        let pg = ProcessGroup::mock();
216        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
217
218        all_reduce_sum(&mut tensor, &pg);
219        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
220    }
221
222    #[test]
223    fn test_all_reduce_mean() {
224        let pg = ProcessGroup::mock();
225        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
226
227        all_reduce_mean(&mut tensor, &pg);
228        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
229    }
230
231    #[test]
232    fn test_all_reduce_min() {
233        let pg = ProcessGroup::mock();
234        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
235
236        all_reduce_min(&mut tensor, &pg);
237        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
238    }
239
240    #[test]
241    fn test_all_reduce_max() {
242        let pg = ProcessGroup::mock();
243        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
244
245        all_reduce_max(&mut tensor, &pg);
246        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
247    }
248
249    #[test]
250    fn test_broadcast() {
251        let pg = ProcessGroup::mock();
252        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
253
254        broadcast(&mut tensor, &pg);
255        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
256    }
257
258    #[test]
259    fn test_broadcast_from() {
260        let pg = ProcessGroup::mock();
261        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
262
263        broadcast_from(&mut tensor, 0, &pg);
264        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
265    }
266
267    #[test]
268    fn test_all_gather() {
269        let pg = ProcessGroup::mock();
270        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
271
272        let gathered = all_gather(&tensor, &pg);
273        assert_eq!(gathered.shape(), &[1, 2]);
274    }
275
276    #[test]
277    fn test_reduce_scatter_sum() {
278        let pg = ProcessGroup::mock();
279        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
280
281        let scattered = reduce_scatter_sum(&tensor, &pg);
282        assert_eq!(scattered.shape(), &[2]);
283    }
284
285    #[test]
286    fn test_barrier() {
287        let pg = ProcessGroup::mock();
288        barrier(&pg); // Should not deadlock
289    }
290
291    #[test]
292    fn test_is_main_process() {
293        let pg = ProcessGroup::mock();
294        assert!(is_main_process(&pg));
295    }
296
297    #[test]
298    fn test_world_size() {
299        let pg = ProcessGroup::mock();
300        assert_eq!(world_size(&pg), 1);
301    }
302
303    #[test]
304    fn test_rank() {
305        let pg = ProcessGroup::mock();
306        assert_eq!(rank(&pg), 0);
307    }
308
309    #[test]
310    fn test_scatter_tensor_1d() {
311        let pg = ProcessGroup::mock();
312        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
313
314        let scattered = scatter_tensor(&tensor, 0, &pg);
315        // With world_size=1, should return full tensor
316        assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
317    }
318
319    #[test]
320    fn test_gather_tensor() {
321        let pg = ProcessGroup::mock();
322        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
323
324        let gathered = gather_tensor(&tensor, 0, &pg);
325        assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
326    }
327
328    #[test]
329    fn test_sync_gradients() {
330        let pg = ProcessGroup::mock();
331        let mut grads = vec![
332            Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
333            Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
334        ];
335
336        sync_gradients(&mut grads, &pg);
337
338        assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
339        assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
340    }
341
342    #[test]
343    fn test_sync_gradient() {
344        let pg = ProcessGroup::mock();
345        let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
346
347        sync_gradient(&mut grad, &pg);
348        assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
349    }
350
351    #[test]
352    fn test_ring_all_reduce() {
353        let pg = ProcessGroup::mock();
354        let mut data = vec![1.0, 2.0, 3.0];
355
356        ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
357        assert_eq!(data, vec![1.0, 2.0, 3.0]);
358    }
359}