use rlx_ir::{Graph, NodeId, Op};
use rlx_opt::memory::MemoryPlan;
use std::collections::HashMap;
pub struct Arena {
pub buffer: wgpu::Buffer,
pub f16_buffer: Option<wgpu::Buffer>,
pub offsets: HashMap<NodeId, usize>,
pub lens: HashMap<NodeId, usize>,
pub size: usize,
}
pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
use rlx_opt::memory::BufferSlot;
let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
let mut schedule = Vec::with_capacity(graph.nodes().len());
let mut cursor = 0usize;
for node in graph.nodes() {
if matches!(node.op, Op::Reshape { .. } | Op::Cast { .. })
&& let Some(in_id) = node.inputs.first()
&& let Some(slot) = assignments.get(in_id)
{
let aliased = slot.clone();
assignments.insert(node.id, aliased);
schedule.push(node.id);
continue;
}
let elems = node.shape.num_elements().unwrap_or(0);
let bytes = elems * 4;
let aligned = bytes.div_ceil(align) * align;
assignments.insert(
node.id,
BufferSlot {
offset: cursor,
size: aligned,
},
);
schedule.push(node.id);
cursor += aligned;
}
MemoryPlan {
arena_size: cursor,
assignments,
schedule,
}
}
impl Arena {
pub fn from_plan(device: &wgpu::Device, plan: &MemoryPlan) -> Self {
let size = plan.arena_size.max(1); let max_binding = device.limits().max_storage_buffer_binding_size;
if (size as u64) > max_binding {
panic!(
"rlx-wgpu: planned arena size {} bytes ({} GiB) exceeds the \
adapter's max_storage_buffer_binding_size of {} bytes \
({} GiB). This is the WebGPU 32-bit binding offset cap on \
Apple Metal / Vulkan; supporting larger arenas requires \
multi-bind-group partitioning. Workaround: reduce batch \
size or compile with a smaller (batch, seq) shape.",
size,
size as f64 / (1u64 << 30) as f64,
max_binding,
max_binding as f64 / (1u64 << 30) as f64
);
}
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu arena"),
size: size as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let f16_buffer = if device.features().contains(wgpu::Features::SHADER_F16) {
Some(device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu arena f16"),
size: size.div_ceil(2) as u64, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}))
} else {
None
};
let mut offsets = HashMap::with_capacity(plan.assignments.len());
let mut lens = HashMap::with_capacity(plan.assignments.len());
for (id, a) in &plan.assignments {
offsets.insert(*id, a.offset);
lens.insert(*id, a.size);
}
Self {
buffer,
f16_buffer,
offsets,
lens,
size,
}
}
pub fn has(&self, id: NodeId) -> bool {
self.offsets.contains_key(&id)
}
pub fn offset(&self, id: NodeId) -> usize {
self.offsets[&id]
}
pub fn len_of(&self, id: NodeId) -> usize {
self.lens[&id]
}
pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
self.lens.insert(id, bytes);
}
pub fn write_f32(&self, queue: &wgpu::Queue, id: NodeId, data: &[f32]) {
let off = self.offset(id);
let bytes: &[u8] = bytemuck::cast_slice(data);
queue.write_buffer(&self.buffer, off as u64, bytes);
if let Some(f16_buf) = &self.f16_buffer {
let mut f16_data: Vec<half::f16> =
data.iter().map(|&v| half::f16::from_f32(v)).collect();
if !f16_data.len().is_multiple_of(2) {
f16_data.push(half::f16::from_f32(0.0));
}
let f16_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(f16_data.as_ptr() as *const u8, f16_data.len() * 2)
};
queue.write_buffer(f16_buf, (off / 2) as u64, f16_bytes);
}
}
pub fn read_f32(&self, device: &wgpu::Device, queue: &wgpu::Queue, id: NodeId) -> Vec<f32> {
let off = self.offset(id);
let len = self.len_of(id);
let n_elems = len / 4;
if n_elems == 0 {
return Vec::new();
}
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu readback"),
size: len as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rlx-wgpu readback enc"),
});
enc.copy_buffer_to_buffer(&self.buffer, off as u64, &staging, 0, len as u64);
queue.submit(std::iter::once(enc.finish()));
let slice = staging.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
let _ = device.poll(wgpu::PollType::wait_indefinitely());
receiver.recv().unwrap().unwrap();
let view = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&view).to_vec();
drop(view);
staging.unmap();
out
}
pub fn read_bytes_range(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
byte_off: usize,
len: usize,
) -> Vec<u8> {
if len == 0 {
return Vec::new();
}
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rlx-wgpu readback bytes"),
size: len as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rlx-wgpu readback bytes enc"),
});
enc.copy_buffer_to_buffer(&self.buffer, byte_off as u64, &staging, 0, len as u64);
queue.submit(std::iter::once(enc.finish()));
let slice = staging.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
let _ = device.poll(wgpu::PollType::wait_indefinitely());
receiver.recv().unwrap().unwrap();
let view = slice.get_mapped_range();
let out = view.to_vec();
drop(view);
staging.unmap();
out
}
pub fn write_bytes_range(&self, queue: &wgpu::Queue, byte_off: usize, data: &[u8]) {
if data.is_empty() {
return;
}
queue.write_buffer(&self.buffer, byte_off as u64, data);
}
}