Skip to main content

rlx_ir/
async_copy.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Async tile-copy + double-buffer primitives (plan #22).
17//!
18//! Borrowed from MAX's
19//! `layout/{tma_async, tensor_core_async}.mojo` +
20//! `structured_kernels/{pipeline, pipeline_storage, barriers}.mojo`.
21//! On NVIDIA the equivalent is TMA (Tensor Memory Accelerator);
22//! on Apple Silicon there's no direct analog because the GPU and
23//! CPU share a unified memory pool — but the *pipelining* idea
24//! still pays off: while shader N runs on tile N, you issue an
25//! async copy / blit for tile N+1 and let the two overlap.
26//!
27//! The shape this module exposes:
28//!
29//!   - [`DoubleBuffer<T>`] — owns two `T` instances with a `swap`
30//!     pointer; `current()` is what compute reads, `next_mut()` is
31//!     where the async copy lands.
32//!   - [`AsyncCopy`] trait — `issue()` schedules a copy and returns
33//!     a [`BarrierToken`]; `wait()` blocks until the matching
34//!     issue has completed.
35//!   - [`SyncCopy`] — the CPU implementation: every issue is a
36//!     memcpy + a fresh token; `wait()` is a no-op (the copy
37//!     already completed). Sufficient for unit tests and for
38//!     bench harnesses that run the pipeline pattern with no
39//!     real overlap.
40//!
41//! A future Metal impl plugs in via the same trait. The Metal
42//! version would issue a `MTLBlitCommandEncoder.copy(...)` on a
43//! distinct command queue and signal an `MTLEvent` for `wait()`.
44
45use std::sync::atomic::{AtomicU64, Ordering};
46
47/// Opaque ticket returned by [`AsyncCopy::issue`]. Pass back to
48/// [`AsyncCopy::wait`] to block until the corresponding copy is
49/// done. Tokens are scoped to one engine — don't pass them across.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub struct BarrierToken(pub u64);
52
53/// Pluggable async-copy engine. Backends (`SyncCopy` for CPU,
54/// future `MetalBlitCopy` for GPU) implement this.
55pub trait AsyncCopy {
56    /// Schedule a `bytes`-byte copy from `src` to `dst`. Returns a
57    /// token usable with [`Self::wait`].
58    /// # Safety
59    /// `src` valid for read, `dst` valid for write, `bytes` doesn't
60    /// overflow either region. Caller ensures `src` and `dst` don't
61    /// alias unless that's intentional.
62    unsafe fn issue(&mut self, src: *const u8, dst: *mut u8, bytes: usize) -> BarrierToken;
63
64    /// Block until the copy referred to by `token` has completed.
65    fn wait(&mut self, token: BarrierToken);
66}
67
68/// CPU "async" copy — actually synchronous. `issue()` does a
69/// `memcpy` immediately and returns a fresh token; `wait()` is a
70/// no-op. Useful as the test fixture and for code paths that
71/// don't actually need overlap.
72pub struct SyncCopy {
73    counter: AtomicU64,
74}
75
76impl SyncCopy {
77    pub const fn new() -> Self {
78        Self {
79            counter: AtomicU64::new(0),
80        }
81    }
82}
83
84impl Default for SyncCopy {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl AsyncCopy for SyncCopy {
91    unsafe fn issue(&mut self, src: *const u8, dst: *mut u8, bytes: usize) -> BarrierToken {
92        unsafe {
93            std::ptr::copy_nonoverlapping(src, dst, bytes);
94        }
95        BarrierToken(self.counter.fetch_add(1, Ordering::Relaxed))
96    }
97
98    fn wait(&mut self, _token: BarrierToken) {
99        // Sync copy: already done at issue() time.
100    }
101}
102
103/// Two-buffer ring. `current()` is what compute reads this step;
104/// `next_mut()` is where the *next* async copy should land. Call
105/// `swap()` after waiting on the current copy to advance.
106#[derive(Debug, Clone)]
107pub struct DoubleBuffer<T> {
108    buffers: [T; 2],
109    active: usize,
110}
111
112impl<T> DoubleBuffer<T> {
113    pub fn new(a: T, b: T) -> Self {
114        Self {
115            buffers: [a, b],
116            active: 0,
117        }
118    }
119
120    pub fn current(&self) -> &T {
121        &self.buffers[self.active]
122    }
123    pub fn current_mut(&mut self) -> &mut T {
124        &mut self.buffers[self.active]
125    }
126
127    pub fn next(&self) -> &T {
128        &self.buffers[1 - self.active]
129    }
130    pub fn next_mut(&mut self) -> &mut T {
131        &mut self.buffers[1 - self.active]
132    }
133
134    /// Flip which buffer is current. Typical pattern:
135    /// ```text
136    /// // At step k:
137    /// engine.wait(prev_token);          // copy of tile-k done
138    /// let token_for_kp1 = engine.issue(src_kp1, double.next_mut(), bytes);
139    /// compute(double.current());        // shader runs on tile-k
140    /// double.swap();                    // tile-(k+1) becomes current
141    /// // → at step k+1, wait(token_for_kp1) etc.
142    /// ```
143    pub fn swap(&mut self) {
144        self.active = 1 - self.active;
145    }
146
147    /// Both buffers' shared length, when `T = Vec<u8>` / `Vec<f32>`.
148    /// Exposed for symmetry; many callers don't need it.
149    pub fn pair(&self) -> (&T, &T) {
150        (&self.buffers[0], &self.buffers[1])
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn double_buffer_swap_round_trip() {
160        let mut db = DoubleBuffer::new(vec![1u8; 4], vec![2u8; 4]);
161        assert_eq!(db.current(), &vec![1u8; 4]);
162        db.swap();
163        assert_eq!(db.current(), &vec![2u8; 4]);
164        db.swap();
165        assert_eq!(db.current(), &vec![1u8; 4]);
166    }
167
168    #[test]
169    fn sync_copy_round_trips_data() {
170        let src = [1u8, 2, 3, 4];
171        let mut dst = [0u8; 4];
172        let mut engine = SyncCopy::new();
173        let token = unsafe { engine.issue(src.as_ptr(), dst.as_mut_ptr(), 4) };
174        engine.wait(token);
175        assert_eq!(dst, src);
176    }
177
178    #[test]
179    fn pipelined_pattern_through_double_buffer() {
180        // Simulate the canonical compute-overlap-copy loop:
181        //   tile 0..N comes in two halves [0..2] and [2..N]; the
182        //   compute step is "sum of the buffer".
183        let source: Vec<u8> = (0..16u8).collect();
184        let tile_bytes = 4;
185        let mut db = DoubleBuffer::new(vec![0u8; tile_bytes], vec![0u8; tile_bytes]);
186        let mut engine = SyncCopy::new();
187
188        // Prime: load tile 0 into the *current* slot.
189        let t0 =
190            unsafe { engine.issue(source.as_ptr(), db.current_mut().as_mut_ptr(), tile_bytes) };
191        engine.wait(t0);
192
193        let mut total: u64 = 0;
194        let mut tile_idx = 1usize;
195        while tile_idx * tile_bytes < source.len() {
196            // Issue copy for next tile into the inactive slot.
197            let t = unsafe {
198                engine.issue(
199                    source.as_ptr().add(tile_idx * tile_bytes),
200                    db.next_mut().as_mut_ptr(),
201                    tile_bytes,
202                )
203            };
204            // Compute on the current tile.
205            total += db.current().iter().map(|&b| b as u64).sum::<u64>();
206            // Step boundary.
207            engine.wait(t);
208            db.swap();
209            tile_idx += 1;
210        }
211        // Drain the last tile.
212        total += db.current().iter().map(|&b| b as u64).sum::<u64>();
213
214        // Sum of 0..16 = 120.
215        let expected: u64 = (0..16u64).sum();
216        assert_eq!(total, expected);
217    }
218}