Skip to main content

rlx_wgpu/
buffer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Buffer arena for the wgpu backend. Mirrors the rlx-metal arena
17//! shape: pre-plan one big storage buffer at compile time, sub-allocate
18//! per-node offsets at known positions, treat I/O as `write_buffer` /
19//! `read_buffer` against those offsets.
20//!
21//! wgpu's storage buffers are fine for both reads and writes from
22//! compute shaders; there's no shared-memory requirement at the API
23//! level (unlike Metal where `StorageModeShared` matters). On Apple
24//! Silicon wgpu's Metal backend gives us unified memory automatically.
25
26use rlx_ir::{Graph, NodeId};
27use rlx_opt::memory::MemoryPlan;
28use std::collections::HashMap;
29
30/// Byte end (exclusive) of an f16 shadow write for a slot starting at
31/// `f32_byte_offset` with `f32_byte_len` bytes of f32 payload.
32/// wgpu requires `queue.write_buffer` sizes to be 4-byte aligned; odd
33/// f16 element counts are zero-padded by two bytes in `write_f32`.
34fn f16_shadow_write_end(f32_byte_offset: usize, f32_byte_len: usize) -> usize {
35    let f16_off = f32_byte_offset / 2;
36    let f16_bytes = (f32_byte_len / 4) * 2;
37    let padded = (f16_bytes + 3) & !3;
38    f16_off + padded
39}
40
41/// Size the f16 side buffer so every planned slot's padded upload fits.
42fn f16_shadow_arena_size(plan: &MemoryPlan) -> usize {
43    plan.assignments
44        .values()
45        .map(|a| f16_shadow_write_end(a.offset, a.size))
46        .max()
47        .unwrap_or(0)
48        .max(1)
49}
50
51/// One contiguous arena buffer + per-node byte offsets. Lives for the
52/// entire executable graph's lifetime.
53pub struct Arena {
54    /// Underlying GPU buffer. Bound as a single STORAGE_READ_WRITE
55    /// resource for every kernel; offsets disambiguate per-node access.
56    pub buffer: wgpu::Buffer,
57    /// Optional shadow buffer holding f16 versions of every value
58    /// written via `write_f32`. Sized at half the arena byte budget
59    /// (each f32 element pairs with an f16 element at the same logical
60    /// index — i.e. f16_off = f32_off / 2). Created only when the
61    /// device exposes the `SHADER_F16` feature; matmul kernels with
62    /// f16-typed B input bind both `buffer` (for f32 activations) and
63    /// `f16_buffer` (for f16 weights). Halves global memory traffic
64    /// on the dominant matmul reads.
65    pub f16_buffer: Option<wgpu::Buffer>,
66    /// Per-node byte offset into `buffer`.
67    pub offsets: HashMap<NodeId, usize>,
68    /// Per-node byte length.
69    pub lens: HashMap<NodeId, usize>,
70    /// Total arena size in bytes.
71    pub size: usize,
72    /// Byte offset of the tail scratch zone (also `size - scratch_bytes`).
73    /// Set when callers request scratch via `from_plan_with_scratch`.
74    /// Reuseable across ops since scratch is temporary — only one
75    /// op writes to it at a time within a schedule.
76    pub scratch_off: usize,
77    /// Size in bytes of the tail scratch zone (0 when not used).
78    pub scratch_bytes: usize,
79}
80
81/// Plan memory using f32-sized slots regardless of declared IR dtype,
82/// with liveness-aware slot reuse (see `rlx_compile::memory::plan_memory_f32_uniform`).
83pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
84    rlx_compile::memory::plan_memory_f32_uniform(graph, align)
85}
86
87impl Arena {
88    /// Build an arena from a memory plan with an extra tail scratch zone
89    /// of `scratch_bytes` reserved past the plan's arena_size. Useful for
90    /// ops that need throwaway temp storage that doesn't fit in a
91    /// workgroup-shared variable.
92    pub fn from_plan_with_scratch(
93        device: &wgpu::Device,
94        plan: &MemoryPlan,
95        scratch_bytes: usize,
96    ) -> Self {
97        let mut arena = Self::from_plan(device, plan);
98        if scratch_bytes == 0 {
99            return arena;
100        }
101        // Round up to 16 for storage-binding alignment.
102        let scratch_aligned = scratch_bytes.div_ceil(16) * 16;
103        let new_size = plan.arena_size + scratch_aligned;
104        let max_buf = device.limits().max_buffer_size;
105        if (new_size as u64) > max_buf {
106            panic!(
107                "rlx-wgpu: arena+scratch {} bytes exceeds max_buffer_size {}",
108                new_size, max_buf
109            );
110        }
111        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
112            label: Some("rlx-wgpu arena+scratch"),
113            size: new_size as u64,
114            usage: wgpu::BufferUsages::STORAGE
115                | wgpu::BufferUsages::COPY_SRC
116                | wgpu::BufferUsages::COPY_DST,
117            mapped_at_creation: false,
118        });
119        // Drop the placeholder buffer (the smaller one from from_plan)
120        // by replacing it; the wgpu Buffer destructor frees the old one.
121        arena.buffer = buffer;
122        arena.size = new_size;
123        arena.scratch_off = plan.arena_size;
124        arena.scratch_bytes = scratch_aligned;
125        arena
126    }
127
128    /// Build an arena from a memory plan. Allocates one big buffer
129    /// sized to fit every node's offset+length.
130    pub fn from_plan(device: &wgpu::Device, plan: &MemoryPlan) -> Self {
131        let size = plan.arena_size.max(1); // wgpu hates zero-sized allocs
132        // Note: WebGPU caps each *binding range* at `max_storage_buffer_binding_size` (often 4 GiB
133        // on native backends). We handle that at dispatch time by binding a per-kernel window.
134        let max_buf = device.limits().max_buffer_size;
135        if (size as u64) > max_buf {
136            panic!(
137                "rlx-wgpu: planned arena size {} bytes ({:.3} GiB) exceeds max_buffer_size {} bytes ({:.3} GiB)",
138                size,
139                size as f64 / (1u64 << 30) as f64,
140                max_buf,
141                max_buf as f64 / (1u64 << 30) as f64
142            );
143        }
144        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
145            label: Some("rlx-wgpu arena"),
146            size: size as u64,
147            usage: wgpu::BufferUsages::STORAGE
148                | wgpu::BufferUsages::COPY_SRC
149                | wgpu::BufferUsages::COPY_DST,
150            mapped_at_creation: false,
151        });
152        // Mirror f16 shadow buffer: half the byte size since each f32
153        // slot maps to an f16 slot at the same logical element index.
154        // On arenas larger than one bind window, allocate a capped f16
155        // buffer (≤ max_storage_buffer_binding_size) so matmul can use
156        // f16_weight_bind_range instead of staging multi‑GiB weights.
157        let max_binding = device.limits().max_storage_buffer_binding_size as usize;
158        let f16_buffer = if device.features().contains(wgpu::Features::SHADER_F16)
159            && !rlx_ir::env::flag("RLX_WGPU_NO_F16_SHADOW")
160        {
161            let f16_size = if size <= max_binding {
162                f16_shadow_arena_size(plan)
163            } else {
164                max_binding
165            };
166            Some(device.create_buffer(&wgpu::BufferDescriptor {
167                label: Some("rlx-wgpu arena f16"),
168                size: f16_size as u64,
169                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
170                mapped_at_creation: false,
171            }))
172        } else {
173            None
174        };
175        // `offsets` map to slot start (16-byte aligned). `lens` map to
176        // ACTUAL data length (elems * 4) — distinct from the slot size,
177        // which may include alignment padding. Readback uses lens so a
178        // [5] f32 returns 5 elements, not the 8 that fit in a 32-byte
179        // padded slot.
180        let mut offsets = HashMap::with_capacity(plan.assignments.len());
181        let mut lens = HashMap::with_capacity(plan.assignments.len());
182        for (id, a) in &plan.assignments {
183            offsets.insert(*id, a.offset);
184            // Default to the slot size; backends may override via
185            // set_actual_len for nodes whose elem count differs.
186            lens.insert(*id, a.size);
187        }
188        Self {
189            buffer,
190            f16_buffer,
191            offsets,
192            lens,
193            size,
194            scratch_off: 0,
195            scratch_bytes: 0,
196        }
197    }
198
199    pub fn has(&self, id: NodeId) -> bool {
200        self.offsets.contains_key(&id)
201    }
202    pub fn offset(&self, id: NodeId) -> usize {
203        self.offsets[&id]
204    }
205    pub fn len_of(&self, id: NodeId) -> usize {
206        self.lens[&id]
207    }
208
209    /// Whether this node's f16 mirror fits in the capped f16 shadow buffer.
210    pub fn param_fits_f16_mirror(&self, id: NodeId) -> bool {
211        let Some(f16) = &self.f16_buffer else {
212            return false;
213        };
214        let f16_off = self.offset(id) / 2;
215        let f16_bytes = self.len_of(id) / 2;
216        f16_off.saturating_add(f16_bytes) <= f16.size() as usize
217    }
218
219    /// Override the actual data length (in bytes) for a node. The
220    /// backend calls this after planning to record true elem*4 sizes
221    /// instead of the alignment-padded slot sizes.
222    pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
223        self.lens.insert(id, bytes);
224    }
225
226    /// Write f32 data into the node's slot. The queue performs an
227    /// async transfer; subsequent kernel dispatches on the same queue
228    /// see the new bytes. When the device supports SHADER_F16, also
229    /// downcasts and writes the same data into the f16 shadow buffer
230    /// at offset `f32_offset / 2` — so matmul kernels with f16 weight
231    /// bindings can read directly from there at half the bandwidth.
232    pub fn write_f32(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
233        let off = self.offset(id);
234        let bytes: &[u8] = bytemuck::cast_slice(data);
235        queue.write_buffer(&self.buffer, off as u64, bytes);
236        self.write_f16_shadow_at(queue, off, data);
237    }
238
239    /// Downcast host f32 data into the f16 shadow buffer at `id`'s slot.
240    /// Used when skipping redundant f32 `write_buffer` but CoopF16Vk still
241    /// needs a fresh f16 mirror (e.g. input upload hash cache hits).
242    pub fn write_f16_shadow(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
243        self.write_f16_shadow_at(queue, self.offset(id), data);
244    }
245
246    fn write_f16_shadow_at(&self, queue: &wgpu::Queue, off: usize, data: &[f32]) {
247        if let Some(f16_buf) = &self.f16_buffer {
248            let f16_off = off / 2;
249            let mut f16_data: Vec<half::f16> =
250                data.iter().map(|&v| half::f16::from_f32(v)).collect();
251            if !f16_data.len().is_multiple_of(2) {
252                f16_data.push(half::f16::from_f32(0.0));
253            }
254            let f16_byte_len = f16_data.len() * 2;
255            if f16_off.saturating_add(f16_byte_len) > f16_buf.size() as usize {
256                return;
257            }
258            let f16_bytes: &[u8] =
259                unsafe { std::slice::from_raw_parts(f16_data.as_ptr() as *const u8, f16_byte_len) };
260            queue.write_buffer(f16_buf, f16_off as u64, f16_bytes);
261        }
262    }
263
264    /// Read a node's bytes back to host f32. Uses a fresh staging buffer;
265    /// hot paths should call [`read_f32_pooled`] with a reused [`ReadbackStaging`].
266    pub fn read_f32(&self, device: &wgpu::Device, queue: &wgpu::Queue, id: NodeId) -> Vec<f32> {
267        read_f32_pooled(self, device, queue, id, &mut None)
268    }
269
270    /// Read a byte range from the arena (used for packed GGUF weights).
271    pub fn read_bytes_range(
272        &self,
273        device: &wgpu::Device,
274        queue: &wgpu::Queue,
275        byte_off: usize,
276        len: usize,
277    ) -> Vec<u8> {
278        if len == 0 {
279            return Vec::new();
280        }
281        let staging = device.create_buffer(&wgpu::BufferDescriptor {
282            label: Some("rlx-wgpu readback bytes"),
283            size: len as u64,
284            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
285            mapped_at_creation: false,
286        });
287        let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
288            label: Some("rlx-wgpu readback bytes enc"),
289        });
290        enc.copy_buffer_to_buffer(&self.buffer, byte_off as u64, &staging, 0, len as u64);
291        queue.submit(std::iter::once(enc.finish()));
292
293        let slice = staging.slice(..);
294        let (sender, receiver) = std::sync::mpsc::channel();
295        slice.map_async(wgpu::MapMode::Read, move |r| {
296            let _ = sender.send(r);
297        });
298        let _ = device.poll(wgpu::PollType::wait_indefinitely());
299        receiver.recv().unwrap().unwrap();
300
301        let view = slice.get_mapped_range();
302        let out = view.to_vec();
303        drop(view);
304        staging.unmap();
305        out
306    }
307
308    /// Write raw bytes into the arena at `byte_off`.
309    pub fn write_bytes_range(&self, queue: &wgpu::Queue, byte_off: usize, data: &[u8]) {
310        if data.is_empty() {
311            return;
312        }
313        queue.write_buffer(&self.buffer, byte_off as u64, data);
314    }
315}
316
317/// Reusable MAP_READ staging buffer for output readback.
318pub struct ReadbackStaging {
319    buffer: wgpu::Buffer,
320    capacity: usize,
321}
322
323/// Fixed 256 B MAP_READ staging for scalar (≤16 B) readback — avoids
324/// `map_buffer_on_submit` + full-layout decode on MoltenVK hot paths.
325pub struct TinyReadbackStaging {
326    buffer: wgpu::Buffer,
327}
328
329impl TinyReadbackStaging {
330    const CAPACITY: u64 = 256;
331
332    pub fn new(device: &wgpu::Device) -> Self {
333        Self {
334            buffer: device.create_buffer(&wgpu::BufferDescriptor {
335                label: Some("rlx-wgpu tiny readback"),
336                size: Self::CAPACITY,
337                usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
338                mapped_at_creation: false,
339            }),
340        }
341    }
342
343    pub fn buffer(&self) -> &wgpu::Buffer {
344        &self.buffer
345    }
346}
347
348/// True when fused readback can use the tiny scalar fast path.
349pub fn use_tiny_readback(layout: &ReadbackLayout, num_outputs: usize) -> bool {
350    num_outputs == 1 && layout.total_bytes <= 16
351}
352
353/// After submit: decode one f32 vector from an already-mapped tiny staging buffer.
354pub fn decode_tiny_mapped_f32(staging: &wgpu::Buffer, len: usize) -> Vec<f32> {
355    let len = len.max(4);
356    let slice = staging.slice(..len as u64);
357    let view = slice.get_mapped_range();
358    let out = bytemuck::cast_slice::<u8, f32>(&view[..len]).to_vec();
359    drop(view);
360    staging.unmap();
361    out
362}
363
364/// After submit: map only `len` bytes and decode one f32 vector.
365pub fn read_tiny_f32_after_submit(
366    device: &wgpu::Device,
367    staging: &wgpu::Buffer,
368    len: usize,
369) -> Vec<f32> {
370    let len = len.max(4);
371    let slice = staging.slice(..len as u64);
372    let (sender, receiver) = std::sync::mpsc::channel();
373    slice.map_async(wgpu::MapMode::Read, move |r| {
374        let _ = sender.send(r);
375    });
376    wait_readback_map(device, &receiver, len);
377    receiver.recv().unwrap().unwrap();
378    decode_tiny_mapped_f32(staging, len)
379}
380
381impl ReadbackStaging {
382    pub(crate) fn buffer(&self) -> &wgpu::Buffer {
383        &self.buffer
384    }
385
386    fn ensure(&mut self, device: &wgpu::Device, min_bytes: usize) {
387        let need = min_bytes.max(256);
388        if self.capacity >= need {
389            return;
390        }
391        let cap = need.next_power_of_two().max(256);
392        self.buffer = device.create_buffer(&wgpu::BufferDescriptor {
393            label: Some("rlx-wgpu readback staging"),
394            size: cap as u64,
395            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
396            mapped_at_creation: false,
397        });
398        self.capacity = cap;
399    }
400
401    /// Grow-or-create staging for at least `min_bytes`.
402    pub fn prepare(device: &wgpu::Device, staging: &mut Option<Self>, min_bytes: usize) {
403        match staging {
404            Some(s) => s.ensure(device, min_bytes),
405            None => {
406                let cap = min_bytes.max(256).next_power_of_two();
407                *staging = Some(Self {
408                    buffer: device.create_buffer(&wgpu::BufferDescriptor {
409                        label: Some("rlx-wgpu readback staging"),
410                        size: cap as u64,
411                        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
412                        mapped_at_creation: false,
413                    }),
414                    capacity: cap,
415                });
416            }
417        }
418    }
419}
420
421fn align4(n: usize) -> usize {
422    (n + 3) & !3
423}
424
425/// Layout for batched output readback into a staging buffer.
426#[derive(Debug, Clone)]
427pub struct ReadbackLayout {
428    pub regions: Vec<(usize, usize)>,
429    pub total_bytes: usize,
430}
431
432impl ReadbackLayout {
433    pub fn for_nodes(arena: &Arena, ids: &[NodeId]) -> Self {
434        if ids.is_empty() {
435            return Self {
436                regions: Vec::new(),
437                total_bytes: 0,
438            };
439        }
440        if ids.len() == 1 {
441            let len = arena.len_of(ids[0]);
442            return Self {
443                regions: vec![(0, len)],
444                total_bytes: len,
445            };
446        }
447        let mut regions = Vec::with_capacity(ids.len());
448        let mut total = 0usize;
449        for &id in ids {
450            let len = arena.len_of(id);
451            let start = total;
452            total = align4(start + len);
453            regions.push((start, len));
454        }
455        Self {
456            regions,
457            total_bytes: total,
458        }
459    }
460}
461
462/// Append arena→staging copies to an encoder (no submit).
463pub fn encode_readback_copies(
464    enc: &mut wgpu::CommandEncoder,
465    arena: &Arena,
466    staging: &wgpu::Buffer,
467    ids: &[NodeId],
468    layout: &ReadbackLayout,
469) {
470    for (&id, &(dst_off, len)) in ids.iter().zip(layout.regions.iter()) {
471        enc.copy_buffer_to_buffer(
472            &arena.buffer,
473            arena.offset(id) as u64,
474            staging,
475            dst_off as u64,
476            len as u64,
477        );
478    }
479}
480
481/// Map staging after submit and decode f32 outputs (one poll).
482pub fn map_readback_f32(
483    device: &wgpu::Device,
484    staging: &wgpu::Buffer,
485    layout: &ReadbackLayout,
486) -> Vec<Vec<f32>> {
487    map_readback_f32_after_submit(device, staging, layout)
488}
489
490/// Poll until a readback map callback completes (fast path for tiny outputs).
491pub fn wait_readback_map(
492    device: &wgpu::Device,
493    _map_rx: &std::sync::mpsc::Receiver<Result<(), wgpu::BufferAsyncError>>,
494    total_bytes: usize,
495) {
496    let spins = if total_bytes <= 16 { 256 } else { 64 };
497    for _ in 0..spins {
498        let _ = device.poll(wgpu::PollType::Poll);
499    }
500    let _ = device.poll(wgpu::PollType::wait_indefinitely());
501}
502
503/// Schedule `map_async` on the encoder so mapping starts with submit (wgpu 29+).
504pub fn schedule_readback_map(
505    encoder: &mut wgpu::CommandEncoder,
506    staging: &wgpu::Buffer,
507    layout: &ReadbackLayout,
508) -> std::sync::mpsc::Receiver<Result<(), wgpu::BufferAsyncError>> {
509    let total = layout.total_bytes;
510    let (sender, receiver) = std::sync::mpsc::channel();
511    encoder.map_buffer_on_submit(staging, wgpu::MapMode::Read, 0..total as u64, move |r| {
512        let _ = sender.send(r);
513    });
514    receiver
515}
516
517fn map_readback_f32_after_submit(
518    device: &wgpu::Device,
519    staging: &wgpu::Buffer,
520    layout: &ReadbackLayout,
521) -> Vec<Vec<f32>> {
522    if layout.regions.is_empty() {
523        return Vec::new();
524    }
525    let total = layout.total_bytes;
526    let slice = staging.slice(..total as u64);
527    let (sender, receiver) = std::sync::mpsc::channel();
528    slice.map_async(wgpu::MapMode::Read, move |r| {
529        let _ = sender.send(r);
530    });
531    wait_readback_map(device, &receiver, total);
532    receiver.recv().unwrap().unwrap();
533
534    let view = slice.get_mapped_range();
535    let bytes = &view[..];
536    let mut outs = Vec::with_capacity(layout.regions.len());
537    for &(start, len) in &layout.regions {
538        let chunk = &bytes[start..start + len];
539        outs.push(bytemuck::cast_slice::<u8, f32>(chunk).to_vec());
540    }
541    drop(view);
542    staging.unmap();
543    outs
544}
545
546/// Decode f32 outputs after submit + map callback (used with [`schedule_readback_map`]).
547pub fn decode_mapped_readback_f32(
548    staging: &wgpu::Buffer,
549    layout: &ReadbackLayout,
550) -> Vec<Vec<f32>> {
551    if layout.regions.is_empty() {
552        return Vec::new();
553    }
554    let total = layout.total_bytes;
555    let slice = staging.slice(..total as u64);
556    let view = slice.get_mapped_range();
557    let bytes = &view[..];
558    let mut outs = Vec::with_capacity(layout.regions.len());
559    for &(start, len) in &layout.regions {
560        let chunk = &bytes[start..start + len];
561        outs.push(bytemuck::cast_slice::<u8, f32>(chunk).to_vec());
562    }
563    drop(view);
564    staging.unmap();
565    outs
566}
567
568/// Read one node via a reused staging buffer (one submit + one poll).
569pub fn read_f32_pooled(
570    arena: &Arena,
571    device: &wgpu::Device,
572    queue: &wgpu::Queue,
573    id: NodeId,
574    staging: &mut Option<ReadbackStaging>,
575) -> Vec<f32> {
576    let off = arena.offset(id);
577    let len = arena.len_of(id);
578    let n_elems = len / 4;
579    if n_elems == 0 {
580        return Vec::new();
581    }
582    ReadbackStaging::prepare(device, staging, len);
583    let staging = staging.as_ref().expect("staging");
584
585    let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
586        label: Some("rlx-wgpu readback enc"),
587    });
588    enc.copy_buffer_to_buffer(&arena.buffer, off as u64, &staging.buffer, 0, len as u64);
589    queue.submit(std::iter::once(enc.finish()));
590
591    let slice = staging.buffer.slice(..len as u64);
592    let (sender, receiver) = std::sync::mpsc::channel();
593    slice.map_async(wgpu::MapMode::Read, move |r| {
594        let _ = sender.send(r);
595    });
596    wait_readback_map(device, &receiver, len);
597    receiver.recv().unwrap().unwrap();
598
599    let view = slice.get_mapped_range();
600    let out: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&view).to_vec();
601    drop(view);
602    staging.buffer.unmap();
603    out
604}
605
606/// Read several nodes with one submit + one poll (contiguous staging layout).
607pub fn read_f32_many_pooled(
608    arena: &Arena,
609    device: &wgpu::Device,
610    queue: &wgpu::Queue,
611    ids: &[NodeId],
612    staging: &mut Option<ReadbackStaging>,
613) -> Vec<Vec<f32>> {
614    if ids.is_empty() {
615        return Vec::new();
616    }
617    let layout = ReadbackLayout::for_nodes(arena, ids);
618    ReadbackStaging::prepare(device, staging, layout.total_bytes);
619    let staging_buf = staging.as_ref().expect("staging").buffer().clone();
620
621    let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
622        label: Some("rlx-wgpu readback batch enc"),
623    });
624    encode_readback_copies(&mut enc, arena, &staging_buf, ids, &layout);
625    queue.submit(std::iter::once(enc.finish()));
626    map_readback_f32(device, &staging_buf, &layout)
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use rlx_ir::NodeId;
633    use rlx_opt::memory::{BufferSlot, MemoryPlan};
634    use std::collections::HashMap;
635
636    #[test]
637    fn f16_shadow_arena_accounts_for_copy_alignment_padding() {
638        // Three f32 elements → six f16 bytes, padded to eight for wgpu
639        // COPY_BUFFER_ALIGNMENT. The old `arena_size / 2` sizing was two
640        // bytes short at this slot boundary.
641        let mut assignments = HashMap::new();
642        assignments.insert(
643            NodeId(0),
644            BufferSlot {
645                offset: 32,
646                size: 12,
647            },
648        );
649        let plan = MemoryPlan {
650            arena_size: 44,
651            assignments,
652            schedule: vec![],
653        };
654        assert_eq!(f16_shadow_arena_size(&plan), 24);
655    }
656}