Skip to main content

rlx_oneapi/
level_zero.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! Hand-rolled Level Zero (`ze_*`) FFI, dynamically loaded with `libloading`.
7//!
8//! Only the subset RLX needs is bound: init + driver/device enumeration and
9//! properties (for availability + device name), context/queue/command-list
10//! lifecycle, USM-shared allocation, and SPIR-V module → kernel → launch. The
11//! loader (`libze_loader`) is opened at runtime, so the crate links and builds
12//! on hosts with no oneAPI runtime (macOS, CI) — [`Lib::load`] simply returns
13//! `Err` there and the whole backend reports itself unavailable.
14//!
15//! ⚠️ The `ZE_STRUCTURE_TYPE_*` enum values and descriptor field layouts are
16//! transcribed from `ze_api.h` (Level Zero spec v1.x). They are exercised only
17//! on real Intel hardware (Arc / Data Center Max) during bring-up; verify them
18//! against the installed loader version there. Nothing in this module runs on
19//! the macOS dev box — the backend falls through to the CPU reference path.
20
21#![allow(
22    non_upper_case_globals,
23    non_camel_case_types,
24    clippy::missing_safety_doc
25)]
26
27use libloading::Library;
28use std::ffi::{c_char, c_void};
29
30// ── result + handles ───────────────────────────────────────────────────────
31
32pub type ZeResult = i32; // ze_result_t
33pub const ZE_RESULT_SUCCESS: ZeResult = 0;
34
35pub type DriverHandle = *mut c_void;
36pub type DeviceHandle = *mut c_void;
37pub type ContextHandle = *mut c_void;
38pub type CommandQueueHandle = *mut c_void;
39pub type CommandListHandle = *mut c_void;
40pub type ModuleHandle = *mut c_void;
41pub type ModuleBuildLogHandle = *mut c_void;
42pub type KernelHandle = *mut c_void;
43pub type FenceHandle = *mut c_void;
44pub type EventHandle = *mut c_void;
45
46// ── ze_structure_type_t (subset) ───────────────────────────────────────────
47
48pub const ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES: u32 = 0x3;
49pub const ZE_STRUCTURE_TYPE_CONTEXT_DESC: u32 = 0xd;
50pub const ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC: u32 = 0xe;
51pub const ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC: u32 = 0xf;
52pub const ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC: u32 = 0x15;
53pub const ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC: u32 = 0x16;
54pub const ZE_STRUCTURE_TYPE_MODULE_DESC: u32 = 0x1d;
55pub const ZE_STRUCTURE_TYPE_KERNEL_DESC: u32 = 0x1f;
56
57/// `ze_device_type_t`: GPU = 1.
58pub const ZE_DEVICE_TYPE_GPU: u32 = 1;
59/// `ze_module_format_t`: IL_SPIRV = 0.
60pub const ZE_MODULE_FORMAT_IL_SPIRV: u32 = 0;
61/// `ze_command_queue_mode_t`: SYNCHRONOUS = 1 (execute blocks until complete).
62pub const ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS: u32 = 1;
63/// `ze_command_queue_mode_t`: DEFAULT = 0.
64pub const ZE_COMMAND_QUEUE_MODE_DEFAULT: u32 = 0;
65
66pub const ZE_MAX_DEVICE_NAME: usize = 256;
67pub const ZE_MAX_DEVICE_UUID_SIZE: usize = 16;
68
69// ── descriptor structs (repr(C), matching ze_api.h) ────────────────────────
70
71#[repr(C)]
72pub struct DeviceUuid {
73    pub id: [u8; ZE_MAX_DEVICE_UUID_SIZE],
74}
75
76#[repr(C)]
77pub struct DeviceProperties {
78    pub stype: u32,
79    pub pnext: *mut c_void,
80    pub type_: u32,
81    pub vendor_id: u32,
82    pub device_id: u32,
83    pub flags: u32,
84    pub subdevice_id: u32,
85    pub core_clock_rate: u32,
86    pub max_mem_alloc_size: u64,
87    pub max_hardware_contexts: u32,
88    pub max_command_queue_priority: u32,
89    pub num_threads_per_eu: u32,
90    pub physical_eu_simd_width: u32,
91    pub num_eus_per_subslice: u32,
92    pub num_subslices_per_slice: u32,
93    pub num_slices: u32,
94    pub timer_resolution: u64,
95    pub timestamp_valid_bits: u32,
96    pub kernel_timestamp_valid_bits: u32,
97    pub uuid: DeviceUuid,
98    pub name: [c_char; ZE_MAX_DEVICE_NAME],
99}
100
101impl Default for DeviceProperties {
102    fn default() -> Self {
103        // SAFETY: all-zero is a valid bit pattern for every field (the call
104        // overwrites them); we only set the required `stype` afterwards.
105        let mut p: Self = unsafe { std::mem::zeroed() };
106        p.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
107        p
108    }
109}
110
111#[repr(C)]
112pub struct ContextDesc {
113    pub stype: u32,
114    pub pnext: *const c_void,
115    pub flags: u32,
116}
117
118#[repr(C)]
119pub struct CommandQueueDesc {
120    pub stype: u32,
121    pub pnext: *const c_void,
122    pub ordinal: u32,
123    pub index: u32,
124    pub flags: u32,
125    pub mode: u32,
126    pub priority: u32,
127}
128
129#[repr(C)]
130pub struct CommandListDesc {
131    pub stype: u32,
132    pub pnext: *const c_void,
133    pub command_queue_group_ordinal: u32,
134    pub flags: u32,
135}
136
137#[repr(C)]
138pub struct ModuleDesc {
139    pub stype: u32,
140    pub pnext: *const c_void,
141    pub format: u32,
142    pub input_size: usize,
143    pub p_input_module: *const u8,
144    pub p_build_flags: *const c_char,
145    pub p_constants: *const c_void,
146}
147
148#[repr(C)]
149pub struct KernelDesc {
150    pub stype: u32,
151    pub pnext: *const c_void,
152    pub flags: u32,
153    pub p_kernel_name: *const c_char,
154}
155
156#[repr(C)]
157pub struct DeviceMemAllocDesc {
158    pub stype: u32,
159    pub pnext: *const c_void,
160    pub flags: u32,
161    pub ordinal: u32,
162}
163
164#[repr(C)]
165pub struct HostMemAllocDesc {
166    pub stype: u32,
167    pub pnext: *const c_void,
168    pub flags: u32,
169}
170
171#[repr(C)]
172pub struct GroupCount {
173    pub group_count_x: u32,
174    pub group_count_y: u32,
175    pub group_count_z: u32,
176}
177
178// ── function pointer types ─────────────────────────────────────────────────
179
180type PfnInit = unsafe extern "C" fn(u32) -> ZeResult;
181type PfnDriverGet = unsafe extern "C" fn(*mut u32, *mut DriverHandle) -> ZeResult;
182type PfnDeviceGet = unsafe extern "C" fn(DriverHandle, *mut u32, *mut DeviceHandle) -> ZeResult;
183type PfnDeviceGetProperties = unsafe extern "C" fn(DeviceHandle, *mut DeviceProperties) -> ZeResult;
184type PfnContextCreate =
185    unsafe extern "C" fn(DriverHandle, *const ContextDesc, *mut ContextHandle) -> ZeResult;
186type PfnContextDestroy = unsafe extern "C" fn(ContextHandle) -> ZeResult;
187type PfnCommandQueueCreate = unsafe extern "C" fn(
188    ContextHandle,
189    DeviceHandle,
190    *const CommandQueueDesc,
191    *mut CommandQueueHandle,
192) -> ZeResult;
193type PfnCommandQueueDestroy = unsafe extern "C" fn(CommandQueueHandle) -> ZeResult;
194type PfnCommandQueueExecute = unsafe extern "C" fn(
195    CommandQueueHandle,
196    u32,
197    *const CommandListHandle,
198    FenceHandle,
199) -> ZeResult;
200type PfnCommandQueueSynchronize = unsafe extern "C" fn(CommandQueueHandle, u64) -> ZeResult;
201type PfnCommandListCreate = unsafe extern "C" fn(
202    ContextHandle,
203    DeviceHandle,
204    *const CommandListDesc,
205    *mut CommandListHandle,
206) -> ZeResult;
207type PfnCommandListDestroy = unsafe extern "C" fn(CommandListHandle) -> ZeResult;
208type PfnCommandListClose = unsafe extern "C" fn(CommandListHandle) -> ZeResult;
209type PfnCommandListReset = unsafe extern "C" fn(CommandListHandle) -> ZeResult;
210type PfnCommandListAppendLaunchKernel = unsafe extern "C" fn(
211    CommandListHandle,
212    KernelHandle,
213    *const GroupCount,
214    EventHandle,
215    u32,
216    *mut EventHandle,
217) -> ZeResult;
218type PfnCommandListAppendBarrier =
219    unsafe extern "C" fn(CommandListHandle, EventHandle, u32, *mut EventHandle) -> ZeResult;
220type PfnMemAllocShared = unsafe extern "C" fn(
221    ContextHandle,
222    *const DeviceMemAllocDesc,
223    *const HostMemAllocDesc,
224    usize,
225    usize,
226    DeviceHandle,
227    *mut *mut c_void,
228) -> ZeResult;
229type PfnMemFree = unsafe extern "C" fn(ContextHandle, *mut c_void) -> ZeResult;
230type PfnModuleCreate = unsafe extern "C" fn(
231    ContextHandle,
232    DeviceHandle,
233    *const ModuleDesc,
234    *mut ModuleHandle,
235    *mut ModuleBuildLogHandle,
236) -> ZeResult;
237type PfnModuleDestroy = unsafe extern "C" fn(ModuleHandle) -> ZeResult;
238type PfnKernelCreate =
239    unsafe extern "C" fn(ModuleHandle, *const KernelDesc, *mut KernelHandle) -> ZeResult;
240type PfnKernelDestroy = unsafe extern "C" fn(KernelHandle) -> ZeResult;
241type PfnKernelSetGroupSize = unsafe extern "C" fn(KernelHandle, u32, u32, u32) -> ZeResult;
242type PfnKernelSetArgumentValue =
243    unsafe extern "C" fn(KernelHandle, u32, usize, *const c_void) -> ZeResult;
244
245/// Resolved Level Zero entry points. Keeps the loaded [`Library`] alive so the
246/// copied function pointers stay valid for the process lifetime.
247pub struct Lib {
248    _lib: Library,
249    pub ze_init: PfnInit,
250    pub driver_get: PfnDriverGet,
251    pub device_get: PfnDeviceGet,
252    pub device_get_properties: PfnDeviceGetProperties,
253    pub context_create: PfnContextCreate,
254    pub context_destroy: PfnContextDestroy,
255    pub command_queue_create: PfnCommandQueueCreate,
256    pub command_queue_destroy: PfnCommandQueueDestroy,
257    pub command_queue_execute: PfnCommandQueueExecute,
258    pub command_queue_synchronize: PfnCommandQueueSynchronize,
259    pub command_list_create: PfnCommandListCreate,
260    pub command_list_destroy: PfnCommandListDestroy,
261    pub command_list_close: PfnCommandListClose,
262    pub command_list_reset: PfnCommandListReset,
263    pub command_list_append_launch_kernel: PfnCommandListAppendLaunchKernel,
264    pub command_list_append_barrier: PfnCommandListAppendBarrier,
265    pub mem_alloc_shared: PfnMemAllocShared,
266    pub mem_free: PfnMemFree,
267    pub module_create: PfnModuleCreate,
268    pub module_destroy: PfnModuleDestroy,
269    pub kernel_create: PfnKernelCreate,
270    pub kernel_destroy: PfnKernelDestroy,
271    pub kernel_set_group_size: PfnKernelSetGroupSize,
272    pub kernel_set_argument_value: PfnKernelSetArgumentValue,
273}
274
275/// Candidate loader filenames per platform. macOS has none → load fails →
276/// backend unavailable (graceful, like rlx-vulkan without MoltenVK).
277fn loader_names() -> &'static [&'static str] {
278    if cfg!(target_os = "windows") {
279        &["ze_loader.dll"]
280    } else {
281        &["libze_loader.so.1", "libze_loader.so"]
282    }
283}
284
285impl Lib {
286    /// Open `libze_loader` and resolve every entry point. `Err` on any host
287    /// without the oneAPI runtime (no library or a missing symbol).
288    pub unsafe fn load() -> Result<Lib, String> {
289        // Honor an explicit override first, then the platform defaults.
290        let mut tried: Vec<String> = Vec::new();
291        let mut lib: Option<Library> = None;
292        let mut names: Vec<String> = Vec::new();
293        if let Ok(p) = std::env::var("RLX_ONEAPI_LOADER") {
294            names.push(p);
295        }
296        names.extend(loader_names().iter().map(|s| s.to_string()));
297        for name in &names {
298            match unsafe { Library::new(name) } {
299                Ok(l) => {
300                    lib = Some(l);
301                    break;
302                }
303                Err(e) => tried.push(format!("{name}: {e}")),
304            }
305        }
306        let lib = lib.ok_or_else(|| format!("no Level Zero loader ({})", tried.join("; ")))?;
307
308        macro_rules! sym {
309            ($name:expr) => {{
310                let s: libloading::Symbol<_> = unsafe { lib.get($name) }
311                    .map_err(|e| format!("missing {}: {e}", String::from_utf8_lossy($name)))?;
312                *s
313            }};
314        }
315
316        let out = Lib {
317            ze_init: sym!(b"zeInit\0"),
318            driver_get: sym!(b"zeDriverGet\0"),
319            device_get: sym!(b"zeDeviceGet\0"),
320            device_get_properties: sym!(b"zeDeviceGetProperties\0"),
321            context_create: sym!(b"zeContextCreate\0"),
322            context_destroy: sym!(b"zeContextDestroy\0"),
323            command_queue_create: sym!(b"zeCommandQueueCreate\0"),
324            command_queue_destroy: sym!(b"zeCommandQueueDestroy\0"),
325            command_queue_execute: sym!(b"zeCommandQueueExecuteCommandLists\0"),
326            command_queue_synchronize: sym!(b"zeCommandQueueSynchronize\0"),
327            command_list_create: sym!(b"zeCommandListCreate\0"),
328            command_list_destroy: sym!(b"zeCommandListDestroy\0"),
329            command_list_close: sym!(b"zeCommandListClose\0"),
330            command_list_reset: sym!(b"zeCommandListReset\0"),
331            command_list_append_launch_kernel: sym!(b"zeCommandListAppendLaunchKernel\0"),
332            command_list_append_barrier: sym!(b"zeCommandListAppendBarrier\0"),
333            mem_alloc_shared: sym!(b"zeMemAllocShared\0"),
334            mem_free: sym!(b"zeMemFree\0"),
335            module_create: sym!(b"zeModuleCreate\0"),
336            module_destroy: sym!(b"zeModuleDestroy\0"),
337            kernel_create: sym!(b"zeKernelCreate\0"),
338            kernel_destroy: sym!(b"zeKernelDestroy\0"),
339            kernel_set_group_size: sym!(b"zeKernelSetGroupSize\0"),
340            kernel_set_argument_value: sym!(b"zeKernelSetArgumentValue\0"),
341            _lib: lib,
342        };
343        Ok(out)
344    }
345}
346
347/// Map a non-success [`ZeResult`] to a printable error.
348pub fn check(res: ZeResult, ctx: &str) -> Result<(), String> {
349    if res == ZE_RESULT_SUCCESS {
350        Ok(())
351    } else {
352        Err(format!("{ctx} failed: ze_result 0x{res:x}"))
353    }
354}