use rlx_ir::{DType, NodeId};
pub trait DeviceArena {
fn byte_offset(&self, id: NodeId) -> usize;
fn has_buffer(&self, id: NodeId) -> bool;
fn size_bytes(&self) -> usize;
fn write_input_f32(&mut self, id: NodeId, dtype: DType, data: &[f32]);
fn read_output_f32(&self, id: NodeId, dtype: DType, n_elements: usize) -> Vec<f32>;
}
pub unsafe fn write_typed_from_f32(dst_ptr: *mut u8, dtype: DType, src: &[f32], max_elems: usize) {
let n = src.len().min(max_elems);
match dtype {
DType::F16 => unsafe {
let dst = dst_ptr as *mut half::f16;
for i in 0..n {
*dst.add(i) = half::f16::from_f32(src[i]);
}
},
DType::BF16 => unsafe {
let dst = dst_ptr as *mut half::bf16;
for i in 0..n {
*dst.add(i) = half::bf16::from_f32(src[i]);
}
},
_ => unsafe {
let dst = dst_ptr as *mut f32;
std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n);
},
}
}
pub unsafe fn read_typed_to_f32(src_ptr: *const u8, dtype: DType, n_elems: usize) -> Vec<f32> {
match dtype {
DType::F16 => {
let mut out = Vec::with_capacity(n_elems);
unsafe {
let src = src_ptr as *const half::f16;
for i in 0..n_elems {
out.push((*src.add(i)).to_f32());
}
}
out
}
DType::BF16 => {
let mut out = Vec::with_capacity(n_elems);
unsafe {
let src = src_ptr as *const half::bf16;
for i in 0..n_elems {
out.push((*src.add(i)).to_f32());
}
}
out
}
_ => unsafe {
let src = src_ptr as *const f32;
std::slice::from_raw_parts(src, n_elems).to_vec()
},
}
}