Skip to main content

rlx_vulkan/
buffer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! The f32-uniform GPU arena. Like rlx-cuda / rlx-wgpu, every tensor is an
7//! f32 slot at a byte offset in one contiguous buffer. We allocate the
8//! arena as `HOST_VISIBLE | HOST_COHERENT` memory and keep it persistently
9//! mapped, so host upload/readback is a plain `memcpy` with no staging
10//! buffer or transfer command. (On discrete GPUs a `DEVICE_LOCAL` arena +
11//! staging would have higher bandwidth — a documented follow-up; correctness
12//! first.)
13
14use crate::device::{VulkanDevice, vulkan_device};
15use ash::vk;
16use rlx_compile::memory::MemoryPlan;
17use rlx_ir::NodeId;
18use std::collections::HashMap;
19
20pub struct Arena {
21    dev: &'static VulkanDevice,
22    pub buffer: vk::Buffer,
23    memory: vk::DeviceMemory,
24    /// Total arena size in bytes.
25    pub size: usize,
26    /// Persistent host mapping of the whole arena.
27    mapped: *mut u8,
28    /// Per-node byte offset into the arena.
29    offsets: HashMap<NodeId, usize>,
30    /// Per-node slot byte length (capacity, ≥ used).
31    lens: HashMap<NodeId, usize>,
32}
33
34// The mapped pointer is only used behind `&mut self` writes / `&self` reads
35// on a single executable at a time; the executable itself is not `Sync`.
36unsafe impl Send for Arena {}
37
38impl Arena {
39    pub fn from_plan(plan: &MemoryPlan) -> Self {
40        let dev = vulkan_device().expect("rlx-vulkan: no device for arena");
41        let size = plan.arena_size.max(4);
42        if std::env::var("RLX_VULKAN_ARENA_DEBUG").ok().as_deref() == Some("1") {
43            eprintln!(
44                "[rlx-vulkan arena] {:.2} GiB ({} bytes)",
45                size as f64 / (1u64 << 30) as f64,
46                size
47            );
48        }
49
50        let info = vk::BufferCreateInfo::default()
51            .size(size as u64)
52            .usage(
53                vk::BufferUsageFlags::STORAGE_BUFFER
54                    | vk::BufferUsageFlags::TRANSFER_SRC
55                    | vk::BufferUsageFlags::TRANSFER_DST,
56            )
57            .sharing_mode(vk::SharingMode::EXCLUSIVE);
58        let buffer = unsafe { dev.device.create_buffer(&info, None) }.expect("vk create_buffer");
59
60        let req = unsafe { dev.device.get_buffer_memory_requirements(buffer) };
61        let mem_type = dev
62            .find_memory_type(
63                req.memory_type_bits,
64                vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT,
65            )
66            .expect("rlx-vulkan: no HOST_VISIBLE|HOST_COHERENT memory type");
67        let memory = unsafe {
68            dev.device.allocate_memory(
69                &vk::MemoryAllocateInfo::default()
70                    .allocation_size(req.size)
71                    .memory_type_index(mem_type),
72                None,
73            )
74        }
75        .expect("vk allocate_memory");
76        unsafe { dev.device.bind_buffer_memory(buffer, memory, 0) }.expect("vk bind_buffer_memory");
77
78        let mapped = unsafe {
79            dev.device
80                .map_memory(memory, 0, req.size, vk::MemoryMapFlags::empty())
81        }
82        .expect("vk map_memory") as *mut u8;
83        // A fresh arena is zero-initialized (scratch slots rely on it).
84        unsafe { std::ptr::write_bytes(mapped, 0, size) };
85
86        let mut offsets = HashMap::new();
87        let mut lens = HashMap::new();
88        for (id, slot) in &plan.assignments {
89            offsets.insert(*id, slot.offset);
90            lens.insert(*id, slot.size);
91        }
92
93        Self {
94            dev,
95            buffer,
96            memory,
97            size,
98            mapped,
99            offsets,
100            lens,
101        }
102    }
103
104    #[inline]
105    pub fn has(&self, id: NodeId) -> bool {
106        self.offsets.contains_key(&id)
107    }
108
109    /// Byte offset of a node's slot.
110    #[inline]
111    pub fn byte_offset(&self, id: NodeId) -> usize {
112        self.offsets[&id]
113    }
114
115    /// f32-element offset of a node's slot (for push constants).
116    #[inline]
117    pub fn elem_offset(&self, id: NodeId) -> u32 {
118        (self.offsets[&id] / 4) as u32
119    }
120
121    /// Slot capacity in f32 elements.
122    #[inline]
123    pub fn slot_elems(&self, id: NodeId) -> usize {
124        self.lens.get(&id).copied().unwrap_or(0) / 4
125    }
126
127    /// Upload f32 data into a node's slot (clamped to the slot capacity).
128    pub fn write_f32(&self, id: NodeId, data: &[f32]) {
129        let Some(&off) = self.offsets.get(&id) else {
130            return;
131        };
132        let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
133        let n = data.len().min(cap);
134        unsafe {
135            let dst = self.mapped.add(off) as *mut f32;
136            std::ptr::copy_nonoverlapping(data.as_ptr(), dst, n);
137        }
138    }
139
140    /// Upload raw bytes into a node's slot (for non-f32 packed params).
141    pub fn write_bytes(&self, id: NodeId, data: &[u8]) {
142        let Some(&off) = self.offsets.get(&id) else {
143            return;
144        };
145        let cap = self.lens.get(&id).copied().unwrap_or(0);
146        let n = data.len().min(cap);
147        unsafe {
148            std::ptr::copy_nonoverlapping(data.as_ptr(), self.mapped.add(off), n);
149        }
150    }
151
152    /// Read `n` f32 elements from a node's slot.
153    pub fn read_f32(&self, id: NodeId, n: usize) -> Vec<f32> {
154        let Some(&off) = self.offsets.get(&id) else {
155            return vec![0.0; n];
156        };
157        let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
158        let n = n.min(cap);
159        let mut out = vec![0.0f32; n];
160        unsafe {
161            let src = self.mapped.add(off) as *const f32;
162            std::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), n);
163        }
164        out
165    }
166
167    /// Byte-for-byte copy this arena's contents into `dst` (same plan/size).
168    /// Used by `clone_for_cache` to carry params + constants into a twin.
169    pub fn copy_into(&self, dst: &Arena) {
170        let n = self.size.min(dst.size);
171        unsafe {
172            std::ptr::copy_nonoverlapping(self.mapped, dst.mapped, n);
173        }
174    }
175
176    /// Read `nbytes` raw bytes from a node's slot (packed quant weights).
177    pub fn read_bytes(&self, id: NodeId, nbytes: usize) -> Vec<u8> {
178        let Some(&off) = self.offsets.get(&id) else {
179            return vec![0u8; nbytes];
180        };
181        let cap = self.lens.get(&id).copied().unwrap_or(0);
182        let n = nbytes.min(cap);
183        let mut out = vec![0u8; nbytes];
184        unsafe {
185            std::ptr::copy_nonoverlapping(self.mapped.add(off), out.as_mut_ptr(), n);
186        }
187        out
188    }
189
190    /// In-arena copy of `n` f32 elements from `src`'s slot into `dst`'s slot,
191    /// clamped to both slot capacities. Used by the GPU-resident K/V feed to
192    /// fold a decode step's new-token K/V output back into the `past_k_*` input
193    /// slot without a host round-trip. The arena is HOST_COHERENT and the GPU
194    /// queue is idle by the time this runs (see `submit_and_wait` in `run`), so
195    /// a plain mapped `memcpy` is safe and visible to the next dispatch.
196    pub fn copy_node_f32_prefix(&self, dst: NodeId, src: NodeId, n: usize) {
197        let (Some(&doff), Some(&soff)) = (self.offsets.get(&dst), self.offsets.get(&src)) else {
198            return;
199        };
200        if doff == soff {
201            return; // aliased slot — nothing to do
202        }
203        let dcap = self.lens.get(&dst).copied().unwrap_or(0) / 4;
204        let scap = self.lens.get(&src).copied().unwrap_or(0) / 4;
205        let n = n.min(dcap).min(scap);
206        if n == 0 {
207            return;
208        }
209        unsafe {
210            let src_p = self.mapped.add(soff) as *const f32;
211            let dst_p = self.mapped.add(doff) as *mut f32;
212            std::ptr::copy_nonoverlapping(src_p, dst_p, n);
213        }
214    }
215
216    /// In-arena copy of `n` f32 elements from `src` slot (starting at element
217    /// offset `src_elem`) into `dst` slot (starting at `dst_elem`), clamped to
218    /// both slot capacities. Used by the decode K/V feed to drop a single new
219    /// token row (output row `upper`) into the resident `past_k_*` slot at the
220    /// active row — without disturbing the already-resident prefix.
221    pub fn copy_node_f32_range(
222        &self,
223        dst: NodeId,
224        dst_elem: usize,
225        src: NodeId,
226        src_elem: usize,
227        n: usize,
228    ) {
229        let (Some(&doff), Some(&soff)) = (self.offsets.get(&dst), self.offsets.get(&src)) else {
230            return;
231        };
232        let dcap = self.lens.get(&dst).copied().unwrap_or(0) / 4;
233        let scap = self.lens.get(&src).copied().unwrap_or(0) / 4;
234        if dst_elem + n > dcap || src_elem + n > scap || n == 0 {
235            return;
236        }
237        let dbyte = doff + dst_elem * 4;
238        let sbyte = soff + src_elem * 4;
239        if dbyte == sbyte {
240            return;
241        }
242        unsafe {
243            let src_p = self.mapped.add(sbyte) as *const f32;
244            let dst_p = self.mapped.add(dbyte) as *mut f32;
245            std::ptr::copy_nonoverlapping(src_p, dst_p, n);
246        }
247    }
248
249    /// Read `n` f32 elements starting at an arbitrary f32-element offset.
250    pub fn read_f32_at_elem(&self, elem_off: usize, n: usize) -> Vec<f32> {
251        let mut out = vec![0.0f32; n];
252        let byte_off = elem_off * 4;
253        if byte_off + n * 4 > self.size {
254            return out;
255        }
256        unsafe {
257            let src = self.mapped.add(byte_off) as *const f32;
258            std::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), n);
259        }
260        out
261    }
262}
263
264impl Drop for Arena {
265    fn drop(&mut self) {
266        unsafe {
267            self.dev.device.unmap_memory(self.memory);
268            self.dev.device.destroy_buffer(self.buffer, None);
269            self.dev.device.free_memory(self.memory, None);
270        }
271    }
272}