rlx_driver/arena.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend-agnostic arena trait — the contract every backend's memory
17//! plan obeys.
18//!
19//! Lifted from CpuExecutable / MetalExecutable's previously duplicated
20//! arena helpers. Each new backend (CUDA, ROCm, wgpu, WASM, TPU) implements
21//! this trait once and gets:
22//! - typed input feed (`f32 → arena_dtype`)
23//! - typed output read (`arena_dtype → f32`)
24//! - per-node byte offset resolution
25//!
26//! The trait deliberately exposes raw pointers / byte offsets rather than
27//! Rust slices so the same implementation works for host-resident memory
28//! (CPU/WASM), unified memory (Apple Silicon Metal/MPSGraph), and
29//! discrete-VRAM backends (CUDA/ROCm) where reading involves a copy.
30
31use rlx_ir::{DType, NodeId};
32
33/// Per-backend arena interface.
34///
35/// All concrete arenas — `rlx-cpu::Arena`, `rlx-metal::Arena`, future
36/// `rlx-cuda::Arena`, `rlx-wgpu::Arena` — implement this trait so the
37/// runtime can drive them uniformly. The actual byte layout is owned
38/// by the backend; we only require offset-based access.
39pub trait DeviceArena {
40 /// Byte offset of `id`'s buffer slot in the arena. `usize::MAX` for
41 /// nodes that don't have an arena slot (e.g. fused-away intermediates).
42 fn byte_offset(&self, id: NodeId) -> usize;
43
44 /// True if `id` has a real arena slot.
45 fn has_buffer(&self, id: NodeId) -> bool;
46
47 /// Total arena size in bytes.
48 fn size_bytes(&self) -> usize;
49
50 /// Write a host-side `f32` slice into `id`'s slot, casting to `dtype`
51 /// if necessary. Truncates to the buffer's capacity (no panic on overflow).
52 ///
53 /// For discrete-memory backends this involves a host→device copy; for
54 /// unified-memory backends (Apple Silicon, integrated GPUs) it's a
55 /// direct write.
56 fn write_input_f32(&mut self, id: NodeId, dtype: DType, data: &[f32]);
57
58 /// Read `id`'s slot as a host-side `Vec<f32>`, casting from `dtype` if
59 /// necessary. The number of elements is determined by the backend
60 /// based on the memory plan (typically `shape.num_elements()`).
61 fn read_output_f32(&self, id: NodeId, dtype: DType, n_elements: usize) -> Vec<f32>;
62}
63
64/// Helper: cast f32 input to bytes of `dtype` and write to `dst_ptr`.
65/// Used by every CPU-resident-arena backend. GPU backends can call this
66/// after staging into a host buffer, then upload.
67///
68/// Currently supports F32 / F16 / BF16. Other dtypes fall through to F32.
69pub unsafe fn write_typed_from_f32(dst_ptr: *mut u8, dtype: DType, src: &[f32], max_elems: usize) {
70 let n = src.len().min(max_elems);
71 match dtype {
72 DType::F16 => unsafe {
73 let dst = dst_ptr as *mut half::f16;
74 for i in 0..n {
75 *dst.add(i) = half::f16::from_f32(src[i]);
76 }
77 },
78 DType::BF16 => unsafe {
79 let dst = dst_ptr as *mut half::bf16;
80 for i in 0..n {
81 *dst.add(i) = half::bf16::from_f32(src[i]);
82 }
83 },
84 _ => unsafe {
85 let dst = dst_ptr as *mut f32;
86 std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n);
87 },
88 }
89}
90
91/// Helper: read `n_elems` of `dtype` from `src_ptr`, returning `Vec<f32>`.
92pub unsafe fn read_typed_to_f32(src_ptr: *const u8, dtype: DType, n_elems: usize) -> Vec<f32> {
93 match dtype {
94 DType::F16 => {
95 let mut out = Vec::with_capacity(n_elems);
96 unsafe {
97 let src = src_ptr as *const half::f16;
98 for i in 0..n_elems {
99 out.push((*src.add(i)).to_f32());
100 }
101 }
102 out
103 }
104 DType::BF16 => {
105 let mut out = Vec::with_capacity(n_elems);
106 unsafe {
107 let src = src_ptr as *const half::bf16;
108 for i in 0..n_elems {
109 out.push((*src.add(i)).to_f32());
110 }
111 }
112 out
113 }
114 _ => unsafe {
115 let src = src_ptr as *const f32;
116 std::slice::from_raw_parts(src, n_elems).to_vec()
117 },
118 }
119}