rlx_ir/async_copy.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Async tile-copy + double-buffer primitives (plan #22).
17//!
18//! Borrowed from MAX's
19//! `layout/{tma_async, tensor_core_async}.mojo` +
20//! `structured_kernels/{pipeline, pipeline_storage, barriers}.mojo`.
21//! On NVIDIA the equivalent is TMA (Tensor Memory Accelerator);
22//! on Apple Silicon there's no direct analog because the GPU and
23//! CPU share a unified memory pool — but the *pipelining* idea
24//! still pays off: while shader N runs on tile N, you issue an
25//! async copy / blit for tile N+1 and let the two overlap.
26//!
27//! The shape this module exposes:
28//!
29//! - [`DoubleBuffer<T>`] — owns two `T` instances with a `swap`
30//! pointer; `current()` is what compute reads, `next_mut()` is
31//! where the async copy lands.
32//! - [`AsyncCopy`] trait — `issue()` schedules a copy and returns
33//! a [`BarrierToken`]; `wait()` blocks until the matching
34//! issue has completed.
35//! - [`SyncCopy`] — the CPU implementation: every issue is a
36//! memcpy + a fresh token; `wait()` is a no-op (the copy
37//! already completed). Sufficient for unit tests and for
38//! bench harnesses that run the pipeline pattern with no
39//! real overlap.
40//!
41//! A future Metal impl plugs in via the same trait. The Metal
42//! version would issue a `MTLBlitCommandEncoder.copy(...)` on a
43//! distinct command queue and signal an `MTLEvent` for `wait()`.
44
45use std::sync::atomic::{AtomicU64, Ordering};
46
47/// Opaque ticket returned by [`AsyncCopy::issue`]. Pass back to
48/// [`AsyncCopy::wait`] to block until the corresponding copy is
49/// done. Tokens are scoped to one engine — don't pass them across.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub struct BarrierToken(pub u64);
52
53/// Pluggable async-copy engine. Backends (`SyncCopy` for CPU,
54/// future `MetalBlitCopy` for GPU) implement this.
55pub trait AsyncCopy {
56 /// Schedule a `bytes`-byte copy from `src` to `dst`. Returns a
57 /// token usable with [`Self::wait`].
58 /// # Safety
59 /// `src` valid for read, `dst` valid for write, `bytes` doesn't
60 /// overflow either region. Caller ensures `src` and `dst` don't
61 /// alias unless that's intentional.
62 unsafe fn issue(&mut self, src: *const u8, dst: *mut u8, bytes: usize) -> BarrierToken;
63
64 /// Block until the copy referred to by `token` has completed.
65 fn wait(&mut self, token: BarrierToken);
66}
67
68/// CPU "async" copy — actually synchronous. `issue()` does a
69/// `memcpy` immediately and returns a fresh token; `wait()` is a
70/// no-op. Useful as the test fixture and for code paths that
71/// don't actually need overlap.
72pub struct SyncCopy {
73 counter: AtomicU64,
74}
75
76impl SyncCopy {
77 pub const fn new() -> Self {
78 Self {
79 counter: AtomicU64::new(0),
80 }
81 }
82}
83
84impl Default for SyncCopy {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl AsyncCopy for SyncCopy {
91 unsafe fn issue(&mut self, src: *const u8, dst: *mut u8, bytes: usize) -> BarrierToken {
92 unsafe {
93 std::ptr::copy_nonoverlapping(src, dst, bytes);
94 }
95 BarrierToken(self.counter.fetch_add(1, Ordering::Relaxed))
96 }
97
98 fn wait(&mut self, _token: BarrierToken) {
99 // Sync copy: already done at issue() time.
100 }
101}
102
103/// Two-buffer ring. `current()` is what compute reads this step;
104/// `next_mut()` is where the *next* async copy should land. Call
105/// `swap()` after waiting on the current copy to advance.
106#[derive(Debug, Clone)]
107pub struct DoubleBuffer<T> {
108 buffers: [T; 2],
109 active: usize,
110}
111
112impl<T> DoubleBuffer<T> {
113 pub fn new(a: T, b: T) -> Self {
114 Self {
115 buffers: [a, b],
116 active: 0,
117 }
118 }
119
120 pub fn current(&self) -> &T {
121 &self.buffers[self.active]
122 }
123 pub fn current_mut(&mut self) -> &mut T {
124 &mut self.buffers[self.active]
125 }
126
127 pub fn next(&self) -> &T {
128 &self.buffers[1 - self.active]
129 }
130 pub fn next_mut(&mut self) -> &mut T {
131 &mut self.buffers[1 - self.active]
132 }
133
134 /// Flip which buffer is current. Typical pattern:
135 /// ```text
136 /// // At step k:
137 /// engine.wait(prev_token); // copy of tile-k done
138 /// let token_for_kp1 = engine.issue(src_kp1, double.next_mut(), bytes);
139 /// compute(double.current()); // shader runs on tile-k
140 /// double.swap(); // tile-(k+1) becomes current
141 /// // → at step k+1, wait(token_for_kp1) etc.
142 /// ```
143 pub fn swap(&mut self) {
144 self.active = 1 - self.active;
145 }
146
147 /// Both buffers' shared length, when `T = Vec<u8>` / `Vec<f32>`.
148 /// Exposed for symmetry; many callers don't need it.
149 pub fn pair(&self) -> (&T, &T) {
150 (&self.buffers[0], &self.buffers[1])
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn double_buffer_swap_round_trip() {
160 let mut db = DoubleBuffer::new(vec![1u8; 4], vec![2u8; 4]);
161 assert_eq!(db.current(), &vec![1u8; 4]);
162 db.swap();
163 assert_eq!(db.current(), &vec![2u8; 4]);
164 db.swap();
165 assert_eq!(db.current(), &vec![1u8; 4]);
166 }
167
168 #[test]
169 fn sync_copy_round_trips_data() {
170 let src = [1u8, 2, 3, 4];
171 let mut dst = [0u8; 4];
172 let mut engine = SyncCopy::new();
173 let token = unsafe { engine.issue(src.as_ptr(), dst.as_mut_ptr(), 4) };
174 engine.wait(token);
175 assert_eq!(dst, src);
176 }
177
178 #[test]
179 fn pipelined_pattern_through_double_buffer() {
180 // Simulate the canonical compute-overlap-copy loop:
181 // tile 0..N comes in two halves [0..2] and [2..N]; the
182 // compute step is "sum of the buffer".
183 let source: Vec<u8> = (0..16u8).collect();
184 let tile_bytes = 4;
185 let mut db = DoubleBuffer::new(vec![0u8; tile_bytes], vec![0u8; tile_bytes]);
186 let mut engine = SyncCopy::new();
187
188 // Prime: load tile 0 into the *current* slot.
189 let t0 =
190 unsafe { engine.issue(source.as_ptr(), db.current_mut().as_mut_ptr(), tile_bytes) };
191 engine.wait(t0);
192
193 let mut total: u64 = 0;
194 let mut tile_idx = 1usize;
195 while tile_idx * tile_bytes < source.len() {
196 // Issue copy for next tile into the inactive slot.
197 let t = unsafe {
198 engine.issue(
199 source.as_ptr().add(tile_idx * tile_bytes),
200 db.next_mut().as_mut_ptr(),
201 tile_bytes,
202 )
203 };
204 // Compute on the current tile.
205 total += db.current().iter().map(|&b| b as u64).sum::<u64>();
206 // Step boundary.
207 engine.wait(t);
208 db.swap();
209 tile_idx += 1;
210 }
211 // Drain the last tile.
212 total += db.current().iter().map(|&b| b as u64).sum::<u64>();
213
214 // Sum of 0..16 = 120.
215 let expected: u64 = (0..16u64).sum();
216 assert_eq!(total, expected);
217 }
218}