use std::collections::HashMap;
use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaSlice};
use rlx_ir::{Graph, NodeId, Op};
use rlx_opt::memory::{BufferSlot, MemoryPlan};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HalfDtype {
F16,
Bf16,
}
pub struct Arena {
pub buffer: CudaSlice<f32>,
pub offsets: HashMap<NodeId, usize>,
pub lens: HashMap<NodeId, usize>,
pub size: usize,
pub half_buffer: Option<CudaSlice<u16>>,
pub half_offsets: HashMap<NodeId, (usize, HalfDtype)>,
pub half_by_f32_off: HashMap<u32, (usize, HalfDtype)>,
pub half_size: usize,
}
pub fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
let mut schedule = Vec::with_capacity(graph.nodes().len());
let mut cursor = 0usize;
for node in graph.nodes() {
if matches!(node.op, Op::Reshape { .. } | Op::Cast { .. })
&& let Some(in_id) = node.inputs.first()
&& let Some(slot) = assignments.get(in_id)
{
let aliased = slot.clone();
assignments.insert(node.id, aliased);
schedule.push(node.id);
continue;
}
let elems = node.shape.num_elements().unwrap_or(0);
let bytes = elems * 4;
let aligned = bytes.div_ceil(align) * align;
assignments.insert(
node.id,
BufferSlot {
offset: cursor,
size: aligned,
},
);
schedule.push(node.id);
cursor += aligned;
}
MemoryPlan {
arena_size: cursor,
assignments,
schedule,
}
}
impl Arena {
pub fn from_plan(ctx: &Arc<CudaContext>, plan: &MemoryPlan) -> Self {
let n_f32 = plan.arena_size.div_ceil(4);
let stream = ctx.default_stream();
let buffer = stream
.alloc_zeros::<f32>(n_f32.max(4))
.expect("rlx-cuda: device allocation failed");
let mut offsets = HashMap::new();
let mut lens = HashMap::new();
for (id, slot) in &plan.assignments {
offsets.insert(*id, slot.offset);
lens.insert(*id, slot.size);
}
Self {
buffer,
offsets,
lens,
size: plan.arena_size,
half_buffer: None,
half_offsets: HashMap::new(),
half_by_f32_off: HashMap::new(),
half_size: 0,
}
}
pub fn has(&self, id: NodeId) -> bool {
self.offsets.contains_key(&id)
}
pub fn offset(&self, id: NodeId) -> usize {
self.offsets[&id]
}
pub fn len_of(&self, id: NodeId) -> usize {
self.lens[&id]
}
pub fn set_actual_len(&mut self, id: NodeId, bytes: usize) {
self.lens.insert(id, bytes);
}
pub fn register_half_param(
&mut self,
ctx: &Arc<CudaContext>,
id: NodeId,
f32_off: u32,
n_elems: usize,
dtype: HalfDtype,
) -> usize {
let off = self.half_size;
self.half_size += n_elems;
self.half_offsets.insert(id, (off, dtype));
self.half_by_f32_off.insert(f32_off, (off, dtype));
let stream = ctx.default_stream();
let new_buf = stream
.alloc_zeros::<u16>(self.half_size.max(4))
.expect("rlx-cuda: half-arena allocation failed");
if let Some(old) = self.half_buffer.take() {
let _ = stream.memcpy_dtod(&old, &mut { new_buf.clone() });
}
self.half_buffer = Some(new_buf);
off
}
pub fn is_half(&self, id: NodeId) -> bool {
self.half_offsets.contains_key(&id)
}
pub fn half_off(&self, id: NodeId) -> Option<(usize, HalfDtype)> {
self.half_offsets.get(&id).copied()
}
}