axonml_distributed/
comm.rs

1//! Communication - High-level Communication Utilities
2//!
3//! Provides high-level functions for common distributed communication patterns.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::backend::ReduceOp;
9use crate::process_group::ProcessGroup;
10use axonml_tensor::Tensor;
11
12// =============================================================================
13// All-Reduce Operations
14// =============================================================================
15
16/// Performs all-reduce sum on a tensor.
17pub fn all_reduce_sum(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
18    pg.all_reduce_tensor(tensor, ReduceOp::Sum);
19}
20
21/// Performs all-reduce mean on a tensor.
22pub fn all_reduce_mean(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
23    pg.all_reduce_tensor(tensor, ReduceOp::Average);
24}
25
26/// Performs all-reduce min on a tensor.
27pub fn all_reduce_min(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
28    pg.all_reduce_tensor(tensor, ReduceOp::Min);
29}
30
31/// Performs all-reduce max on a tensor.
32pub fn all_reduce_max(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
33    pg.all_reduce_tensor(tensor, ReduceOp::Max);
34}
35
36/// Performs all-reduce product on a tensor.
37pub fn all_reduce_product(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
38    pg.all_reduce_tensor(tensor, ReduceOp::Product);
39}
40
41// =============================================================================
42// Broadcast Operations
43// =============================================================================
44
45/// Broadcasts a tensor from the root rank (0).
46pub fn broadcast(tensor: &mut Tensor<f32>, pg: &ProcessGroup) {
47    broadcast_from(tensor, 0, pg);
48}
49
50/// Broadcasts a tensor from a specific rank.
51pub fn broadcast_from(tensor: &mut Tensor<f32>, src: usize, pg: &ProcessGroup) {
52    pg.broadcast_tensor(tensor, src);
53}
54
55// =============================================================================
56// Gather Operations
57// =============================================================================
58
59/// All-gathers a tensor across all ranks.
60#[must_use] pub fn all_gather(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
61    pg.all_gather_tensor(tensor)
62}
63
64// =============================================================================
65// Reduce-Scatter Operations
66// =============================================================================
67
68/// Reduce-scatters a tensor with sum.
69#[must_use] pub fn reduce_scatter_sum(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
70    pg.reduce_scatter_tensor(tensor, ReduceOp::Sum)
71}
72
73/// Reduce-scatters a tensor with mean.
74#[must_use] pub fn reduce_scatter_mean(tensor: &Tensor<f32>, pg: &ProcessGroup) -> Tensor<f32> {
75    pg.reduce_scatter_tensor(tensor, ReduceOp::Average)
76}
77
78// =============================================================================
79// Utility Functions
80// =============================================================================
81
82/// Synchronizes all processes.
83pub fn barrier(pg: &ProcessGroup) {
84    pg.barrier();
85}
86
87/// Checks if this is the main process (rank 0).
88#[must_use] pub fn is_main_process(pg: &ProcessGroup) -> bool {
89    pg.rank() == 0
90}
91
92/// Returns the world size.
93#[must_use] pub fn world_size(pg: &ProcessGroup) -> usize {
94    pg.world_size()
95}
96
97/// Returns the current rank.
98#[must_use] pub fn rank(pg: &ProcessGroup) -> usize {
99    pg.rank()
100}
101
102// =============================================================================
103// Model Parallel Utilities
104// =============================================================================
105
106/// Splits a tensor along a dimension for model parallelism.
107#[must_use] pub fn scatter_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
108    let shape = tensor.shape();
109    if dim >= shape.len() {
110        return tensor.clone();
111    }
112
113    let world_size = pg.world_size();
114    let rank = pg.rank();
115    let dim_size = shape[dim];
116
117    if dim_size % world_size != 0 {
118        return tensor.clone();
119    }
120
121    let chunk_size = dim_size / world_size;
122    let start = rank * chunk_size;
123    let end = start + chunk_size;
124
125    // For 1D tensors along dim 0
126    if shape.len() == 1 && dim == 0 {
127        let data = tensor.to_vec();
128        let chunk = data[start..end].to_vec();
129        return Tensor::from_vec(chunk, &[chunk_size]).unwrap();
130    }
131
132    // For 2D tensors along dim 0
133    if shape.len() == 2 && dim == 0 {
134        let data = tensor.to_vec();
135        let cols = shape[1];
136        let mut chunk = Vec::with_capacity(chunk_size * cols);
137        for row in start..end {
138            let row_start = row * cols;
139            let row_end = row_start + cols;
140            chunk.extend_from_slice(&data[row_start..row_end]);
141        }
142        return Tensor::from_vec(chunk, &[chunk_size, cols]).unwrap();
143    }
144
145    tensor.clone()
146}
147
148/// Gathers scattered tensor chunks back together.
149#[must_use] pub fn gather_tensor(tensor: &Tensor<f32>, dim: usize, pg: &ProcessGroup) -> Tensor<f32> {
150    let gathered = pg.all_gather_tensor(tensor);
151
152    // Reshape gathered tensor
153    let world_size = pg.world_size();
154    let shape = tensor.shape();
155
156    if shape.len() == 1 && dim == 0 {
157        // Flatten [world_size, chunk_size] to [total_size]
158        let data = gathered.to_vec();
159        return Tensor::from_vec(data, &[shape[0] * world_size]).unwrap();
160    }
161
162    gathered
163}
164
165// =============================================================================
166// Gradient Synchronization
167// =============================================================================
168
169/// Synchronizes gradients by averaging across all processes.
170pub fn sync_gradients(gradients: &mut [Tensor<f32>], pg: &ProcessGroup) {
171    for grad in gradients.iter_mut() {
172        all_reduce_mean(grad, pg);
173    }
174}
175
176/// Synchronizes a single gradient tensor.
177pub fn sync_gradient(gradient: &mut Tensor<f32>, pg: &ProcessGroup) {
178    all_reduce_mean(gradient, pg);
179}
180
181// =============================================================================
182// Ring Communication Pattern
183// =============================================================================
184
185/// Ring all-reduce implementation (educational).
186/// In practice, the backend handles this more efficiently.
187pub fn ring_all_reduce(data: &mut [f32], pg: &ProcessGroup, op: ReduceOp) {
188    let world_size = pg.world_size();
189    if world_size == 1 {
190        return;
191    }
192
193    // Use backend's all-reduce
194    pg.backend().all_reduce(data, op);
195}
196
197// =============================================================================
198// Tests
199// =============================================================================
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_all_reduce_sum() {
207        let pg = ProcessGroup::mock();
208        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
209
210        all_reduce_sum(&mut tensor, &pg);
211        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
212    }
213
214    #[test]
215    fn test_all_reduce_mean() {
216        let pg = ProcessGroup::mock();
217        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
218
219        all_reduce_mean(&mut tensor, &pg);
220        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
221    }
222
223    #[test]
224    fn test_all_reduce_min() {
225        let pg = ProcessGroup::mock();
226        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
227
228        all_reduce_min(&mut tensor, &pg);
229        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
230    }
231
232    #[test]
233    fn test_all_reduce_max() {
234        let pg = ProcessGroup::mock();
235        let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
236
237        all_reduce_max(&mut tensor, &pg);
238        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0]);
239    }
240
241    #[test]
242    fn test_broadcast() {
243        let pg = ProcessGroup::mock();
244        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
245
246        broadcast(&mut tensor, &pg);
247        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
248    }
249
250    #[test]
251    fn test_broadcast_from() {
252        let pg = ProcessGroup::mock();
253        let mut tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
254
255        broadcast_from(&mut tensor, 0, &pg);
256        assert_eq!(tensor.to_vec(), vec![1.0, 2.0]);
257    }
258
259    #[test]
260    fn test_all_gather() {
261        let pg = ProcessGroup::mock();
262        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
263
264        let gathered = all_gather(&tensor, &pg);
265        assert_eq!(gathered.shape(), &[1, 2]);
266    }
267
268    #[test]
269    fn test_reduce_scatter_sum() {
270        let pg = ProcessGroup::mock();
271        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
272
273        let scattered = reduce_scatter_sum(&tensor, &pg);
274        assert_eq!(scattered.shape(), &[2]);
275    }
276
277    #[test]
278    fn test_barrier() {
279        let pg = ProcessGroup::mock();
280        barrier(&pg); // Should not deadlock
281    }
282
283    #[test]
284    fn test_is_main_process() {
285        let pg = ProcessGroup::mock();
286        assert!(is_main_process(&pg));
287    }
288
289    #[test]
290    fn test_world_size() {
291        let pg = ProcessGroup::mock();
292        assert_eq!(world_size(&pg), 1);
293    }
294
295    #[test]
296    fn test_rank() {
297        let pg = ProcessGroup::mock();
298        assert_eq!(rank(&pg), 0);
299    }
300
301    #[test]
302    fn test_scatter_tensor_1d() {
303        let pg = ProcessGroup::mock();
304        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
305
306        let scattered = scatter_tensor(&tensor, 0, &pg);
307        // With world_size=1, should return full tensor
308        assert_eq!(scattered.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
309    }
310
311    #[test]
312    fn test_gather_tensor() {
313        let pg = ProcessGroup::mock();
314        let tensor = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
315
316        let gathered = gather_tensor(&tensor, 0, &pg);
317        assert_eq!(gathered.to_vec(), vec![1.0, 2.0]);
318    }
319
320    #[test]
321    fn test_sync_gradients() {
322        let pg = ProcessGroup::mock();
323        let mut grads = vec![
324            Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
325            Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
326        ];
327
328        sync_gradients(&mut grads, &pg);
329
330        assert_eq!(grads[0].to_vec(), vec![1.0, 2.0]);
331        assert_eq!(grads[1].to_vec(), vec![3.0, 4.0]);
332    }
333
334    #[test]
335    fn test_sync_gradient() {
336        let pg = ProcessGroup::mock();
337        let mut grad = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
338
339        sync_gradient(&mut grad, &pg);
340        assert_eq!(grad.to_vec(), vec![1.0, 2.0, 3.0]);
341    }
342
343    #[test]
344    fn test_ring_all_reduce() {
345        let pg = ProcessGroup::mock();
346        let mut data = vec![1.0, 2.0, 3.0];
347
348        ring_all_reduce(&mut data, &pg, ReduceOp::Sum);
349        assert_eq!(data, vec![1.0, 2.0, 3.0]);
350    }
351}