Skip to main content

rlx_driver/
arena.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend-agnostic arena trait — the contract every backend's memory
17//! plan obeys.
18//!
19//! Lifted from CpuExecutable / MetalExecutable's previously duplicated
20//! arena helpers. Each new backend (CUDA, ROCm, wgpu, WASM, TPU) implements
21//! this trait once and gets:
22//!   - typed input feed (`f32 → arena_dtype`)
23//!   - typed output read (`arena_dtype → f32`)
24//!   - per-node byte offset resolution
25//!
26//! The trait deliberately exposes raw pointers / byte offsets rather than
27//! Rust slices so the same implementation works for host-resident memory
28//! (CPU/WASM), unified memory (Apple Silicon Metal/MPSGraph), and
29//! discrete-VRAM backends (CUDA/ROCm) where reading involves a copy.
30
31use rlx_ir::{DType, NodeId};
32
33/// Per-backend arena interface.
34///
35/// All concrete arenas — `rlx-cpu::Arena`, `rlx-metal::Arena`, future
36/// `rlx-cuda::Arena`, `rlx-wgpu::Arena` — implement this trait so the
37/// runtime can drive them uniformly. The actual byte layout is owned
38/// by the backend; we only require offset-based access.
39pub trait DeviceArena {
40    /// Byte offset of `id`'s buffer slot in the arena. `usize::MAX` for
41    /// nodes that don't have an arena slot (e.g. fused-away intermediates).
42    fn byte_offset(&self, id: NodeId) -> usize;
43
44    /// True if `id` has a real arena slot.
45    fn has_buffer(&self, id: NodeId) -> bool;
46
47    /// Total arena size in bytes.
48    fn size_bytes(&self) -> usize;
49
50    /// Write a host-side `f32` slice into `id`'s slot, casting to `dtype`
51    /// if necessary. Truncates to the buffer's capacity (no panic on overflow).
52    ///
53    /// For discrete-memory backends this involves a host→device copy; for
54    /// unified-memory backends (Apple Silicon, integrated GPUs) it's a
55    /// direct write.
56    fn write_input_f32(&mut self, id: NodeId, dtype: DType, data: &[f32]);
57
58    /// Read `id`'s slot as a host-side `Vec<f32>`, casting from `dtype` if
59    /// necessary. The number of elements is determined by the backend
60    /// based on the memory plan (typically `shape.num_elements()`).
61    fn read_output_f32(&self, id: NodeId, dtype: DType, n_elements: usize) -> Vec<f32>;
62}
63
64/// Helper: cast f32 input to bytes of `dtype` and write to `dst_ptr`.
65/// Used by every CPU-resident-arena backend. GPU backends can call this
66/// after staging into a host buffer, then upload.
67///
68/// Currently supports F32 / F16 / BF16. Other dtypes fall through to F32.
69pub unsafe fn write_typed_from_f32(dst_ptr: *mut u8, dtype: DType, src: &[f32], max_elems: usize) {
70    let n = src.len().min(max_elems);
71    match dtype {
72        DType::F16 => unsafe {
73            let dst = dst_ptr as *mut half::f16;
74            for i in 0..n {
75                *dst.add(i) = half::f16::from_f32(src[i]);
76            }
77        },
78        DType::BF16 => unsafe {
79            let dst = dst_ptr as *mut half::bf16;
80            for i in 0..n {
81                *dst.add(i) = half::bf16::from_f32(src[i]);
82            }
83        },
84        _ => unsafe {
85            let dst = dst_ptr as *mut f32;
86            std::ptr::copy_nonoverlapping(src.as_ptr(), dst, n);
87        },
88    }
89}
90
91/// Helper: read `n_elems` of `dtype` from `src_ptr`, returning `Vec<f32>`.
92pub unsafe fn read_typed_to_f32(src_ptr: *const u8, dtype: DType, n_elems: usize) -> Vec<f32> {
93    match dtype {
94        DType::F16 => {
95            let mut out = Vec::with_capacity(n_elems);
96            unsafe {
97                let src = src_ptr as *const half::f16;
98                for i in 0..n_elems {
99                    out.push((*src.add(i)).to_f32());
100                }
101            }
102            out
103        }
104        DType::BF16 => {
105            let mut out = Vec::with_capacity(n_elems);
106            unsafe {
107                let src = src_ptr as *const half::bf16;
108                for i in 0..n_elems {
109                    out.push((*src.add(i)).to_f32());
110                }
111            }
112            out
113        }
114        _ => unsafe {
115            let src = src_ptr as *const f32;
116            std::slice::from_raw_parts(src, n_elems).to_vec()
117        },
118    }
119}