Skip to main content

rlx_oneapi/
arena.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! The f32-uniform USM-shared GPU arena for the Level Zero dispatch path. Like
7//! rlx-vulkan's host-visible arena, every tensor is an f32 slot at a byte
8//! offset in one contiguous buffer; here the buffer is a single
9//! `zeMemAllocShared` allocation, which is CPU-dereferenceable on Intel's
10//! shared-memory GPUs, so host upload/readback and the CPU host-fallback are
11//! plain pointer writes with no staging. Only constructed when a live device is
12//! present (the dev-box path uses the value-map interpreter in `backend.rs`).
13
14use crate::device::{OneApiDevice, oneapi_device};
15use rlx_compile::memory::MemoryPlan;
16use rlx_ir::NodeId;
17use std::collections::HashMap;
18
19pub struct Arena {
20    dev: &'static OneApiDevice,
21    base: *mut std::ffi::c_void,
22    pub size: usize,
23    offsets: HashMap<NodeId, usize>,
24    lens: HashMap<NodeId, usize>,
25}
26
27// The USM pointer is only used behind `&mut self` writes / `&self` reads on a
28// single executable at a time; the executable itself is not `Sync`.
29unsafe impl Send for Arena {}
30
31impl Arena {
32    pub fn from_plan(plan: &MemoryPlan) -> Result<Self, String> {
33        let dev = oneapi_device().ok_or("rlx-oneapi: no device for arena")?;
34        let size = plan.arena_size.max(4);
35        let base = dev.alloc_shared(size)?;
36        let mut offsets = HashMap::new();
37        let mut lens = HashMap::new();
38        for (id, slot) in &plan.assignments {
39            offsets.insert(*id, slot.offset);
40            lens.insert(*id, slot.size);
41        }
42        Ok(Self {
43            dev,
44            base,
45            size,
46            offsets,
47            lens,
48        })
49    }
50
51    #[inline]
52    pub fn has(&self, id: NodeId) -> bool {
53        self.offsets.contains_key(&id)
54    }
55
56    /// Element offset (f32) of a node's slot — what the kernels index by.
57    #[inline]
58    pub fn elem_offset(&self, id: NodeId) -> u32 {
59        (self.offsets[&id] / 4) as u32
60    }
61
62    /// Raw USM base pointer (kernel argument 0).
63    #[inline]
64    pub fn base_ptr(&self) -> *mut std::ffi::c_void {
65        self.base
66    }
67
68    pub fn write_f32(&self, id: NodeId, data: &[f32]) {
69        let Some(&off) = self.offsets.get(&id) else {
70            return;
71        };
72        let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
73        let n = data.len().min(cap);
74        unsafe {
75            let dst = (self.base as *mut u8).add(off) as *mut f32;
76            std::ptr::copy_nonoverlapping(data.as_ptr(), dst, n);
77        }
78    }
79
80    pub fn write_bytes(&self, id: NodeId, data: &[u8]) {
81        let Some(&off) = self.offsets.get(&id) else {
82            return;
83        };
84        let cap = self.lens.get(&id).copied().unwrap_or(0);
85        let n = data.len().min(cap);
86        unsafe {
87            std::ptr::copy_nonoverlapping(data.as_ptr(), (self.base as *mut u8).add(off), n);
88        }
89    }
90
91    pub fn read_f32(&self, id: NodeId, n: usize) -> Vec<f32> {
92        let Some(&off) = self.offsets.get(&id) else {
93            return vec![0.0; n];
94        };
95        let cap = self.lens.get(&id).copied().unwrap_or(0) / 4;
96        let n = n.min(cap);
97        let mut out = vec![0.0f32; n];
98        unsafe {
99            let src = (self.base as *const u8).add(off) as *const f32;
100            std::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), n);
101        }
102        out
103    }
104
105    pub fn read_bytes(&self, id: NodeId, nbytes: usize) -> Vec<u8> {
106        let Some(&off) = self.offsets.get(&id) else {
107            return vec![0u8; nbytes];
108        };
109        let cap = self.lens.get(&id).copied().unwrap_or(0);
110        let n = nbytes.min(cap);
111        let mut out = vec![0u8; nbytes];
112        unsafe {
113            std::ptr::copy_nonoverlapping((self.base as *const u8).add(off), out.as_mut_ptr(), n);
114        }
115        out
116    }
117}
118
119impl Drop for Arena {
120    fn drop(&mut self) {
121        let _ = &self.dev;
122        self.dev.free(self.base);
123    }
124}