1use rlx_ir::{Graph, NodeId};
27use rlx_opt::memory::MemoryPlan;
28use std::collections::HashMap;
29
30fn f16_shadow_write_end(f32_byte_offset: usize, f32_byte_len: usize) -> usize {
35 let f16_off = f32_byte_offset / 2;
36 let f16_bytes = (f32_byte_len / 4) * 2;
37 let padded = (f16_bytes + 3) & !3;
38 f16_off + padded
39}
40
41fn f16_shadow_arena_size(plan: &MemoryPlan) -> usize {
43 plan.assignments
44 .values()
45 .map(|a| f16_shadow_write_end(a.offset, a.size))
46 .max()
47 .unwrap_or(0)
48 .max(1)
49}
50
51pub struct Arena {
54 pub buffer: wgpu::Buffer,
57 pub f16_buffer: Option<wgpu::Buffer>,
66 pub offsets: HashMap<NodeId, usize>,
68 pub lens: HashMap<NodeId, usize>,
70 pub size: usize,
72 pub scratch_off: usize,
77 pub scratch_bytes: usize,
79}
80
81pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
84 rlx_compile::memory::plan_memory_f32_uniform(graph, align)
85}
86
87impl Arena {
88 pub fn from_plan_with_scratch(
93 device: &wgpu::Device,
94 plan: &MemoryPlan,
95 scratch_bytes: usize,
96 ) -> Self {
97 let mut arena = Self::from_plan(device, plan);
98 if scratch_bytes == 0 {
99 return arena;
100 }
101 let scratch_aligned = scratch_bytes.div_ceil(16) * 16;
103 let new_size = plan.arena_size + scratch_aligned;
104 let max_buf = device.limits().max_buffer_size;
105 if (new_size as u64) > max_buf {
106 panic!(
107 "rlx-wgpu: arena+scratch {} bytes exceeds max_buffer_size {}",
108 new_size, max_buf
109 );
110 }
111 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
112 label: Some("rlx-wgpu arena+scratch"),
113 size: new_size as u64,
114 usage: wgpu::BufferUsages::STORAGE
115 | wgpu::BufferUsages::COPY_SRC
116 | wgpu::BufferUsages::COPY_DST,
117 mapped_at_creation: false,
118 });
119 arena.buffer = buffer;
122 arena.size = new_size;
123 arena.scratch_off = plan.arena_size;
124 arena.scratch_bytes = scratch_aligned;
125 arena
126 }
127
128 pub fn from_plan(device: &wgpu::Device, plan: &MemoryPlan) -> Self {
131 let size = plan.arena_size.max(1); let max_buf = device.limits().max_buffer_size;
135 if (size as u64) > max_buf {
136 panic!(
137 "rlx-wgpu: planned arena size {} bytes ({:.3} GiB) exceeds max_buffer_size {} bytes ({:.3} GiB)",
138 size,
139 size as f64 / (1u64 << 30) as f64,
140 max_buf,
141 max_buf as f64 / (1u64 << 30) as f64
142 );
143 }
144 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
145 label: Some("rlx-wgpu arena"),
146 size: size as u64,
147 usage: wgpu::BufferUsages::STORAGE
148 | wgpu::BufferUsages::COPY_SRC
149 | wgpu::BufferUsages::COPY_DST,
150 mapped_at_creation: false,
151 });
152 let max_binding = device.limits().max_storage_buffer_binding_size as usize;
158 let f16_buffer = if device.features().contains(wgpu::Features::SHADER_F16)
159 && !rlx_ir::env::flag("RLX_WGPU_NO_F16_SHADOW")
160 {
161 let f16_size = if size <= max_binding {
162 f16_shadow_arena_size(plan)
163 } else {
164 max_binding
165 };
166 Some(device.create_buffer(&wgpu::BufferDescriptor {
167 label: Some("rlx-wgpu arena f16"),
168 size: f16_size as u64,
169 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
170 mapped_at_creation: false,
171 }))
172 } else {
173 None
174 };
175 let mut offsets = HashMap::with_capacity(plan.assignments.len());
181 let mut lens = HashMap::with_capacity(plan.assignments.len());
182 for (id, a) in &plan.assignments {
183 offsets.insert(*id, a.offset);
184 lens.insert(*id, a.size);
187 }
188 Self {
189 buffer,
190 f16_buffer,
191 offsets,
192 lens,
193 size,
194 scratch_off: 0,
195 scratch_bytes: 0,
196 }
197 }
198
199 pub fn has(&self, id: NodeId) -> bool {
200 self.offsets.contains_key(&id)
201 }
202 pub fn offset(&self, id: NodeId) -> usize {
203 self.offsets[&id]
204 }
205 pub fn len_of(&self, id: NodeId) -> usize {
206 self.lens[&id]
207 }
208
209 pub fn param_fits_f16_mirror(&self, id: NodeId) -> bool {
211 let Some(f16) = &self.f16_buffer else {
212 return false;
213 };
214 let f16_off = self.offset(id) / 2;
215 let f16_bytes = self.len_of(id) / 2;
216 f16_off.saturating_add(f16_bytes) <= f16.size() as usize
217 }
218
219 pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
223 self.lens.insert(id, bytes);
224 }
225
226 pub fn write_f32(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
233 let off = self.offset(id);
234 let bytes: &[u8] = bytemuck::cast_slice(data);
235 queue.write_buffer(&self.buffer, off as u64, bytes);
236 self.write_f16_shadow_at(queue, off, data);
237 }
238
239 pub fn write_f16_shadow(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
243 self.write_f16_shadow_at(queue, self.offset(id), data);
244 }
245
246 fn write_f16_shadow_at(&self, queue: &wgpu::Queue, off: usize, data: &[f32]) {
247 if let Some(f16_buf) = &self.f16_buffer {
248 let f16_off = off / 2;
249 let mut f16_data: Vec<half::f16> =
250 data.iter().map(|&v| half::f16::from_f32(v)).collect();
251 if !f16_data.len().is_multiple_of(2) {
252 f16_data.push(half::f16::from_f32(0.0));
253 }
254 let f16_byte_len = f16_data.len() * 2;
255 if f16_off.saturating_add(f16_byte_len) > f16_buf.size() as usize {
256 return;
257 }
258 let f16_bytes: &[u8] =
259 unsafe { std::slice::from_raw_parts(f16_data.as_ptr() as *const u8, f16_byte_len) };
260 queue.write_buffer(f16_buf, f16_off as u64, f16_bytes);
261 }
262 }
263
264 pub fn read_f32(&self, device: &wgpu::Device, queue: &wgpu::Queue, id: NodeId) -> Vec<f32> {
267 read_f32_pooled(self, device, queue, id, &mut None)
268 }
269
270 pub fn read_bytes_range(
272 &self,
273 device: &wgpu::Device,
274 queue: &wgpu::Queue,
275 byte_off: usize,
276 len: usize,
277 ) -> Vec<u8> {
278 if len == 0 {
279 return Vec::new();
280 }
281 let staging = device.create_buffer(&wgpu::BufferDescriptor {
282 label: Some("rlx-wgpu readback bytes"),
283 size: len as u64,
284 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
285 mapped_at_creation: false,
286 });
287 let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
288 label: Some("rlx-wgpu readback bytes enc"),
289 });
290 enc.copy_buffer_to_buffer(&self.buffer, byte_off as u64, &staging, 0, len as u64);
291 queue.submit(std::iter::once(enc.finish()));
292
293 let slice = staging.slice(..);
294 let (sender, receiver) = std::sync::mpsc::channel();
295 slice.map_async(wgpu::MapMode::Read, move |r| {
296 let _ = sender.send(r);
297 });
298 let _ = device.poll(wgpu::PollType::wait_indefinitely());
299 receiver.recv().unwrap().unwrap();
300
301 let view = slice.get_mapped_range();
302 let out = view.to_vec();
303 drop(view);
304 staging.unmap();
305 out
306 }
307
308 pub fn write_bytes_range(&self, queue: &wgpu::Queue, byte_off: usize, data: &[u8]) {
310 if data.is_empty() {
311 return;
312 }
313 queue.write_buffer(&self.buffer, byte_off as u64, data);
314 }
315}
316
317pub struct ReadbackStaging {
319 buffer: wgpu::Buffer,
320 capacity: usize,
321}
322
323pub struct TinyReadbackStaging {
326 buffer: wgpu::Buffer,
327}
328
329impl TinyReadbackStaging {
330 const CAPACITY: u64 = 256;
331
332 pub fn new(device: &wgpu::Device) -> Self {
333 Self {
334 buffer: device.create_buffer(&wgpu::BufferDescriptor {
335 label: Some("rlx-wgpu tiny readback"),
336 size: Self::CAPACITY,
337 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
338 mapped_at_creation: false,
339 }),
340 }
341 }
342
343 pub fn buffer(&self) -> &wgpu::Buffer {
344 &self.buffer
345 }
346}
347
348pub fn use_tiny_readback(layout: &ReadbackLayout, num_outputs: usize) -> bool {
350 num_outputs == 1 && layout.total_bytes <= 16
351}
352
353pub fn decode_tiny_mapped_f32(staging: &wgpu::Buffer, len: usize) -> Vec<f32> {
355 let len = len.max(4);
356 let slice = staging.slice(..len as u64);
357 let view = slice.get_mapped_range();
358 let out = bytemuck::cast_slice::<u8, f32>(&view[..len]).to_vec();
359 drop(view);
360 staging.unmap();
361 out
362}
363
364pub fn read_tiny_f32_after_submit(
366 device: &wgpu::Device,
367 staging: &wgpu::Buffer,
368 len: usize,
369) -> Vec<f32> {
370 let len = len.max(4);
371 let slice = staging.slice(..len as u64);
372 let (sender, receiver) = std::sync::mpsc::channel();
373 slice.map_async(wgpu::MapMode::Read, move |r| {
374 let _ = sender.send(r);
375 });
376 wait_readback_map(device, &receiver, len);
377 receiver.recv().unwrap().unwrap();
378 decode_tiny_mapped_f32(staging, len)
379}
380
381impl ReadbackStaging {
382 pub(crate) fn buffer(&self) -> &wgpu::Buffer {
383 &self.buffer
384 }
385
386 fn ensure(&mut self, device: &wgpu::Device, min_bytes: usize) {
387 let need = min_bytes.max(256);
388 if self.capacity >= need {
389 return;
390 }
391 let cap = need.next_power_of_two().max(256);
392 self.buffer = device.create_buffer(&wgpu::BufferDescriptor {
393 label: Some("rlx-wgpu readback staging"),
394 size: cap as u64,
395 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
396 mapped_at_creation: false,
397 });
398 self.capacity = cap;
399 }
400
401 pub fn prepare(device: &wgpu::Device, staging: &mut Option<Self>, min_bytes: usize) {
403 match staging {
404 Some(s) => s.ensure(device, min_bytes),
405 None => {
406 let cap = min_bytes.max(256).next_power_of_two();
407 *staging = Some(Self {
408 buffer: device.create_buffer(&wgpu::BufferDescriptor {
409 label: Some("rlx-wgpu readback staging"),
410 size: cap as u64,
411 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
412 mapped_at_creation: false,
413 }),
414 capacity: cap,
415 });
416 }
417 }
418 }
419}
420
421fn align4(n: usize) -> usize {
422 (n + 3) & !3
423}
424
425#[derive(Debug, Clone)]
427pub struct ReadbackLayout {
428 pub regions: Vec<(usize, usize)>,
429 pub total_bytes: usize,
430}
431
432impl ReadbackLayout {
433 pub fn for_nodes(arena: &Arena, ids: &[NodeId]) -> Self {
434 if ids.is_empty() {
435 return Self {
436 regions: Vec::new(),
437 total_bytes: 0,
438 };
439 }
440 if ids.len() == 1 {
441 let len = arena.len_of(ids[0]);
442 return Self {
443 regions: vec![(0, len)],
444 total_bytes: len,
445 };
446 }
447 let mut regions = Vec::with_capacity(ids.len());
448 let mut total = 0usize;
449 for &id in ids {
450 let len = arena.len_of(id);
451 let start = total;
452 total = align4(start + len);
453 regions.push((start, len));
454 }
455 Self {
456 regions,
457 total_bytes: total,
458 }
459 }
460}
461
462pub fn encode_readback_copies(
464 enc: &mut wgpu::CommandEncoder,
465 arena: &Arena,
466 staging: &wgpu::Buffer,
467 ids: &[NodeId],
468 layout: &ReadbackLayout,
469) {
470 for (&id, &(dst_off, len)) in ids.iter().zip(layout.regions.iter()) {
471 enc.copy_buffer_to_buffer(
472 &arena.buffer,
473 arena.offset(id) as u64,
474 staging,
475 dst_off as u64,
476 len as u64,
477 );
478 }
479}
480
481pub fn map_readback_f32(
483 device: &wgpu::Device,
484 staging: &wgpu::Buffer,
485 layout: &ReadbackLayout,
486) -> Vec<Vec<f32>> {
487 map_readback_f32_after_submit(device, staging, layout)
488}
489
490pub fn wait_readback_map(
492 device: &wgpu::Device,
493 _map_rx: &std::sync::mpsc::Receiver<Result<(), wgpu::BufferAsyncError>>,
494 total_bytes: usize,
495) {
496 let spins = if total_bytes <= 16 { 256 } else { 64 };
497 for _ in 0..spins {
498 let _ = device.poll(wgpu::PollType::Poll);
499 }
500 let _ = device.poll(wgpu::PollType::wait_indefinitely());
501}
502
503pub fn schedule_readback_map(
505 encoder: &mut wgpu::CommandEncoder,
506 staging: &wgpu::Buffer,
507 layout: &ReadbackLayout,
508) -> std::sync::mpsc::Receiver<Result<(), wgpu::BufferAsyncError>> {
509 let total = layout.total_bytes;
510 let (sender, receiver) = std::sync::mpsc::channel();
511 encoder.map_buffer_on_submit(staging, wgpu::MapMode::Read, 0..total as u64, move |r| {
512 let _ = sender.send(r);
513 });
514 receiver
515}
516
517fn map_readback_f32_after_submit(
518 device: &wgpu::Device,
519 staging: &wgpu::Buffer,
520 layout: &ReadbackLayout,
521) -> Vec<Vec<f32>> {
522 if layout.regions.is_empty() {
523 return Vec::new();
524 }
525 let total = layout.total_bytes;
526 let slice = staging.slice(..total as u64);
527 let (sender, receiver) = std::sync::mpsc::channel();
528 slice.map_async(wgpu::MapMode::Read, move |r| {
529 let _ = sender.send(r);
530 });
531 wait_readback_map(device, &receiver, total);
532 receiver.recv().unwrap().unwrap();
533
534 let view = slice.get_mapped_range();
535 let bytes = &view[..];
536 let mut outs = Vec::with_capacity(layout.regions.len());
537 for &(start, len) in &layout.regions {
538 let chunk = &bytes[start..start + len];
539 outs.push(bytemuck::cast_slice::<u8, f32>(chunk).to_vec());
540 }
541 drop(view);
542 staging.unmap();
543 outs
544}
545
546pub fn decode_mapped_readback_f32(
548 staging: &wgpu::Buffer,
549 layout: &ReadbackLayout,
550) -> Vec<Vec<f32>> {
551 if layout.regions.is_empty() {
552 return Vec::new();
553 }
554 let total = layout.total_bytes;
555 let slice = staging.slice(..total as u64);
556 let view = slice.get_mapped_range();
557 let bytes = &view[..];
558 let mut outs = Vec::with_capacity(layout.regions.len());
559 for &(start, len) in &layout.regions {
560 let chunk = &bytes[start..start + len];
561 outs.push(bytemuck::cast_slice::<u8, f32>(chunk).to_vec());
562 }
563 drop(view);
564 staging.unmap();
565 outs
566}
567
568pub fn read_f32_pooled(
570 arena: &Arena,
571 device: &wgpu::Device,
572 queue: &wgpu::Queue,
573 id: NodeId,
574 staging: &mut Option<ReadbackStaging>,
575) -> Vec<f32> {
576 let off = arena.offset(id);
577 let len = arena.len_of(id);
578 let n_elems = len / 4;
579 if n_elems == 0 {
580 return Vec::new();
581 }
582 ReadbackStaging::prepare(device, staging, len);
583 let staging = staging.as_ref().expect("staging");
584
585 let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
586 label: Some("rlx-wgpu readback enc"),
587 });
588 enc.copy_buffer_to_buffer(&arena.buffer, off as u64, &staging.buffer, 0, len as u64);
589 queue.submit(std::iter::once(enc.finish()));
590
591 let slice = staging.buffer.slice(..len as u64);
592 let (sender, receiver) = std::sync::mpsc::channel();
593 slice.map_async(wgpu::MapMode::Read, move |r| {
594 let _ = sender.send(r);
595 });
596 wait_readback_map(device, &receiver, len);
597 receiver.recv().unwrap().unwrap();
598
599 let view = slice.get_mapped_range();
600 let out: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&view).to_vec();
601 drop(view);
602 staging.buffer.unmap();
603 out
604}
605
606pub fn read_f32_many_pooled(
608 arena: &Arena,
609 device: &wgpu::Device,
610 queue: &wgpu::Queue,
611 ids: &[NodeId],
612 staging: &mut Option<ReadbackStaging>,
613) -> Vec<Vec<f32>> {
614 if ids.is_empty() {
615 return Vec::new();
616 }
617 let layout = ReadbackLayout::for_nodes(arena, ids);
618 ReadbackStaging::prepare(device, staging, layout.total_bytes);
619 let staging_buf = staging.as_ref().expect("staging").buffer().clone();
620
621 let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
622 label: Some("rlx-wgpu readback batch enc"),
623 });
624 encode_readback_copies(&mut enc, arena, &staging_buf, ids, &layout);
625 queue.submit(std::iter::once(enc.finish()));
626 map_readback_f32(device, &staging_buf, &layout)
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use rlx_ir::NodeId;
633 use rlx_opt::memory::{BufferSlot, MemoryPlan};
634 use std::collections::HashMap;
635
636 #[test]
637 fn f16_shadow_arena_accounts_for_copy_alignment_padding() {
638 let mut assignments = HashMap::new();
642 assignments.insert(
643 NodeId(0),
644 BufferSlot {
645 offset: 32,
646 size: 12,
647 },
648 );
649 let plan = MemoryPlan {
650 arena_size: 44,
651 assignments,
652 schedule: vec![],
653 };
654 assert_eq!(f16_shadow_arena_size(&plan), 24);
655 }
656}