use rlx_ir::{Graph, NodeId};
use rlx_opt::memory::MemoryPlan;
use std::collections::HashMap;
fn f16_shadow_write_end(f32_byte_offset: usize, f32_byte_len: usize) -> usize {
let f16_off = f32_byte_offset / 2;
let f16_bytes = (f32_byte_len / 4) * 2;
let padded = (f16_bytes + 3) & !3;
f16_off + padded
}
fn f16_shadow_arena_size(plan: &MemoryPlan) -> usize {
plan.assignments
.values()
.map(|a| f16_shadow_write_end(a.offset, a.size))
.max()
.unwrap_or(0)
.max(1)
}
pub struct Arena {
pub buffer: wgpu::Buffer,
pub f16_buffer: Option<wgpu::Buffer>,
pub offsets: HashMap<NodeId, usize>,
pub lens: HashMap<NodeId, usize>,
pub size: usize,
}
pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
rlx_compile::memory::plan_memory_f32_uniform(graph, align)
}
impl Arena {
pub fn from_plan(device: &wgpu::Device, plan: &MemoryPlan) -> Self {
let size = plan.arena_size.max(1); let max_binding = device.limits().max_storage_buffer_binding_size;
if (size as u64) > max_binding {
panic!(
"rlx-wgpu: planned arena size {} bytes ({:.3} GiB) exceeds \
max_storage_buffer_binding_size {} bytes ({:.3} GiB). \
Reduce batch/sequence size or split the graph.",
size,
size as f64 / (1u64 << 30) as f64,
max_binding,
max_binding as f64 / (1u64 << 30) as f64
);
}
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu arena"),
size: size as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let f16_buffer = if device.features().contains(wgpu::Features::SHADER_F16) {
let f16_size = f16_shadow_arena_size(plan);
Some(device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu arena f16"),
size: f16_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}))
} else {
None
};
let mut offsets = HashMap::with_capacity(plan.assignments.len());
let mut lens = HashMap::with_capacity(plan.assignments.len());
for (id, a) in &plan.assignments {
offsets.insert(*id, a.offset);
lens.insert(*id, a.size);
}
Self {
buffer,
f16_buffer,
offsets,
lens,
size,
}
}
pub fn has(&self, id: NodeId) -> bool {
self.offsets.contains_key(&id)
}
pub fn offset(&self, id: NodeId) -> usize {
self.offsets[&id]
}
pub fn len_of(&self, id: NodeId) -> usize {
self.lens[&id]
}
pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
self.lens.insert(id, bytes);
}
pub fn write_f32(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
let off = self.offset(id);
let bytes: &[u8] = bytemuck::cast_slice(data);
queue.write_buffer(&self.buffer, off as u64, bytes);
if let Some(f16_buf) = &self.f16_buffer {
let mut f16_data: Vec<half::f16> =
data.iter().map(|&v| half::f16::from_f32(v)).collect();
if !f16_data.len().is_multiple_of(2) {
f16_data.push(half::f16::from_f32(0.0));
}
let f16_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(f16_data.as_ptr() as *const u8, f16_data.len() * 2)
};
queue.write_buffer(f16_buf, (off / 2) as u64, f16_bytes);
}
}
pub fn read_f32(&self, device: &wgpu::Device, queue: &wgpu::Queue, id: NodeId) -> Vec<f32> {
let off = self.offset(id);
let len = self.len_of(id);
let n_elems = len / 4;
if n_elems == 0 {
return Vec::new();
}
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu readback"),
size: len as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rlx-wgpu readback enc"),
});
enc.copy_buffer_to_buffer(&self.buffer, off as u64, &staging, 0, len as u64);
queue.submit(std::iter::once(enc.finish()));
let slice = staging.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
let _ = device.poll(wgpu::PollType::wait_indefinitely());
receiver.recv().unwrap().unwrap();
let view = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&view).to_vec();
drop(view);
staging.unmap();
out
}
pub fn read_bytes_range(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
byte_off: usize,
len: usize,
) -> Vec<u8> {
if len == 0 {
return Vec::new();
}
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu readback bytes"),
size: len as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rlx-wgpu readback bytes enc"),
});
enc.copy_buffer_to_buffer(&self.buffer, byte_off as u64, &staging, 0, len as u64);
queue.submit(std::iter::once(enc.finish()));
let slice = staging.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
let _ = device.poll(wgpu::PollType::wait_indefinitely());
receiver.recv().unwrap().unwrap();
let view = slice.get_mapped_range();
let out = view.to_vec();
drop(view);
staging.unmap();
out
}
pub fn write_bytes_range(&self, queue: &wgpu::Queue, byte_off: usize, data: &[u8]) {
if data.is_empty() {
return;
}
queue.write_buffer(&self.buffer, byte_off as u64, data);
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::NodeId;
use rlx_opt::memory::{BufferSlot, MemoryPlan};
use std::collections::HashMap;
#[test]
fn f16_shadow_arena_accounts_for_copy_alignment_padding() {
let mut assignments = HashMap::new();
assignments.insert(
NodeId(0),
BufferSlot {
offset: 32,
size: 12,
},
);
let plan = MemoryPlan {
arena_size: 44,
assignments,
schedule: vec![],
};
assert_eq!(f16_shadow_arena_size(&plan), 24);
}
}