1use crate::device::{OneApiDevice, oneapi_device};
15use rlx_compile::memory::MemoryPlan;
16use rlx_ir::NodeId;
17use std::collections::HashMap;
18
19pub struct Arena {
20 dev: &'static OneApiDevice,
21 base: *mut std::ffi::c_void,
22 pub size: usize,
23 offsets: HashMap<NodeId, usize>,
24 lens: HashMap<NodeId, usize>,
25}
26
27unsafe impl Send for Arena {}
30
31impl Arena {
32 pub fn from_plan(plan: &MemoryPlan) -> Result<Self, String> {
33 let dev = oneapi_device().ok_or("rlx-oneapi: no device for arena")?;
34 let size = plan.arena_size.max(4);
35 let base = dev.alloc_shared(size)?;
36 let mut offsets = HashMap::new();
37 let mut lens = HashMap::new();
38 for (id, slot) in &plan.assignments {
39 offsets.insert(*id, slot.offset);
40 lens.insert(*id, slot.size);
41 }
42 Ok(Self {
43 dev,
44 base,
45 size,
46 offsets,
47 lens,
48 })
49 }
50
51 #[inline]
52 pub fn has(&self, id: NodeId) -> bool {
53 self.offsets.contains_key(&id)
54 }
55
56 #[inline]
58 pub fn elem_offset(&self, id: NodeId) -> u32 {
59 (self.offsets[&id] / 4) as u32
60 }
61
62 #[inline]
64 pub fn base_ptr(&self) -> *mut std::ffi::c_void {
65 self.base
66 }
67
68 pub fn write_f32(&self, id: NodeId, data: &[f32]) {
69 let Some(&off) = self.offsets.get(&id) else {
70 return;
71 };
72 let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
73 let n = data.len().min(cap);
74 unsafe {
75 let dst = (self.base as *mut u8).add(off) as *mut f32;
76 std::ptr::copy_nonoverlapping(data.as_ptr(), dst, n);
77 }
78 }
79
80 pub fn write_bytes(&self, id: NodeId, data: &[u8]) {
81 let Some(&off) = self.offsets.get(&id) else {
82 return;
83 };
84 let cap = self.lens.get(&id).copied().unwrap_or(0);
85 let n = data.len().min(cap);
86 unsafe {
87 std::ptr::copy_nonoverlapping(data.as_ptr(), (self.base as *mut u8).add(off), n);
88 }
89 }
90
91 pub fn read_f32(&self, id: NodeId, n: usize) -> Vec<f32> {
92 let Some(&off) = self.offsets.get(&id) else {
93 return vec![0.0; n];
94 };
95 let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
96 let n = n.min(cap);
97 let mut out = vec![0.0f32; n];
98 unsafe {
99 let src = (self.base as *const u8).add(off) as *const f32;
100 std::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), n);
101 }
102 out
103 }
104
105 pub fn read_bytes(&self, id: NodeId, nbytes: usize) -> Vec<u8> {
106 let Some(&off) = self.offsets.get(&id) else {
107 return vec![0u8; nbytes];
108 };
109 let cap = self.lens.get(&id).copied().unwrap_or(0);
110 let n = nbytes.min(cap);
111 let mut out = vec![0u8; nbytes];
112 unsafe {
113 std::ptr::copy_nonoverlapping((self.base as *const u8).add(off), out.as_mut_ptr(), n);
114 }
115 out
116 }
117}
118
119impl Drop for Arena {
120 fn drop(&mut self) {
121 let _ = &self.dev;
122 self.dev.free(self.base);
123 }
124}