rlx_wgpu/buffer.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Buffer arena for the wgpu backend. Mirrors the rlx-metal arena
17//! shape: pre-plan one big storage buffer at compile time, sub-allocate
18//! per-node offsets at known positions, treat I/O as `write_buffer` /
19//! `read_buffer` against those offsets.
20//!
21//! wgpu's storage buffers are fine for both reads and writes from
22//! compute shaders; there's no shared-memory requirement at the API
23//! level (unlike Metal where `StorageModeShared` matters). On Apple
24//! Silicon wgpu's Metal backend gives us unified memory automatically.
25
26use rlx_ir::{Graph, NodeId};
27use rlx_opt::memory::MemoryPlan;
28use std::collections::HashMap;
29
30/// Byte end (exclusive) of an f16 shadow write for a slot starting at
31/// `f32_byte_offset` with `f32_byte_len` bytes of f32 payload.
32/// wgpu requires `queue.write_buffer` sizes to be 4-byte aligned; odd
33/// f16 element counts are zero-padded by two bytes in `write_f32`.
34fn f16_shadow_write_end(f32_byte_offset: usize, f32_byte_len: usize) -> usize {
35 let f16_off = f32_byte_offset / 2;
36 let f16_bytes = (f32_byte_len / 4) * 2;
37 let padded = (f16_bytes + 3) & !3;
38 f16_off + padded
39}
40
41/// Size the f16 side buffer so every planned slot's padded upload fits.
42fn f16_shadow_arena_size(plan: &MemoryPlan) -> usize {
43 plan.assignments
44 .values()
45 .map(|a| f16_shadow_write_end(a.offset, a.size))
46 .max()
47 .unwrap_or(0)
48 .max(1)
49}
50
51/// One contiguous arena buffer + per-node byte offsets. Lives for the
52/// entire executable graph's lifetime.
53pub struct Arena {
54 /// Underlying GPU buffer. Bound as a single STORAGE_READ_WRITE
55 /// resource for every kernel; offsets disambiguate per-node access.
56 pub buffer: wgpu::Buffer,
57 /// Optional shadow buffer holding f16 versions of every value
58 /// written via `write_f32`. Sized at half the arena byte budget
59 /// (each f32 element pairs with an f16 element at the same logical
60 /// index — i.e. f16_off = f32_off / 2). Created only when the
61 /// device exposes the `SHADER_F16` feature; matmul kernels with
62 /// f16-typed B input bind both `buffer` (for f32 activations) and
63 /// `f16_buffer` (for f16 weights). Halves global memory traffic
64 /// on the dominant matmul reads.
65 pub f16_buffer: Option<wgpu::Buffer>,
66 /// Per-node byte offset into `buffer`.
67 pub offsets: HashMap<NodeId, usize>,
68 /// Per-node byte length.
69 pub lens: HashMap<NodeId, usize>,
70 /// Total arena size in bytes.
71 pub size: usize,
72}
73
74/// Plan memory using f32-sized slots regardless of declared IR dtype,
75/// with liveness-aware slot reuse (see `rlx_compile::memory::plan_memory_f32_uniform`).
76pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
77 rlx_compile::memory::plan_memory_f32_uniform(graph, align)
78}
79
80impl Arena {
81 /// Build an arena from a memory plan. Allocates one big buffer
82 /// sized to fit every node's offset+length.
83 pub fn from_plan(device: &wgpu::Device, plan: &MemoryPlan) -> Self {
84 let size = plan.arena_size.max(1); // wgpu hates zero-sized allocs
85 // WebGPU caps each storage binding at `max_storage_buffer_binding_size`
86 // (128 MiB on most adapters). Liveness-aware `plan_memory_f32_uniform`
87 // keeps typical UMAP `[n,n]` graphs under this; if not, fail early.
88 let max_binding = device.limits().max_storage_buffer_binding_size;
89 if (size as u64) > max_binding {
90 panic!(
91 "rlx-wgpu: planned arena size {} bytes ({:.3} GiB) exceeds \
92 max_storage_buffer_binding_size {} bytes ({:.3} GiB). \
93 Reduce batch/sequence size or split the graph.",
94 size,
95 size as f64 / (1u64 << 30) as f64,
96 max_binding,
97 max_binding as f64 / (1u64 << 30) as f64
98 );
99 }
100 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
101 label: Some("rlx-wgpu arena"),
102 size: size as u64,
103 usage: wgpu::BufferUsages::STORAGE
104 | wgpu::BufferUsages::COPY_SRC
105 | wgpu::BufferUsages::COPY_DST,
106 mapped_at_creation: false,
107 });
108 // Mirror f16 shadow buffer: half the byte size since each f32
109 // slot maps to an f16 slot at the same logical element index.
110 // Add per-slot COPY_BUFFER_ALIGNMENT padding (see
111 // `f16_shadow_write_end`).
112 let f16_buffer = if device.features().contains(wgpu::Features::SHADER_F16) {
113 let f16_size = f16_shadow_arena_size(plan);
114 Some(device.create_buffer(&wgpu::BufferDescriptor {
115 label: Some("rlx-wgpu arena f16"),
116 size: f16_size as u64,
117 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
118 mapped_at_creation: false,
119 }))
120 } else {
121 None
122 };
123 // `offsets` map to slot start (16-byte aligned). `lens` map to
124 // ACTUAL data length (elems * 4) — distinct from the slot size,
125 // which may include alignment padding. Readback uses lens so a
126 // [5] f32 returns 5 elements, not the 8 that fit in a 32-byte
127 // padded slot.
128 let mut offsets = HashMap::with_capacity(plan.assignments.len());
129 let mut lens = HashMap::with_capacity(plan.assignments.len());
130 for (id, a) in &plan.assignments {
131 offsets.insert(*id, a.offset);
132 // Default to the slot size; backends may override via
133 // set_actual_len for nodes whose elem count differs.
134 lens.insert(*id, a.size);
135 }
136 Self {
137 buffer,
138 f16_buffer,
139 offsets,
140 lens,
141 size,
142 }
143 }
144
145 pub fn has(&self, id: NodeId) -> bool {
146 self.offsets.contains_key(&id)
147 }
148 pub fn offset(&self, id: NodeId) -> usize {
149 self.offsets[&id]
150 }
151 pub fn len_of(&self, id: NodeId) -> usize {
152 self.lens[&id]
153 }
154
155 /// Override the actual data length (in bytes) for a node. The
156 /// backend calls this after planning to record true elem*4 sizes
157 /// instead of the alignment-padded slot sizes.
158 pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
159 self.lens.insert(id, bytes);
160 }
161
162 /// Write f32 data into the node's slot. The queue performs an
163 /// async transfer; subsequent kernel dispatches on the same queue
164 /// see the new bytes. When the device supports SHADER_F16, also
165 /// downcasts and writes the same data into the f16 shadow buffer
166 /// at offset `f32_offset / 2` — so matmul kernels with f16 weight
167 /// bindings can read directly from there at half the bandwidth.
168 pub fn write_f32(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
169 let off = self.offset(id);
170 let bytes: &[u8] = bytemuck::cast_slice(data);
171 queue.write_buffer(&self.buffer, off as u64, bytes);
172 if let Some(f16_buf) = &self.f16_buffer {
173 // wgpu requires queue.write_buffer to use 4-byte-aligned
174 // sizes (`COPY_BUFFER_ALIGNMENT`). f16 is 2 bytes; an odd
175 // element count yields a non-aligned byte length. Pad with
176 // a zero half so the byte count is always even.
177 let mut f16_data: Vec<half::f16> =
178 data.iter().map(|&v| half::f16::from_f32(v)).collect();
179 if !f16_data.len().is_multiple_of(2) {
180 f16_data.push(half::f16::from_f32(0.0));
181 }
182 let f16_bytes: &[u8] = unsafe {
183 std::slice::from_raw_parts(f16_data.as_ptr() as *const u8, f16_data.len() * 2)
184 };
185 queue.write_buffer(f16_buf, (off / 2) as u64, f16_bytes);
186 }
187 }
188
189 /// Read a node's bytes back to host f32 via a staging buffer +
190 /// blocking map. Used by `run()` for output extraction.
191 pub fn read_f32(&self, device: &wgpu::Device, queue: &wgpu::Queue, id: NodeId) -> Vec<f32> {
192 let off = self.offset(id);
193 let len = self.len_of(id);
194 let n_elems = len / 4;
195 if n_elems == 0 {
196 return Vec::new();
197 }
198
199 let staging = device.create_buffer(&wgpu::BufferDescriptor {
200 label: Some("rlx-wgpu readback"),
201 size: len as u64,
202 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
203 mapped_at_creation: false,
204 });
205 let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
206 label: Some("rlx-wgpu readback enc"),
207 });
208 enc.copy_buffer_to_buffer(&self.buffer, off as u64, &staging, 0, len as u64);
209 queue.submit(std::iter::once(enc.finish()));
210
211 let slice = staging.slice(..);
212 let (sender, receiver) = std::sync::mpsc::channel();
213 slice.map_async(wgpu::MapMode::Read, move |r| {
214 let _ = sender.send(r);
215 });
216 let _ = device.poll(wgpu::PollType::wait_indefinitely());
217 receiver.recv().unwrap().unwrap();
218
219 let view = slice.get_mapped_range();
220 let out: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&view).to_vec();
221 drop(view);
222 staging.unmap();
223 out
224 }
225
226 /// Read a byte range from the arena (used for packed GGUF weights).
227 pub fn read_bytes_range(
228 &self,
229 device: &wgpu::Device,
230 queue: &wgpu::Queue,
231 byte_off: usize,
232 len: usize,
233 ) -> Vec<u8> {
234 if len == 0 {
235 return Vec::new();
236 }
237 let staging = device.create_buffer(&wgpu::BufferDescriptor {
238 label: Some("rlx-wgpu readback bytes"),
239 size: len as u64,
240 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
241 mapped_at_creation: false,
242 });
243 let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
244 label: Some("rlx-wgpu readback bytes enc"),
245 });
246 enc.copy_buffer_to_buffer(&self.buffer, byte_off as u64, &staging, 0, len as u64);
247 queue.submit(std::iter::once(enc.finish()));
248
249 let slice = staging.slice(..);
250 let (sender, receiver) = std::sync::mpsc::channel();
251 slice.map_async(wgpu::MapMode::Read, move |r| {
252 let _ = sender.send(r);
253 });
254 let _ = device.poll(wgpu::PollType::wait_indefinitely());
255 receiver.recv().unwrap().unwrap();
256
257 let view = slice.get_mapped_range();
258 let out = view.to_vec();
259 drop(view);
260 staging.unmap();
261 out
262 }
263
264 /// Write raw bytes into the arena at `byte_off`.
265 pub fn write_bytes_range(&self, queue: &wgpu::Queue, byte_off: usize, data: &[u8]) {
266 if data.is_empty() {
267 return;
268 }
269 queue.write_buffer(&self.buffer, byte_off as u64, data);
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use rlx_ir::NodeId;
277 use rlx_opt::memory::{BufferSlot, MemoryPlan};
278 use std::collections::HashMap;
279
280 #[test]
281 fn f16_shadow_arena_accounts_for_copy_alignment_padding() {
282 // Three f32 elements → six f16 bytes, padded to eight for wgpu
283 // COPY_BUFFER_ALIGNMENT. The old `arena_size / 2` sizing was two
284 // bytes short at this slot boundary.
285 let mut assignments = HashMap::new();
286 assignments.insert(
287 NodeId(0),
288 BufferSlot {
289 offset: 32,
290 size: 12,
291 },
292 );
293 let plan = MemoryPlan {
294 arena_size: 44,
295 assignments,
296 schedule: vec![],
297 };
298 assert_eq!(f16_shadow_arena_size(&plan), 24);
299 }
300}