Skip to main content

rlx_wgpu/
buffer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Buffer arena for the wgpu backend. Mirrors the rlx-metal arena
17//! shape: pre-plan one big storage buffer at compile time, sub-allocate
18//! per-node offsets at known positions, treat I/O as `write_buffer` /
19//! `read_buffer` against those offsets.
20//!
21//! wgpu's storage buffers are fine for both reads and writes from
22//! compute shaders; there's no shared-memory requirement at the API
23//! level (unlike Metal where `StorageModeShared` matters). On Apple
24//! Silicon wgpu's Metal backend gives us unified memory automatically.
25
26use rlx_ir::{Graph, NodeId};
27use rlx_opt::memory::MemoryPlan;
28use std::collections::HashMap;
29
30/// Byte end (exclusive) of an f16 shadow write for a slot starting at
31/// `f32_byte_offset` with `f32_byte_len` bytes of f32 payload.
32/// wgpu requires `queue.write_buffer` sizes to be 4-byte aligned; odd
33/// f16 element counts are zero-padded by two bytes in `write_f32`.
34fn f16_shadow_write_end(f32_byte_offset: usize, f32_byte_len: usize) -> usize {
35    let f16_off = f32_byte_offset / 2;
36    let f16_bytes = (f32_byte_len / 4) * 2;
37    let padded = (f16_bytes + 3) & !3;
38    f16_off + padded
39}
40
41/// Size the f16 side buffer so every planned slot's padded upload fits.
42fn f16_shadow_arena_size(plan: &MemoryPlan) -> usize {
43    plan.assignments
44        .values()
45        .map(|a| f16_shadow_write_end(a.offset, a.size))
46        .max()
47        .unwrap_or(0)
48        .max(1)
49}
50
51/// One contiguous arena buffer + per-node byte offsets. Lives for the
52/// entire executable graph's lifetime.
53pub struct Arena {
54    /// Underlying GPU buffer. Bound as a single STORAGE_READ_WRITE
55    /// resource for every kernel; offsets disambiguate per-node access.
56    pub buffer: wgpu::Buffer,
57    /// Optional shadow buffer holding f16 versions of every value
58    /// written via `write_f32`. Sized at half the arena byte budget
59    /// (each f32 element pairs with an f16 element at the same logical
60    /// index — i.e. f16_off = f32_off / 2). Created only when the
61    /// device exposes the `SHADER_F16` feature; matmul kernels with
62    /// f16-typed B input bind both `buffer` (for f32 activations) and
63    /// `f16_buffer` (for f16 weights). Halves global memory traffic
64    /// on the dominant matmul reads.
65    pub f16_buffer: Option<wgpu::Buffer>,
66    /// Per-node byte offset into `buffer`.
67    pub offsets: HashMap<NodeId, usize>,
68    /// Per-node byte length.
69    pub lens: HashMap<NodeId, usize>,
70    /// Total arena size in bytes.
71    pub size: usize,
72}
73
74/// Plan memory using f32-sized slots regardless of declared IR dtype,
75/// with liveness-aware slot reuse (see `rlx_compile::memory::plan_memory_f32_uniform`).
76pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
77    rlx_compile::memory::plan_memory_f32_uniform(graph, align)
78}
79
80impl Arena {
81    /// Build an arena from a memory plan. Allocates one big buffer
82    /// sized to fit every node's offset+length.
83    pub fn from_plan(device: &wgpu::Device, plan: &MemoryPlan) -> Self {
84        let size = plan.arena_size.max(1); // wgpu hates zero-sized allocs
85        // WebGPU caps each storage binding at `max_storage_buffer_binding_size`
86        // (128 MiB on most adapters). Liveness-aware `plan_memory_f32_uniform`
87        // keeps typical UMAP `[n,n]` graphs under this; if not, fail early.
88        let max_binding = device.limits().max_storage_buffer_binding_size;
89        if (size as u64) > max_binding {
90            panic!(
91                "rlx-wgpu: planned arena size {} bytes ({:.3} GiB) exceeds \
92                    max_storage_buffer_binding_size {} bytes ({:.3} GiB). \
93                    Reduce batch/sequence size or split the graph.",
94                size,
95                size as f64 / (1u64 << 30) as f64,
96                max_binding,
97                max_binding as f64 / (1u64 << 30) as f64
98            );
99        }
100        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
101            label: Some("rlx-wgpu arena"),
102            size: size as u64,
103            usage: wgpu::BufferUsages::STORAGE
104                | wgpu::BufferUsages::COPY_SRC
105                | wgpu::BufferUsages::COPY_DST,
106            mapped_at_creation: false,
107        });
108        // Mirror f16 shadow buffer: half the byte size since each f32
109        // slot maps to an f16 slot at the same logical element index.
110        // Add per-slot COPY_BUFFER_ALIGNMENT padding (see
111        // `f16_shadow_write_end`).
112        let f16_buffer = if device.features().contains(wgpu::Features::SHADER_F16) {
113            let f16_size = f16_shadow_arena_size(plan);
114            Some(device.create_buffer(&wgpu::BufferDescriptor {
115                label: Some("rlx-wgpu arena f16"),
116                size: f16_size as u64,
117                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
118                mapped_at_creation: false,
119            }))
120        } else {
121            None
122        };
123        // `offsets` map to slot start (16-byte aligned). `lens` map to
124        // ACTUAL data length (elems * 4) — distinct from the slot size,
125        // which may include alignment padding. Readback uses lens so a
126        // [5] f32 returns 5 elements, not the 8 that fit in a 32-byte
127        // padded slot.
128        let mut offsets = HashMap::with_capacity(plan.assignments.len());
129        let mut lens = HashMap::with_capacity(plan.assignments.len());
130        for (id, a) in &plan.assignments {
131            offsets.insert(*id, a.offset);
132            // Default to the slot size; backends may override via
133            // set_actual_len for nodes whose elem count differs.
134            lens.insert(*id, a.size);
135        }
136        Self {
137            buffer,
138            f16_buffer,
139            offsets,
140            lens,
141            size,
142        }
143    }
144
145    pub fn has(&self, id: NodeId) -> bool {
146        self.offsets.contains_key(&id)
147    }
148    pub fn offset(&self, id: NodeId) -> usize {
149        self.offsets[&id]
150    }
151    pub fn len_of(&self, id: NodeId) -> usize {
152        self.lens[&id]
153    }
154
155    /// Override the actual data length (in bytes) for a node. The
156    /// backend calls this after planning to record true elem*4 sizes
157    /// instead of the alignment-padded slot sizes.
158    pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
159        self.lens.insert(id, bytes);
160    }
161
162    /// Write f32 data into the node's slot. The queue performs an
163    /// async transfer; subsequent kernel dispatches on the same queue
164    /// see the new bytes. When the device supports SHADER_F16, also
165    /// downcasts and writes the same data into the f16 shadow buffer
166    /// at offset `f32_offset / 2` — so matmul kernels with f16 weight
167    /// bindings can read directly from there at half the bandwidth.
168    pub fn write_f32(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
169        let off = self.offset(id);
170        let bytes: &[u8] = bytemuck::cast_slice(data);
171        queue.write_buffer(&self.buffer, off as u64, bytes);
172        if let Some(f16_buf) = &self.f16_buffer {
173            // wgpu requires queue.write_buffer to use 4-byte-aligned
174            // sizes (`COPY_BUFFER_ALIGNMENT`). f16 is 2 bytes; an odd
175            // element count yields a non-aligned byte length. Pad with
176            // a zero half so the byte count is always even.
177            let mut f16_data: Vec<half::f16> =
178                data.iter().map(|&v| half::f16::from_f32(v)).collect();
179            if !f16_data.len().is_multiple_of(2) {
180                f16_data.push(half::f16::from_f32(0.0));
181            }
182            let f16_bytes: &[u8] = unsafe {
183                std::slice::from_raw_parts(f16_data.as_ptr() as *const u8, f16_data.len() * 2)
184            };
185            queue.write_buffer(f16_buf, (off / 2) as u64, f16_bytes);
186        }
187    }
188
189    /// Read a node's bytes back to host f32 via a staging buffer +
190    /// blocking map. Used by `run()` for output extraction.
191    pub fn read_f32(&self, device: &wgpu::Device, queue: &wgpu::Queue, id: NodeId) -> Vec<f32> {
192        let off = self.offset(id);
193        let len = self.len_of(id);
194        let n_elems = len / 4;
195        if n_elems == 0 {
196            return Vec::new();
197        }
198
199        let staging = device.create_buffer(&wgpu::BufferDescriptor {
200            label: Some("rlx-wgpu readback"),
201            size: len as u64,
202            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
203            mapped_at_creation: false,
204        });
205        let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
206            label: Some("rlx-wgpu readback enc"),
207        });
208        enc.copy_buffer_to_buffer(&self.buffer, off as u64, &staging, 0, len as u64);
209        queue.submit(std::iter::once(enc.finish()));
210
211        let slice = staging.slice(..);
212        let (sender, receiver) = std::sync::mpsc::channel();
213        slice.map_async(wgpu::MapMode::Read, move |r| {
214            let _ = sender.send(r);
215        });
216        let _ = device.poll(wgpu::PollType::wait_indefinitely());
217        receiver.recv().unwrap().unwrap();
218
219        let view = slice.get_mapped_range();
220        let out: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&view).to_vec();
221        drop(view);
222        staging.unmap();
223        out
224    }
225
226    /// Read a byte range from the arena (used for packed GGUF weights).
227    pub fn read_bytes_range(
228        &self,
229        device: &wgpu::Device,
230        queue: &wgpu::Queue,
231        byte_off: usize,
232        len: usize,
233    ) -> Vec<u8> {
234        if len == 0 {
235            return Vec::new();
236        }
237        let staging = device.create_buffer(&wgpu::BufferDescriptor {
238            label: Some("rlx-wgpu readback bytes"),
239            size: len as u64,
240            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
241            mapped_at_creation: false,
242        });
243        let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
244            label: Some("rlx-wgpu readback bytes enc"),
245        });
246        enc.copy_buffer_to_buffer(&self.buffer, byte_off as u64, &staging, 0, len as u64);
247        queue.submit(std::iter::once(enc.finish()));
248
249        let slice = staging.slice(..);
250        let (sender, receiver) = std::sync::mpsc::channel();
251        slice.map_async(wgpu::MapMode::Read, move |r| {
252            let _ = sender.send(r);
253        });
254        let _ = device.poll(wgpu::PollType::wait_indefinitely());
255        receiver.recv().unwrap().unwrap();
256
257        let view = slice.get_mapped_range();
258        let out = view.to_vec();
259        drop(view);
260        staging.unmap();
261        out
262    }
263
264    /// Write raw bytes into the arena at `byte_off`.
265    pub fn write_bytes_range(&self, queue: &wgpu::Queue, byte_off: usize, data: &[u8]) {
266        if data.is_empty() {
267            return;
268        }
269        queue.write_buffer(&self.buffer, byte_off as u64, data);
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use rlx_ir::NodeId;
277    use rlx_opt::memory::{BufferSlot, MemoryPlan};
278    use std::collections::HashMap;
279
280    #[test]
281    fn f16_shadow_arena_accounts_for_copy_alignment_padding() {
282        // Three f32 elements → six f16 bytes, padded to eight for wgpu
283        // COPY_BUFFER_ALIGNMENT. The old `arena_size / 2` sizing was two
284        // bytes short at this slot boundary.
285        let mut assignments = HashMap::new();
286        assignments.insert(
287            NodeId(0),
288            BufferSlot {
289                offset: 32,
290                size: 12,
291            },
292        );
293        let plan = MemoryPlan {
294            arena_size: 44,
295            assignments,
296            schedule: vec![],
297        };
298        assert_eq!(f16_shadow_arena_size(&plan), 24);
299    }
300}