Skip to main content

rlx_oneapi/
device.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! The process-wide Level Zero driver / device / context / compute-queue
7//! singleton, brought up through the dynamically-loaded `libze_loader`. If no
8//! loader is present, `zeInit` fails, or no GPU device is exposed,
9//! [`oneapi_device`] returns `None` and the backend reports itself unavailable
10//! — mirroring rlx-cuda / rlx-rocm / rlx-vulkan on a host with no driver.
11
12use crate::level_zero::*;
13use std::ffi::{CStr, c_void};
14use std::sync::{Mutex, OnceLock};
15
16/// Owned Level Zero context. One per process.
17pub struct OneApiDevice {
18    pub lib: Lib,
19    pub driver: DriverHandle,
20    pub device: DeviceHandle,
21    pub context: ContextHandle,
22    pub queue: CommandQueueHandle,
23    /// Command-queue-group ordinal the compute queue/list were created on.
24    pub queue_ordinal: u32,
25    pub name: String,
26    /// Level Zero command queues require external synchronization; serialize.
27    submit_lock: Mutex<()>,
28}
29
30// The handles are process-global and only touched behind `submit_lock` or are
31// immutable after construction (mirrors rlx-vulkan's VulkanDevice).
32unsafe impl Send for OneApiDevice {}
33unsafe impl Sync for OneApiDevice {}
34
35static DEVICE: OnceLock<Option<OneApiDevice>> = OnceLock::new();
36
37/// The process-wide oneAPI device, or `None` when unavailable.
38pub fn oneapi_device() -> Option<&'static OneApiDevice> {
39    DEVICE.get_or_init(|| OneApiDevice::new().ok()).as_ref()
40}
41
42impl OneApiDevice {
43    fn new() -> Result<Self, String> {
44        // SAFETY: dynamic-load the Level Zero loader; Err on hosts without it.
45        let lib = unsafe { Lib::load() }?;
46
47        unsafe {
48            check((lib.ze_init)(0), "zeInit")?;
49
50            // Enumerate drivers.
51            let mut driver_count: u32 = 0;
52            check(
53                (lib.driver_get)(&mut driver_count, std::ptr::null_mut()),
54                "zeDriverGet(count)",
55            )?;
56            if driver_count == 0 {
57                return Err("no Level Zero drivers".into());
58            }
59            let mut drivers: Vec<DriverHandle> = vec![std::ptr::null_mut(); driver_count as usize];
60            check(
61                (lib.driver_get)(&mut driver_count, drivers.as_mut_ptr()),
62                "zeDriverGet",
63            )?;
64
65            // Pick the first GPU device across all drivers.
66            let mut chosen: Option<(DriverHandle, DeviceHandle, String)> = None;
67            for &driver in &drivers {
68                let mut dev_count: u32 = 0;
69                if check(
70                    (lib.device_get)(driver, &mut dev_count, std::ptr::null_mut()),
71                    "zeDeviceGet(count)",
72                )
73                .is_err()
74                    || dev_count == 0
75                {
76                    continue;
77                }
78                let mut devices: Vec<DeviceHandle> = vec![std::ptr::null_mut(); dev_count as usize];
79                if check(
80                    (lib.device_get)(driver, &mut dev_count, devices.as_mut_ptr()),
81                    "zeDeviceGet",
82                )
83                .is_err()
84                {
85                    continue;
86                }
87                for &device in &devices {
88                    let mut props = DeviceProperties::default();
89                    if (lib.device_get_properties)(device, &mut props) != ZE_RESULT_SUCCESS {
90                        continue;
91                    }
92                    if props.type_ != ZE_DEVICE_TYPE_GPU {
93                        continue;
94                    }
95                    let name = CStr::from_ptr(props.name.as_ptr())
96                        .to_string_lossy()
97                        .into_owned();
98                    chosen = Some((driver, device, name));
99                    break;
100                }
101                if chosen.is_some() {
102                    break;
103                }
104            }
105            let (driver, device, name) =
106                chosen.ok_or_else(|| "no Level Zero GPU device".to_string())?;
107
108            // Context.
109            let ctx_desc = ContextDesc {
110                stype: ZE_STRUCTURE_TYPE_CONTEXT_DESC,
111                pnext: std::ptr::null(),
112                flags: 0,
113            };
114            let mut context: ContextHandle = std::ptr::null_mut();
115            check(
116                (lib.context_create)(driver, &ctx_desc, &mut context),
117                "zeContextCreate",
118            )?;
119
120            // Compute command queue. Group ordinal 0 is the primary compute
121            // group on Intel GPUs; a follow-up can query
122            // zeDeviceGetCommandQueueGroupProperties to pick the COMPUTE group
123            // explicitly (it's almost always 0).
124            let queue_ordinal = 0u32;
125            let q_desc = CommandQueueDesc {
126                stype: ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
127                pnext: std::ptr::null(),
128                ordinal: queue_ordinal,
129                index: 0,
130                flags: 0,
131                mode: ZE_COMMAND_QUEUE_MODE_DEFAULT,
132                priority: 0,
133            };
134            let mut queue: CommandQueueHandle = std::ptr::null_mut();
135            check(
136                (lib.command_queue_create)(context, device, &q_desc, &mut queue),
137                "zeCommandQueueCreate",
138            )?;
139
140            Ok(Self {
141                lib,
142                driver,
143                device,
144                context,
145                queue,
146                queue_ordinal,
147                name,
148                submit_lock: Mutex::new(()),
149            })
150        }
151    }
152
153    /// Allocate a host-accessible USM-shared buffer of `size` bytes, zeroed.
154    /// Returns the device pointer (also CPU-dereferenceable on integrated /
155    /// shared-memory Intel GPUs).
156    pub fn alloc_shared(&self, size: usize) -> Result<*mut c_void, String> {
157        let dev_desc = DeviceMemAllocDesc {
158            stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
159            pnext: std::ptr::null(),
160            flags: 0,
161            ordinal: 0,
162        };
163        let host_desc = HostMemAllocDesc {
164            stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
165            pnext: std::ptr::null(),
166            flags: 0,
167        };
168        let mut ptr: *mut c_void = std::ptr::null_mut();
169        unsafe {
170            check(
171                (self.lib.mem_alloc_shared)(
172                    self.context,
173                    &dev_desc,
174                    &host_desc,
175                    size.max(1),
176                    64,
177                    self.device,
178                    &mut ptr,
179                ),
180                "zeMemAllocShared",
181            )?;
182            std::ptr::write_bytes(ptr as *mut u8, 0, size);
183        }
184        Ok(ptr)
185    }
186
187    /// Free a USM buffer previously returned by [`alloc_shared`](Self::alloc_shared).
188    #[allow(clippy::not_unsafe_ptr_arg_deref)] // ptr is an opaque USM handle; we only pass it to FFI
189    pub fn free(&self, ptr: *mut c_void) {
190        if !ptr.is_null() {
191            unsafe {
192                let _ = (self.lib.mem_free)(self.context, ptr);
193            }
194        }
195    }
196
197    /// Create a fresh closed-on-demand command list on the compute group.
198    pub fn create_command_list(&self) -> Result<CommandListHandle, String> {
199        let desc = CommandListDesc {
200            stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
201            pnext: std::ptr::null(),
202            command_queue_group_ordinal: self.queue_ordinal,
203            flags: 0,
204        };
205        let mut list: CommandListHandle = std::ptr::null_mut();
206        unsafe {
207            check(
208                (self.lib.command_list_create)(self.context, self.device, &desc, &mut list),
209                "zeCommandListCreate",
210            )?;
211        }
212        Ok(list)
213    }
214
215    /// Close `list`, execute it on the compute queue, and block until complete.
216    /// Serialized across threads (queues need external synchronization).
217    #[allow(clippy::not_unsafe_ptr_arg_deref)] // list is an opaque handle; we only pass it to FFI
218    pub fn execute_sync(&self, list: CommandListHandle) -> Result<(), String> {
219        let _guard = self.submit_lock.lock().unwrap();
220        unsafe {
221            check((self.lib.command_list_close)(list), "zeCommandListClose")?;
222            let lists = [list];
223            check(
224                (self.lib.command_queue_execute)(
225                    self.queue,
226                    1,
227                    lists.as_ptr(),
228                    std::ptr::null_mut(),
229                ),
230                "zeCommandQueueExecuteCommandLists",
231            )?;
232            check(
233                (self.lib.command_queue_synchronize)(self.queue, u64::MAX),
234                "zeCommandQueueSynchronize",
235            )?;
236        }
237        Ok(())
238    }
239}
240
241impl Drop for OneApiDevice {
242    fn drop(&mut self) {
243        unsafe {
244            if !self.queue.is_null() {
245                let _ = (self.lib.command_queue_destroy)(self.queue);
246            }
247            if !self.context.is_null() {
248                let _ = (self.lib.context_destroy)(self.context);
249            }
250        }
251    }
252}