Skip to main content

rlx_driver/
symmetric.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Symmetric-memory primitives for collective ops (plan #49).
17//!
18//! Borrowed from MAX's `kernels/src/shmem/{_nvshmem, _rocshmem,
19//! _mpi, shmem_buffer, ep_comm}.mojo`. Symmetric heaps are the
20//! standard abstraction for multi-device collective comm: every
21//! rank allocates the same logical region; addresses are
22//! identical across ranks (the physical backing differs). One
23//! rank can `put` directly into another's slot at the same
24//! offset.
25//!
26//! The Rust spelling here ships:
27//!
28//!   - [`SymmetricHeap`] — owns the per-rank physical storage.
29//!     Single-machine emulation today (one `Vec<u8>` per rank);
30//!     the trait surface ([`SymmetricTransport`]) is what a
31//!     future MPI / NVSHMEM-equivalent / process-shared-memory
32//!     impl plugs into.
33//!   - [`SymmetricBuffer`] — a `(rank, offset, len)` view.
34//!   - [`SymmetricTransport`] — the trait every transport
35//!     impl satisfies. `LocalTransport` is the in-process,
36//!     single-machine impl used by tests + collective-algo
37//!     correctness checks (plan #12).
38//!
39//! This is the foundation; #12 builds AllReduce / AllGather /
40//! ReduceScatter on top.
41
42use std::sync::{Arc, RwLock};
43
44/// Identifier for a participant in a collective. Ranks are
45/// `0..num_ranks` and stay stable for the lifetime of a
46/// transport.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
48pub struct Rank(pub u32);
49
50/// `(rank, offset, len)` view into a symmetric heap. The same
51/// `(offset, len)` pair is valid on every rank — that's what
52/// "symmetric" means.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct SymmetricBuffer {
55    pub rank: Rank,
56    pub offset: usize,
57    pub len: usize,
58}
59
60/// One-sided operation surface. `put(buf, src)` writes `src`
61/// into `buf.rank`'s memory at `buf.offset`; `get(buf, dst)`
62/// reads from `buf.rank`'s memory into `dst`. Both calls block
63/// until completion (a future async impl can return a future).
64pub trait SymmetricTransport {
65    /// How many ranks participate.
66    fn num_ranks(&self) -> u32;
67    /// This process's rank.
68    fn this_rank(&self) -> Rank;
69
70    /// Write `src` into `buf`. Errors on length mismatch.
71    fn put(&self, buf: SymmetricBuffer, src: &[u8]) -> Result<(), CollectiveError>;
72    /// Read from `buf` into `dst`. Errors on length mismatch.
73    fn get(&self, buf: SymmetricBuffer, dst: &mut [u8]) -> Result<(), CollectiveError>;
74
75    /// Block until every rank has reached this barrier. Local
76    /// emulation is a memory fence + a counter; real transports
77    /// implement their own.
78    fn barrier(&self) -> Result<(), CollectiveError>;
79}
80
81#[derive(Debug, Clone)]
82pub enum CollectiveError {
83    /// `(rank, offset, len)` walks past the heap.
84    OutOfBounds {
85        rank: Rank,
86        offset: usize,
87        len: usize,
88        heap_size: usize,
89    },
90    /// `src.len() != buf.len`.
91    LengthMismatch { expected: usize, got: usize },
92    /// Unknown rank id.
93    UnknownRank { rank: Rank, num_ranks: u32 },
94    /// Underlying transport failed (network, mmap, etc.).
95    TransportError { reason: String },
96}
97
98impl std::fmt::Display for CollectiveError {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        match self {
101            Self::OutOfBounds {
102                rank,
103                offset,
104                len,
105                heap_size,
106            } => write!(
107                f,
108                "OOB on rank {}: offset {offset} + len {len} > heap_size {heap_size}",
109                rank.0
110            ),
111            Self::LengthMismatch { expected, got } => {
112                write!(f, "length mismatch: expected {expected}, got {got}")
113            }
114            Self::UnknownRank { rank, num_ranks } => {
115                write!(f, "unknown rank {} (have {num_ranks})", rank.0)
116            }
117            Self::TransportError { reason } => write!(f, "transport: {reason}"),
118        }
119    }
120}
121
122impl std::error::Error for CollectiveError {}
123
124/// Per-rank symmetric memory: a `Vec<u8>` per rank, all the same
125/// size. Owned by the [`LocalTransport`].
126#[derive(Debug)]
127pub struct SymmetricHeap {
128    storage: Vec<Arc<RwLock<Vec<u8>>>>,
129    pub heap_size: usize,
130}
131
132impl SymmetricHeap {
133    pub fn new(num_ranks: u32, heap_size: usize) -> Self {
134        let storage = (0..num_ranks)
135            .map(|_| Arc::new(RwLock::new(vec![0u8; heap_size])))
136            .collect();
137        Self { storage, heap_size }
138    }
139
140    pub fn num_ranks(&self) -> u32 {
141        self.storage.len() as u32
142    }
143
144    pub fn rank_view(&self, rank: Rank) -> Option<Arc<RwLock<Vec<u8>>>> {
145        self.storage.get(rank.0 as usize).cloned()
146    }
147}
148
149/// Single-machine in-process transport. All `num_ranks`
150/// "ranks" share one [`SymmetricHeap`] instance, so put / get
151/// are just locks + memcpy. Useful for unit tests and for
152/// algorithm-correctness checking of collective ops without a
153/// real cluster.
154#[derive(Debug, Clone)]
155pub struct LocalTransport {
156    heap: Arc<SymmetricHeap>,
157    me: Rank,
158    barrier_count: Arc<std::sync::atomic::AtomicU32>,
159    barrier_target: u32,
160}
161
162impl LocalTransport {
163    pub fn new(num_ranks: u32, heap_size: usize, this_rank: Rank) -> Self {
164        let heap = Arc::new(SymmetricHeap::new(num_ranks, heap_size));
165        Self::with_heap(heap, this_rank)
166    }
167
168    /// Construct multiple `LocalTransport`s sharing one heap —
169    /// `Vec` of length `num_ranks`, each with its own `me`.
170    /// Tests typically iterate this list to drive each rank.
171    pub fn fan_out(num_ranks: u32, heap_size: usize) -> Vec<Self> {
172        let heap = Arc::new(SymmetricHeap::new(num_ranks, heap_size));
173        (0..num_ranks)
174            .map(|i| Self::with_heap(heap.clone(), Rank(i)))
175            .collect()
176    }
177
178    fn with_heap(heap: Arc<SymmetricHeap>, me: Rank) -> Self {
179        let n = heap.num_ranks();
180        Self {
181            heap,
182            me,
183            barrier_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
184            barrier_target: n,
185        }
186    }
187
188    fn check_buf(&self, buf: SymmetricBuffer) -> Result<(), CollectiveError> {
189        if buf.rank.0 >= self.heap.num_ranks() {
190            return Err(CollectiveError::UnknownRank {
191                rank: buf.rank,
192                num_ranks: self.heap.num_ranks(),
193            });
194        }
195        if buf.offset + buf.len > self.heap.heap_size {
196            return Err(CollectiveError::OutOfBounds {
197                rank: buf.rank,
198                offset: buf.offset,
199                len: buf.len,
200                heap_size: self.heap.heap_size,
201            });
202        }
203        Ok(())
204    }
205}
206
207impl SymmetricTransport for LocalTransport {
208    fn num_ranks(&self) -> u32 {
209        self.heap.num_ranks()
210    }
211    fn this_rank(&self) -> Rank {
212        self.me
213    }
214
215    fn put(&self, buf: SymmetricBuffer, src: &[u8]) -> Result<(), CollectiveError> {
216        self.check_buf(buf)?;
217        if src.len() != buf.len {
218            return Err(CollectiveError::LengthMismatch {
219                expected: buf.len,
220                got: src.len(),
221            });
222        }
223        let view = self.heap.rank_view(buf.rank).expect("checked above");
224        let mut guard = view.write().unwrap();
225        guard[buf.offset..buf.offset + buf.len].copy_from_slice(src);
226        Ok(())
227    }
228
229    fn get(&self, buf: SymmetricBuffer, dst: &mut [u8]) -> Result<(), CollectiveError> {
230        self.check_buf(buf)?;
231        if dst.len() != buf.len {
232            return Err(CollectiveError::LengthMismatch {
233                expected: buf.len,
234                got: dst.len(),
235            });
236        }
237        let view = self.heap.rank_view(buf.rank).expect("checked above");
238        let guard = view.read().unwrap();
239        dst.copy_from_slice(&guard[buf.offset..buf.offset + buf.len]);
240        Ok(())
241    }
242
243    fn barrier(&self) -> Result<(), CollectiveError> {
244        // Each rank bumps the counter; spin until we observe
245        // num_ranks bumps, then move on. This isn't a "real"
246        // barrier (no rendezvous) — it's an arrival counter,
247        // sufficient for single-thread tests where each rank
248        // calls barrier in turn.
249        use std::sync::atomic::Ordering;
250        self.barrier_count.fetch_add(1, Ordering::AcqRel);
251        // For LocalTransport in single-thread tests this returns
252        // immediately; concurrent multi-thread tests can spin.
253        while self.barrier_count.load(Ordering::Acquire) < self.barrier_target {
254            std::hint::spin_loop();
255        }
256        Ok(())
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn put_then_get_round_trips() {
266        let t = LocalTransport::new(4, 1024, Rank(0));
267        let buf = SymmetricBuffer {
268            rank: Rank(2),
269            offset: 16,
270            len: 8,
271        };
272        t.put(buf, &[1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
273        let mut dst = [0u8; 8];
274        t.get(buf, &mut dst).unwrap();
275        assert_eq!(&dst, &[1, 2, 3, 4, 5, 6, 7, 8]);
276    }
277
278    #[test]
279    fn fan_out_yields_one_transport_per_rank() {
280        let ts = LocalTransport::fan_out(3, 64);
281        assert_eq!(ts.len(), 3);
282        for (i, t) in ts.iter().enumerate() {
283            assert_eq!(t.this_rank(), Rank(i as u32));
284            assert_eq!(t.num_ranks(), 3);
285        }
286    }
287
288    #[test]
289    fn put_visible_to_other_rank_via_shared_heap() {
290        // Rank 0 writes into rank 2's slot; rank 2 reads its
291        // own slot and sees the write.
292        let ts = LocalTransport::fan_out(3, 32);
293        let payload = [9u8, 9, 9, 9];
294        ts[0]
295            .put(
296                SymmetricBuffer {
297                    rank: Rank(2),
298                    offset: 0,
299                    len: 4,
300                },
301                &payload,
302            )
303            .unwrap();
304        let mut dst = [0u8; 4];
305        ts[2]
306            .get(
307                SymmetricBuffer {
308                    rank: Rank(2),
309                    offset: 0,
310                    len: 4,
311                },
312                &mut dst,
313            )
314            .unwrap();
315        assert_eq!(dst, payload);
316    }
317
318    #[test]
319    fn oob_offset_errors() {
320        let t = LocalTransport::new(2, 8, Rank(0));
321        let err = t
322            .put(
323                SymmetricBuffer {
324                    rank: Rank(1),
325                    offset: 4,
326                    len: 8,
327                },
328                &[0u8; 8],
329            )
330            .unwrap_err();
331        assert!(matches!(err, CollectiveError::OutOfBounds { .. }));
332    }
333
334    #[test]
335    fn unknown_rank_errors() {
336        let t = LocalTransport::new(2, 8, Rank(0));
337        let err = t
338            .get(
339                SymmetricBuffer {
340                    rank: Rank(99),
341                    offset: 0,
342                    len: 4,
343                },
344                &mut [0u8; 4],
345            )
346            .unwrap_err();
347        assert!(matches!(err, CollectiveError::UnknownRank { .. }));
348    }
349
350    #[test]
351    fn length_mismatch_errors() {
352        let t = LocalTransport::new(2, 32, Rank(0));
353        let err = t
354            .put(
355                SymmetricBuffer {
356                    rank: Rank(1),
357                    offset: 0,
358                    len: 4,
359                },
360                &[0u8; 8],
361            )
362            .unwrap_err();
363        assert!(matches!(err, CollectiveError::LengthMismatch { .. }));
364    }
365}