Skip to main content

rlx_cpu/
arena.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Arena allocator — ONE allocation, zero per-call overhead.
17//!
18//! The memory planner computes the total arena size and per-buffer offsets
19//! at compile time. At runtime, the arena is allocated once and slices
20//! are handed out by offset. Between forward calls, just reset the
21//! generation counter — no deallocation, no reallocation.
22
23use rlx_ir::NodeId;
24use rlx_opt::memory::MemoryPlan;
25
26/// Pre-allocated memory arena for graph execution.
27#[derive(Clone)]
28pub struct Arena {
29    buf: Vec<u8>,
30    plan: MemoryPlan,
31}
32
33impl Arena {
34    /// Allocate arena from a memory plan.
35    pub fn from_plan(plan: MemoryPlan) -> Self {
36        let buf = vec![0u8; plan.arena_size];
37        Self { buf, plan }
38    }
39
40    /// Total arena size in bytes.
41    pub fn size(&self) -> usize {
42        self.plan.arena_size
43    }
44
45    /// Get a mutable f32 slice for a node's buffer.
46    ///
47    /// # Panics
48    /// Panics if the node has no buffer assignment.
49    pub fn slice_mut(&mut self, id: NodeId) -> &mut [f32] {
50        let slot = self
51            .plan
52            .assignments
53            .get(&id)
54            .unwrap_or_else(|| panic!("no buffer for {id}"));
55        let bytes = &mut self.buf[slot.offset..slot.offset + slot.size];
56        // SAFETY: buf is aligned to at least 1, but we need f32 alignment.
57        // The memory planner aligns to 64 bytes, so this is safe.
58        unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, slot.size / 4) }
59    }
60
61    /// Get a read-only f32 slice for a node's buffer.
62    pub fn slice(&self, id: NodeId) -> &[f32] {
63        let slot = self
64            .plan
65            .assignments
66            .get(&id)
67            .unwrap_or_else(|| panic!("no buffer for {id}"));
68        let bytes = &self.buf[slot.offset..slot.offset + slot.size];
69        unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, slot.size / 4) }
70    }
71
72    /// Get a mutable f64 slice for a node's buffer.
73    ///
74    /// # Panics
75    /// Panics if the node has no buffer assignment, or if the slot's
76    /// byte size is not 8-aligned.
77    pub fn slice_mut_f64(&mut self, id: NodeId) -> &mut [f64] {
78        let slot = self
79            .plan
80            .assignments
81            .get(&id)
82            .unwrap_or_else(|| panic!("no buffer for {id}"));
83        debug_assert!(
84            slot.size.is_multiple_of(8),
85            "slice_mut_f64: slot {} has size {} not divisible by 8",
86            id,
87            slot.size
88        );
89        let bytes = &mut self.buf[slot.offset..slot.offset + slot.size];
90        // SAFETY: planner aligns slots to 64 bytes ⇒ f64-aligned.
91        unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f64, slot.size / 8) }
92    }
93
94    /// Get a read-only f64 slice for a node's buffer.
95    pub fn slice_f64(&self, id: NodeId) -> &[f64] {
96        let slot = self
97            .plan
98            .assignments
99            .get(&id)
100            .unwrap_or_else(|| panic!("no buffer for {id}"));
101        debug_assert!(
102            slot.size.is_multiple_of(8),
103            "slice_f64: slot {} has size {} not divisible by 8",
104            id,
105            slot.size
106        );
107        let bytes = &self.buf[slot.offset..slot.offset + slot.size];
108        unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f64, slot.size / 8) }
109    }
110
111    /// Check if a node has a buffer assignment.
112    pub fn has_buffer(&self, id: NodeId) -> bool {
113        self.plan.assignments.contains_key(&id)
114    }
115
116    /// Get a raw pointer + length for a node's buffer.
117    /// SAFETY: caller must ensure no aliasing writes to the same buffer.
118    pub fn raw_ptr(&self, id: NodeId) -> (*mut f32, usize) {
119        let slot = self
120            .plan
121            .assignments
122            .get(&id)
123            .unwrap_or_else(|| panic!("no buffer for {id}"));
124        let ptr = unsafe { self.buf.as_ptr().add(slot.offset) as *mut f32 };
125        (ptr, slot.size / 4)
126    }
127
128    /// The execution schedule from the memory plan.
129    pub fn schedule(&self) -> &[NodeId] {
130        &self.plan.schedule
131    }
132
133    /// Byte offset of a node's buffer within the arena.
134    pub fn byte_offset(&self, id: NodeId) -> usize {
135        self.plan
136            .assignments
137            .get(&id)
138            .map(|s| s.offset)
139            .unwrap_or(usize::MAX)
140    }
141
142    /// Raw mutable access to the arena buffer (for thunk executor).
143    pub fn raw_buf_mut(&mut self) -> &mut [u8] {
144        &mut self.buf
145    }
146
147    /// Read-only access to the arena buffer (for typed reads).
148    pub fn raw_buf(&self) -> &[u8] {
149        &self.buf
150    }
151
152    /// Raw pointer to arena start (for zero-copy output reads).
153    pub fn raw_buf_mut_ptr(&self) -> *const u8 {
154        self.buf.as_ptr()
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use rlx_opt::memory::BufferSlot;
162    use std::collections::HashMap;
163
164    #[test]
165    fn arena_slice_access() {
166        let plan = MemoryPlan {
167            arena_size: 1024,
168            assignments: {
169                let mut m = HashMap::new();
170                m.insert(
171                    NodeId(0),
172                    BufferSlot {
173                        offset: 0,
174                        size: 256,
175                    },
176                );
177                m.insert(
178                    NodeId(1),
179                    BufferSlot {
180                        offset: 256,
181                        size: 512,
182                    },
183                );
184                m
185            },
186            schedule: vec![NodeId(0), NodeId(1)],
187        };
188
189        let mut arena = Arena::from_plan(plan);
190        let s0 = arena.slice_mut(NodeId(0));
191        assert_eq!(s0.len(), 64); // 256 bytes / 4 bytes per f32
192        s0[0] = 42.0;
193
194        let s1 = arena.slice_mut(NodeId(1));
195        assert_eq!(s1.len(), 128); // 512 / 4
196
197        // s0's data persists
198        let s0_read = arena.slice(NodeId(0));
199        assert_eq!(s0_read[0], 42.0);
200    }
201}