use std::collections::HashMap;
use std::mem::ManuallyDrop;
use std::sync::{Arc, Mutex, OnceLock};
use cudarc::driver::{CudaContext, CudaSlice};
use rlx_ir::{Graph, NodeId, Op};
use rlx_opt::memory::{BufferSlot, MemoryPlan};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HalfDtype {
F16,
Bf16,
}
pub struct Arena {
pub buffer: ManuallyDrop<CudaSlice<f32>>,
pub offsets: HashMap<NodeId, usize>,
pub lens: HashMap<NodeId, usize>,
pub size: usize,
pub half_buffer: Option<CudaSlice<u16>>,
pub half_offsets: HashMap<NodeId, (usize, HalfDtype)>,
pub half_by_f32_off: HashMap<u32, (usize, HalfDtype)>,
pub half_size: usize,
}
const F32_ARENA_POOL_CAP: usize = 16;
static F32_ARENA_POOL: OnceLock<Mutex<Vec<(usize, CudaSlice<f32>)>>> = OnceLock::new();
fn f32_arena_pool() -> &'static Mutex<Vec<(usize, CudaSlice<f32>)>> {
F32_ARENA_POOL.get_or_init(|| Mutex::new(Vec::new()))
}
fn pool_acquire_f32(ctx: &Arc<CudaContext>, n_f32: usize) -> CudaSlice<f32> {
let need = n_f32.max(4);
let mut pool = f32_arena_pool()
.lock()
.expect("rlx-cuda: arena pool lock poisoned");
if let Some(idx) = pool.iter().position(|(cap, _)| *cap >= need) {
let (_, buf) = pool.swap_remove(idx);
return buf;
}
drop(pool);
unsafe {
ctx.default_stream()
.alloc(need)
.expect("rlx-cuda: device allocation failed")
}
}
fn pool_release_f32(cap_f32: usize, buffer: CudaSlice<f32>) {
let mut pool = f32_arena_pool()
.lock()
.expect("rlx-cuda: arena pool lock poisoned");
if pool.len() >= F32_ARENA_POOL_CAP {
pool.sort_by_key(|(cap, _)| *cap);
pool.remove(0);
}
pool.push((cap_f32.max(4), buffer));
}
pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
let mut schedule = Vec::with_capacity(graph.nodes().len());
let mut cursor = 0usize;
for node in graph.nodes() {
if matches!(node.op, Op::Reshape { .. } | Op::Cast { .. })
&& let Some(in_id) = node.inputs.first()
&& let Some(slot) = assignments.get(in_id)
{
let aliased = slot.clone();
assignments.insert(node.id, aliased);
schedule.push(node.id);
continue;
}
let elems = node.shape.num_elements().unwrap_or(0);
let bytes = elems * 4;
let aligned = bytes.div_ceil(align) * align;
assignments.insert(
node.id,
BufferSlot {
offset: cursor,
size: aligned,
},
);
schedule.push(node.id);
cursor += aligned;
}
MemoryPlan {
arena_size: cursor,
assignments,
schedule,
}
}
impl Arena {
pub fn from_plan(ctx: &Arc<CudaContext>, plan: &MemoryPlan) -> Self {
let n_f32 = plan.arena_size.div_ceil(4);
let buffer = ManuallyDrop::new(pool_acquire_f32(ctx, n_f32));
let mut offsets = HashMap::new();
let mut lens = HashMap::new();
for (id, slot) in &plan.assignments {
offsets.insert(*id, slot.offset);
lens.insert(*id, slot.size);
}
Self {
buffer,
offsets,
lens,
size: plan.arena_size,
half_buffer: None,
half_offsets: HashMap::new(),
half_by_f32_off: HashMap::new(),
half_size: 0,
}
}
pub fn has(&self, id: NodeId) -> bool {
self.offsets.contains_key(&id)
}
#[inline]
pub fn f32_buf(&self) -> &CudaSlice<f32> {
&self.buffer
}
#[inline]
pub fn f32_buf_mut(&mut self) -> &mut CudaSlice<f32> {
&mut self.buffer
}
#[inline]
pub fn f32_buf_and_size(&mut self) -> (&mut CudaSlice<f32>, usize) {
let size = self.size;
(self.f32_buf_mut(), size)
}
pub fn offset(&self, id: NodeId) -> usize {
self.offsets[&id]
}
pub fn len_of(&self, id: NodeId) -> usize {
self.lens[&id]
}
pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
self.lens.insert(id, bytes);
}
pub fn register_half_param(
&mut self,
ctx: &Arc<CudaContext>,
id: NodeId,
f32_off: u32,
n_elems: usize,
dtype: HalfDtype,
) -> usize {
let off = self.half_size;
self.half_size += n_elems;
self.half_offsets.insert(id, (off, dtype));
self.half_by_f32_off.insert(f32_off, (off, dtype));
let stream = ctx.default_stream();
let new_buf = stream
.alloc_zeros::<u16>(self.half_size.max(4))
.expect("rlx-cuda: half-arena allocation failed");
if let Some(old) = self.half_buffer.take() {
let _ = stream.memcpy_dtod(&old, &mut { new_buf.clone() });
}
self.half_buffer = Some(new_buf);
off
}
pub fn is_half(&self, id: NodeId) -> bool {
self.half_offsets.contains_key(&id)
}
pub fn half_off(&self, id: NodeId) -> Option<(usize, HalfDtype)> {
self.half_offsets.get(&id).copied()
}
}
impl Drop for Arena {
fn drop(&mut self) {
let cap_f32 = self.size.div_ceil(4).max(4);
let buffer = unsafe { ManuallyDrop::take(&mut self.buffer) };
pool_release_f32(cap_f32, buffer);
}
}