Skip to main content

rlx_driver/
collective.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Collective ops as algorithms over [`SymmetricTransport`]
17//! (plan #12).
18//!
19//! Borrowed from MAX's `kernels/src/comm/{allgather, allreduce,
20//! reducescatter, allreduce_residual_rmsnorm_fp8, rms_norm_fp8}.mojo`.
21//! Each collective is a small algorithm that uses
22//! [`SymmetricTransport::put`] / `get` / `barrier` to move
23//! tensors between ranks. Pure data layer — the transport is
24//! pluggable, so the same algorithm runs against
25//! `LocalTransport` (single-machine emulation) today and a
26//! future MPI / NVSHMEM transport on a real cluster.
27//!
28//! Element-wise reductions parameterized by [`ReduceKind`].
29//! All algorithms operate on `f32` slices today; quantized fp8
30//! variants from MAX (`allreduce_residual_rmsnorm_fp8`) are
31//! per-precision impls that slot in once a quantized model
32//! lands.
33
34use crate::symmetric::{CollectiveError, Rank, SymmetricBuffer, SymmetricTransport};
35
36/// Element-wise reduction operator for collectives.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum ReduceKind {
39    Sum,
40    Mean,
41    Max,
42    Min,
43}
44
45impl ReduceKind {
46    fn fold(self, acc: f32, x: f32) -> f32 {
47        match self {
48            Self::Sum => acc + x,
49            Self::Mean => acc + x, // divide at the end
50            Self::Max => acc.max(x),
51            Self::Min => acc.min(x),
52        }
53    }
54    fn finalize(self, acc: f32, n: usize) -> f32 {
55        match self {
56            Self::Mean => acc / (n as f32),
57            _ => acc,
58        }
59    }
60    fn identity(self) -> f32 {
61        match self {
62            Self::Sum | Self::Mean => 0.0,
63            Self::Max => f32::NEG_INFINITY,
64            Self::Min => f32::INFINITY,
65        }
66    }
67}
68
69/// AllReduce: every rank ends up with `op({values from every rank})`.
70///
71/// Naïve algorithm — every rank reads every other rank's slot
72/// and combines. O(n_ranks²) communications, fine for small
73/// rank counts. Real impls use ring-reduce / tree-reduce; we
74/// pick simplicity since LocalTransport's "comm" is memcpy.
75///
76/// `local` carries this rank's contribution on entry; on exit
77/// it carries the reduced result. Element count must match the
78/// per-rank `len` of `buf` (in bytes: 4 * elements).
79pub fn all_reduce<T: SymmetricTransport>(
80    transport: &T,
81    buf: SymmetricBuffer, // shape (offset, len) shared across ranks
82    local: &mut [f32],
83    op: ReduceKind,
84) -> Result<(), CollectiveError> {
85    let elems = buf.len / 4;
86    if local.len() != elems {
87        return Err(CollectiveError::LengthMismatch {
88            expected: elems,
89            got: local.len(),
90        });
91    }
92    let me = transport.this_rank();
93    let n = transport.num_ranks();
94
95    // Step 1: write our contribution into our slot.
96    let our_buf = SymmetricBuffer {
97        rank: me,
98        offset: buf.offset,
99        len: buf.len,
100    };
101    let bytes = unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, buf.len) };
102    transport.put(our_buf, bytes)?;
103
104    // Step 2: barrier so every rank has written its slot.
105    transport.barrier()?;
106
107    // Step 3: read every rank's slot and reduce.
108    let mut acc: Vec<f32> = vec![op.identity(); elems];
109    let mut scratch_bytes = vec![0u8; buf.len];
110    for r in 0..n {
111        let src = SymmetricBuffer {
112            rank: Rank(r),
113            offset: buf.offset,
114            len: buf.len,
115        };
116        transport.get(src, &mut scratch_bytes)?;
117        let scratch =
118            unsafe { std::slice::from_raw_parts(scratch_bytes.as_ptr() as *const f32, elems) };
119        for (i, &v) in scratch.iter().enumerate() {
120            acc[i] = op.fold(acc[i], v);
121        }
122    }
123    for v in acc.iter_mut() {
124        *v = op.finalize(*v, n as usize);
125    }
126    local.copy_from_slice(&acc);
127    Ok(())
128}
129
130/// AllGather: every rank ends up with the concatenation of all
131/// per-rank `local` slices, in rank order.
132///
133/// `local.len()` is the per-rank chunk size; `output.len()` must
134/// equal `num_ranks * local.len()`. Output rank `r`'s
135/// contribution lands at `output[r*local.len()..(r+1)*local.len()]`.
136pub fn all_gather<T: SymmetricTransport>(
137    transport: &T,
138    buf: SymmetricBuffer, // per-rank slot
139    local: &[f32],
140    output: &mut [f32],
141) -> Result<(), CollectiveError> {
142    let elems_per_rank = buf.len / 4;
143    let n = transport.num_ranks() as usize;
144    if local.len() != elems_per_rank {
145        return Err(CollectiveError::LengthMismatch {
146            expected: elems_per_rank,
147            got: local.len(),
148        });
149    }
150    if output.len() != n * elems_per_rank {
151        return Err(CollectiveError::LengthMismatch {
152            expected: n * elems_per_rank,
153            got: output.len(),
154        });
155    }
156
157    let me = transport.this_rank();
158    let our_buf = SymmetricBuffer {
159        rank: me,
160        offset: buf.offset,
161        len: buf.len,
162    };
163    let bytes = unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, buf.len) };
164    transport.put(our_buf, bytes)?;
165    transport.barrier()?;
166
167    let mut scratch_bytes = vec![0u8; buf.len];
168    for r in 0..n {
169        let src = SymmetricBuffer {
170            rank: Rank(r as u32),
171            offset: buf.offset,
172            len: buf.len,
173        };
174        transport.get(src, &mut scratch_bytes)?;
175        let chunk = unsafe {
176            std::slice::from_raw_parts(scratch_bytes.as_ptr() as *const f32, elems_per_rank)
177        };
178        let dst_start = r * elems_per_rank;
179        output[dst_start..dst_start + elems_per_rank].copy_from_slice(chunk);
180    }
181    Ok(())
182}
183
184/// ReduceScatter: equivalent to AllReduce followed by partition
185/// — every rank ends up with one `chunk_size`-element slice of
186/// the reduced result. Rank `r` gets element indices
187/// `[r*chunk_size, (r+1)*chunk_size)`.
188///
189/// `local.len()` is the full vector (`num_ranks * chunk_size`);
190/// `output.len()` is `chunk_size`.
191pub fn reduce_scatter<T: SymmetricTransport>(
192    transport: &T,
193    buf: SymmetricBuffer,
194    local: &[f32],
195    output: &mut [f32],
196    op: ReduceKind,
197) -> Result<(), CollectiveError> {
198    let total = buf.len / 4;
199    let n = transport.num_ranks() as usize;
200    if !total.is_multiple_of(n) {
201        return Err(CollectiveError::TransportError {
202            reason: format!("reduce_scatter: total elements {total} not divisible by {n} ranks"),
203        });
204    }
205    let chunk = total / n;
206    if local.len() != total {
207        return Err(CollectiveError::LengthMismatch {
208            expected: total,
209            got: local.len(),
210        });
211    }
212    if output.len() != chunk {
213        return Err(CollectiveError::LengthMismatch {
214            expected: chunk,
215            got: output.len(),
216        });
217    }
218
219    // Reuse all_reduce — works on a scratch copy of `local`,
220    // then this rank picks its slice.
221    let me = transport.this_rank().0 as usize;
222    let mut full = local.to_vec();
223    all_reduce(transport, buf, &mut full, op)?;
224    output.copy_from_slice(&full[me * chunk..(me + 1) * chunk]);
225    Ok(())
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::symmetric::LocalTransport;
232
233    /// 4 ranks, each contributes [r+1, r+1, r+1, r+1] (so ranks
234    /// hold 1, 2, 3, 4). After AllReduce::Sum each rank should
235    /// see [10, 10, 10, 10].
236    #[test]
237    fn all_reduce_sum_across_4_ranks() {
238        let n_ranks = 4u32;
239        let elems = 4usize;
240        let bytes = elems * 4;
241        let ts = LocalTransport::fan_out(n_ranks, bytes);
242        let _buf = SymmetricBuffer {
243            rank: Rank(0),
244            offset: 0,
245            len: bytes,
246        };
247
248        // Each rank's local data + reduced output.
249        let mut state: Vec<Vec<f32>> = (0..n_ranks).map(|r| vec![(r + 1) as f32; elems]).collect();
250
251        // Run all_reduce sequentially; LocalTransport's barrier
252        // counter accumulates across calls, so n_ranks calls
253        // satisfy each barrier. We pre-write contributions for
254        // every rank so the barrier-then-get phase sees data.
255        // Step 1: each rank puts its contribution.
256        for (r, t) in ts.iter().enumerate() {
257            let our_buf = SymmetricBuffer {
258                rank: Rank(r as u32),
259                offset: 0,
260                len: bytes,
261            };
262            let raw = unsafe { std::slice::from_raw_parts(state[r].as_ptr() as *const u8, bytes) };
263            t.put(our_buf, raw).unwrap();
264        }
265        // Step 2: each rank reduces. We can't use the public
266        // all_reduce since it does its own put + barrier (which
267        // double-counts after our manual put above). Inline the
268        // reduce step instead.
269        for (r, t) in ts.iter().enumerate() {
270            let mut acc = vec![0f32; elems];
271            let mut scratch = vec![0u8; bytes];
272            for src_r in 0..n_ranks {
273                let src = SymmetricBuffer {
274                    rank: Rank(src_r),
275                    offset: 0,
276                    len: bytes,
277                };
278                t.get(src, &mut scratch).unwrap();
279                let view =
280                    unsafe { std::slice::from_raw_parts(scratch.as_ptr() as *const f32, elems) };
281                for (i, &v) in view.iter().enumerate() {
282                    acc[i] += v;
283                }
284            }
285            state[r] = acc;
286        }
287
288        for (r, slot) in state.iter().enumerate() {
289            assert_eq!(slot, &vec![10.0; elems], "rank {r} after all-reduce");
290        }
291    }
292
293    #[test]
294    fn all_gather_concatenates_in_rank_order() {
295        let n_ranks = 3u32;
296        let chunk = 2usize;
297        let bytes = chunk * 4;
298        let ts = LocalTransport::fan_out(n_ranks, bytes);
299        let _buf = SymmetricBuffer {
300            rank: Rank(0),
301            offset: 0,
302            len: bytes,
303        };
304
305        // Each rank contributes [10*r, 10*r+1].
306        let local: Vec<Vec<f32>> = (0..n_ranks)
307            .map(|r| {
308                let r = r as f32;
309                vec![10.0 * r, 10.0 * r + 1.0]
310            })
311            .collect();
312
313        // Step 1: each rank puts.
314        for (r, t) in ts.iter().enumerate() {
315            let our_buf = SymmetricBuffer {
316                rank: Rank(r as u32),
317                offset: 0,
318                len: bytes,
319            };
320            let raw = unsafe { std::slice::from_raw_parts(local[r].as_ptr() as *const u8, bytes) };
321            t.put(our_buf, raw).unwrap();
322        }
323        // Step 2: each rank gathers.
324        for (r_idx, t) in ts.iter().enumerate() {
325            let mut output = vec![0f32; n_ranks as usize * chunk];
326            let mut scratch = vec![0u8; bytes];
327            for src_r in 0..n_ranks {
328                let src = SymmetricBuffer {
329                    rank: Rank(src_r),
330                    offset: 0,
331                    len: bytes,
332                };
333                t.get(src, &mut scratch).unwrap();
334                let view =
335                    unsafe { std::slice::from_raw_parts(scratch.as_ptr() as *const f32, chunk) };
336                let dst_start = src_r as usize * chunk;
337                output[dst_start..dst_start + chunk].copy_from_slice(view);
338            }
339            assert_eq!(
340                output,
341                vec![0.0, 1.0, 10.0, 11.0, 20.0, 21.0],
342                "rank {r_idx} after all-gather"
343            );
344        }
345    }
346
347    #[test]
348    fn reduce_kind_max_takes_pointwise_max() {
349        let mut acc = ReduceKind::Max.identity();
350        for v in [3.0, 1.0, 7.0, -2.0] {
351            acc = ReduceKind::Max.fold(acc, v);
352        }
353        assert_eq!(acc, 7.0);
354    }
355
356    #[test]
357    fn reduce_kind_mean_divides_at_finalize() {
358        let mut acc = ReduceKind::Mean.identity();
359        for v in [2.0, 4.0, 6.0, 8.0] {
360            acc = ReduceKind::Mean.fold(acc, v);
361        }
362        assert_eq!(ReduceKind::Mean.finalize(acc, 4), 5.0);
363    }
364}