1use crate::device::{VulkanDevice, vulkan_device};
15use ash::vk;
16use rlx_compile::memory::MemoryPlan;
17use rlx_ir::NodeId;
18use std::collections::HashMap;
19
20pub struct Arena {
21 dev: &'static VulkanDevice,
22 pub buffer: vk::Buffer,
23 memory: vk::DeviceMemory,
24 pub size: usize,
26 mapped: *mut u8,
28 offsets: HashMap<NodeId, usize>,
30 lens: HashMap<NodeId, usize>,
32}
33
34unsafe impl Send for Arena {}
37
38impl Arena {
39 pub fn from_plan(plan: &MemoryPlan) -> Self {
40 let dev = vulkan_device().expect("rlx-vulkan: no device for arena");
41 let size = plan.arena_size.max(4);
42 if std::env::var("RLX_VULKAN_ARENA_DEBUG").ok().as_deref() == Some("1") {
43 eprintln!(
44 "[rlx-vulkan arena] {:.2} GiB ({} bytes)",
45 size as f64 / (1u64 << 30) as f64,
46 size
47 );
48 }
49
50 let info = vk::BufferCreateInfo::default()
51 .size(size as u64)
52 .usage(
53 vk::BufferUsageFlags::STORAGE_BUFFER
54 | vk::BufferUsageFlags::TRANSFER_SRC
55 | vk::BufferUsageFlags::TRANSFER_DST,
56 )
57 .sharing_mode(vk::SharingMode::EXCLUSIVE);
58 let buffer = unsafe { dev.device.create_buffer(&info, None) }.expect("vk create_buffer");
59
60 let req = unsafe { dev.device.get_buffer_memory_requirements(buffer) };
61 let mem_type = dev
62 .find_memory_type(
63 req.memory_type_bits,
64 vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT,
65 )
66 .expect("rlx-vulkan: no HOST_VISIBLE|HOST_COHERENT memory type");
67 let memory = unsafe {
68 dev.device.allocate_memory(
69 &vk::MemoryAllocateInfo::default()
70 .allocation_size(req.size)
71 .memory_type_index(mem_type),
72 None,
73 )
74 }
75 .expect("vk allocate_memory");
76 unsafe { dev.device.bind_buffer_memory(buffer, memory, 0) }.expect("vk bind_buffer_memory");
77
78 let mapped = unsafe {
79 dev.device
80 .map_memory(memory, 0, req.size, vk::MemoryMapFlags::empty())
81 }
82 .expect("vk map_memory") as *mut u8;
83 unsafe { std::ptr::write_bytes(mapped, 0, size) };
85
86 let mut offsets = HashMap::new();
87 let mut lens = HashMap::new();
88 for (id, slot) in &plan.assignments {
89 offsets.insert(*id, slot.offset);
90 lens.insert(*id, slot.size);
91 }
92
93 Self {
94 dev,
95 buffer,
96 memory,
97 size,
98 mapped,
99 offsets,
100 lens,
101 }
102 }
103
104 #[inline]
105 pub fn has(&self, id: NodeId) -> bool {
106 self.offsets.contains_key(&id)
107 }
108
109 #[inline]
111 pub fn byte_offset(&self, id: NodeId) -> usize {
112 self.offsets[&id]
113 }
114
115 #[inline]
117 pub fn elem_offset(&self, id: NodeId) -> u32 {
118 (self.offsets[&id] / 4) as u32
119 }
120
121 #[inline]
123 pub fn slot_elems(&self, id: NodeId) -> usize {
124 self.lens.get(&id).copied().unwrap_or(0) / 4
125 }
126
127 pub fn write_f32(&self, id: NodeId, data: &[f32]) {
129 let Some(&off) = self.offsets.get(&id) else {
130 return;
131 };
132 let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
133 let n = data.len().min(cap);
134 unsafe {
135 let dst = self.mapped.add(off) as *mut f32;
136 std::ptr::copy_nonoverlapping(data.as_ptr(), dst, n);
137 }
138 }
139
140 pub fn write_bytes(&self, id: NodeId, data: &[u8]) {
142 let Some(&off) = self.offsets.get(&id) else {
143 return;
144 };
145 let cap = self.lens.get(&id).copied().unwrap_or(0);
146 let n = data.len().min(cap);
147 unsafe {
148 std::ptr::copy_nonoverlapping(data.as_ptr(), self.mapped.add(off), n);
149 }
150 }
151
152 pub fn read_f32(&self, id: NodeId, n: usize) -> Vec<f32> {
154 let Some(&off) = self.offsets.get(&id) else {
155 return vec![0.0; n];
156 };
157 let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
158 let n = n.min(cap);
159 let mut out = vec![0.0f32; n];
160 unsafe {
161 let src = self.mapped.add(off) as *const f32;
162 std::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), n);
163 }
164 out
165 }
166
167 pub fn copy_into(&self, dst: &Arena) {
170 let n = self.size.min(dst.size);
171 unsafe {
172 std::ptr::copy_nonoverlapping(self.mapped, dst.mapped, n);
173 }
174 }
175
176 pub fn read_bytes(&self, id: NodeId, nbytes: usize) -> Vec<u8> {
178 let Some(&off) = self.offsets.get(&id) else {
179 return vec![0u8; nbytes];
180 };
181 let cap = self.lens.get(&id).copied().unwrap_or(0);
182 let n = nbytes.min(cap);
183 let mut out = vec![0u8; nbytes];
184 unsafe {
185 std::ptr::copy_nonoverlapping(self.mapped.add(off), out.as_mut_ptr(), n);
186 }
187 out
188 }
189
190 pub fn copy_node_f32_prefix(&self, dst: NodeId, src: NodeId, n: usize) {
197 let (Some(&doff), Some(&soff)) = (self.offsets.get(&dst), self.offsets.get(&src)) else {
198 return;
199 };
200 if doff == soff {
201 return; }
203 let dcap = self.lens.get(&dst).copied().unwrap_or(0) / 4;
204 let scap = self.lens.get(&src).copied().unwrap_or(0) / 4;
205 let n = n.min(dcap).min(scap);
206 if n == 0 {
207 return;
208 }
209 unsafe {
210 let src_p = self.mapped.add(soff) as *const f32;
211 let dst_p = self.mapped.add(doff) as *mut f32;
212 std::ptr::copy_nonoverlapping(src_p, dst_p, n);
213 }
214 }
215
216 pub fn copy_node_f32_range(
222 &self,
223 dst: NodeId,
224 dst_elem: usize,
225 src: NodeId,
226 src_elem: usize,
227 n: usize,
228 ) {
229 let (Some(&doff), Some(&soff)) = (self.offsets.get(&dst), self.offsets.get(&src)) else {
230 return;
231 };
232 let dcap = self.lens.get(&dst).copied().unwrap_or(0) / 4;
233 let scap = self.lens.get(&src).copied().unwrap_or(0) / 4;
234 if dst_elem + n > dcap || src_elem + n > scap || n == 0 {
235 return;
236 }
237 let dbyte = doff + dst_elem * 4;
238 let sbyte = soff + src_elem * 4;
239 if dbyte == sbyte {
240 return;
241 }
242 unsafe {
243 let src_p = self.mapped.add(sbyte) as *const f32;
244 let dst_p = self.mapped.add(dbyte) as *mut f32;
245 std::ptr::copy_nonoverlapping(src_p, dst_p, n);
246 }
247 }
248
249 pub fn read_f32_at_elem(&self, elem_off: usize, n: usize) -> Vec<f32> {
251 let mut out = vec![0.0f32; n];
252 let byte_off = elem_off * 4;
253 if byte_off + n * 4 > self.size {
254 return out;
255 }
256 unsafe {
257 let src = self.mapped.add(byte_off) as *const f32;
258 std::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), n);
259 }
260 out
261 }
262}
263
264impl Drop for Arena {
265 fn drop(&mut self) {
266 unsafe {
267 self.dev.device.unmap_memory(self.memory);
268 self.dev.device.destroy_buffer(self.buffer, None);
269 self.dev.device.free_memory(self.memory, None);
270 }
271 }
272}