1use crate::error::{LevelZeroError, LevelZeroResult};
8
9#[cfg(any(target_os = "linux", target_os = "windows"))]
12use std::{ffi::c_void, sync::Arc};
13
14#[cfg(any(target_os = "linux", target_os = "windows"))]
15use libloading::Library;
16
17#[cfg(any(target_os = "linux", target_os = "windows"))]
20type ZeDriverHandle = *mut c_void;
21
22#[cfg(any(target_os = "linux", target_os = "windows"))]
23type ZeDeviceHandle = *mut c_void;
24
25#[cfg(any(target_os = "linux", target_os = "windows"))]
26type ZeContextHandle = *mut c_void;
27
28#[cfg(any(target_os = "linux", target_os = "windows"))]
29pub(crate) type ZeCommandQueueHandle = *mut c_void;
30
31#[cfg(any(target_os = "linux", target_os = "windows"))]
32pub(crate) type ZeCommandListHandle = *mut c_void;
33
34#[cfg(any(target_os = "linux", target_os = "windows"))]
37const ZE_RESULT_SUCCESS: u32 = 0;
38
39#[cfg(any(target_os = "linux", target_os = "windows"))]
40const ZE_DEVICE_TYPE_GPU: u32 = 1;
41
42#[cfg(any(target_os = "linux", target_os = "windows"))]
43const ZE_STRUCTURE_TYPE_CONTEXT_DESC: u32 = 0xb;
44
45#[cfg(any(target_os = "linux", target_os = "windows"))]
46const ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC: u32 = 0xf;
47
48#[cfg(any(target_os = "linux", target_os = "windows"))]
49pub(crate) const ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC: u32 = 0x9;
50
51#[cfg(any(target_os = "linux", target_os = "windows"))]
52pub(crate) const ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC: u32 = 0x1;
53
54#[cfg(any(target_os = "linux", target_os = "windows"))]
55pub(crate) const ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC: u32 = 0x2;
56
57#[cfg(any(target_os = "linux", target_os = "windows"))]
58const ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES: u32 = 0x3;
59
60#[cfg(any(target_os = "linux", target_os = "windows"))]
63#[repr(C)]
64pub(crate) struct ZeContextDesc {
65 pub(crate) stype: u32,
66 pub(crate) p_next: *const c_void,
67}
68
69#[cfg(any(target_os = "linux", target_os = "windows"))]
70#[repr(C)]
71pub(crate) struct ZeCommandQueueDesc {
72 stype: u32,
73 p_next: *const c_void,
74 ordinal: u32,
75 index: u32,
76 flags: u32,
77 mode: u32,
78 priority: u32,
79}
80
81#[cfg(any(target_os = "linux", target_os = "windows"))]
82#[repr(C)]
83pub(crate) struct ZeCommandListDesc {
84 pub stype: u32,
85 pub p_next: *const c_void,
86 pub command_queue_group_ordinal: u32,
87 pub flags: u32,
88}
89
90#[cfg(any(target_os = "linux", target_os = "windows"))]
91#[repr(C)]
92pub(crate) struct ZeDeviceMemAllocDesc {
93 pub stype: u32,
94 pub p_next: *const c_void,
95 pub flags: u32,
96 pub ordinal: u32,
97}
98
99#[cfg(any(target_os = "linux", target_os = "windows"))]
100#[repr(C)]
101pub(crate) struct ZeHostMemAllocDesc {
102 pub stype: u32,
103 pub p_next: *const c_void,
104 pub flags: u32,
105}
106
107#[cfg(any(target_os = "linux", target_os = "windows"))]
108#[repr(C)]
109pub(crate) struct ZeDeviceProperties {
110 stype: u32,
111 p_next: *const c_void,
112 device_type: u32,
113 vendor_id: u32,
114 device_id: u32,
115 _flags: u32,
116 _sub_device_ids: [u32; 64],
117 _timer_resolution: u64,
118 _timestamp_valid_bits: u32,
119 _kernel_timestamp_valid_bits: u32,
120 name: [u8; 256],
121 _max_mem_alloc_size: u64,
122 _num_threads_per_eu: u32,
123 _physical_eu_simd_width: u32,
124 _num_eu_per_sub_slice: u32,
125 _num_sub_slices_per_slice: u32,
126 _num_slices: u32,
127 _timer_resolution_ns: u64,
128 _uuid: [u8; 16],
129}
130
131#[cfg(any(target_os = "linux", target_os = "windows"))]
134type ZeInitFn = unsafe extern "C" fn(flags: u32) -> u32;
135
136#[cfg(any(target_os = "linux", target_os = "windows"))]
137type ZeDriverGetFn = unsafe extern "C" fn(count: *mut u32, drivers: *mut ZeDriverHandle) -> u32;
138
139#[cfg(any(target_os = "linux", target_os = "windows"))]
140type ZeDeviceGetFn = unsafe extern "C" fn(
141 driver: ZeDriverHandle,
142 count: *mut u32,
143 devices: *mut ZeDeviceHandle,
144) -> u32;
145
146#[cfg(any(target_os = "linux", target_os = "windows"))]
147type ZeDeviceGetPropertiesFn =
148 unsafe extern "C" fn(device: ZeDeviceHandle, props: *mut ZeDeviceProperties) -> u32;
149
150#[cfg(any(target_os = "linux", target_os = "windows"))]
151type ZeContextCreateFn = unsafe extern "C" fn(
152 driver: ZeDriverHandle,
153 desc: *const ZeContextDesc,
154 context: *mut ZeContextHandle,
155) -> u32;
156
157#[cfg(any(target_os = "linux", target_os = "windows"))]
158type ZeContextDestroyFn = unsafe extern "C" fn(context: ZeContextHandle) -> u32;
159
160#[cfg(any(target_os = "linux", target_os = "windows"))]
161type ZeCommandQueueCreateFn = unsafe extern "C" fn(
162 context: ZeContextHandle,
163 device: ZeDeviceHandle,
164 desc: *const ZeCommandQueueDesc,
165 queue: *mut ZeCommandQueueHandle,
166) -> u32;
167
168#[cfg(any(target_os = "linux", target_os = "windows"))]
169type ZeCommandQueueDestroyFn = unsafe extern "C" fn(queue: ZeCommandQueueHandle) -> u32;
170
171#[cfg(any(target_os = "linux", target_os = "windows"))]
172type ZeCommandQueueSynchronizeFn =
173 unsafe extern "C" fn(queue: ZeCommandQueueHandle, timeout: u64) -> u32;
174
175#[cfg(any(target_os = "linux", target_os = "windows"))]
176type ZeCommandQueueExecuteCommandListsFn = unsafe extern "C" fn(
177 queue: ZeCommandQueueHandle,
178 count: u32,
179 lists: *const ZeCommandListHandle,
180 fence: usize,
181) -> u32;
182
183#[cfg(any(target_os = "linux", target_os = "windows"))]
184type ZeCommandListCreateFn = unsafe extern "C" fn(
185 context: ZeContextHandle,
186 device: ZeDeviceHandle,
187 desc: *const ZeCommandListDesc,
188 list: *mut ZeCommandListHandle,
189) -> u32;
190
191#[cfg(any(target_os = "linux", target_os = "windows"))]
192type ZeCommandListDestroyFn = unsafe extern "C" fn(list: ZeCommandListHandle) -> u32;
193
194#[cfg(any(target_os = "linux", target_os = "windows"))]
195type ZeCommandListCloseFn = unsafe extern "C" fn(list: ZeCommandListHandle) -> u32;
196
197#[cfg(any(target_os = "linux", target_os = "windows"))]
198type ZeCommandListResetFn = unsafe extern "C" fn(list: ZeCommandListHandle) -> u32;
199
200#[cfg(any(target_os = "linux", target_os = "windows"))]
201type ZeCommandListAppendMemoryCopyFn = unsafe extern "C" fn(
202 list: ZeCommandListHandle,
203 dst: *mut c_void,
204 src: *const c_void,
205 size: usize,
206 signal_event: usize,
207 wait_count: u32,
208 wait_events: *const usize,
209) -> u32;
210
211#[cfg(any(target_os = "linux", target_os = "windows"))]
212type ZeMemAllocDeviceFn = unsafe extern "C" fn(
213 context: ZeContextHandle,
214 desc: *const ZeDeviceMemAllocDesc,
215 size: usize,
216 alignment: usize,
217 device: ZeDeviceHandle,
218 ptr: *mut *mut c_void,
219) -> u32;
220
221#[cfg(any(target_os = "linux", target_os = "windows"))]
222type ZeMemAllocHostFn = unsafe extern "C" fn(
223 context: ZeContextHandle,
224 desc: *const ZeHostMemAllocDesc,
225 size: usize,
226 alignment: usize,
227 ptr: *mut *mut c_void,
228) -> u32;
229
230#[cfg(any(target_os = "linux", target_os = "windows"))]
231type ZeMemFreeFn = unsafe extern "C" fn(context: ZeContextHandle, ptr: *mut c_void) -> u32;
232
233#[cfg(any(target_os = "linux", target_os = "windows"))]
239pub(crate) struct L0Api {
240 _lib: Library,
242 pub ze_init: ZeInitFn,
243 pub ze_driver_get: ZeDriverGetFn,
244 pub ze_device_get: ZeDeviceGetFn,
245 pub ze_device_get_properties: ZeDeviceGetPropertiesFn,
246 pub ze_context_create: ZeContextCreateFn,
247 pub ze_context_destroy: ZeContextDestroyFn,
248 pub ze_command_queue_create: ZeCommandQueueCreateFn,
249 pub ze_command_queue_destroy: ZeCommandQueueDestroyFn,
250 pub ze_command_queue_synchronize: ZeCommandQueueSynchronizeFn,
251 pub ze_command_queue_execute_command_lists: ZeCommandQueueExecuteCommandListsFn,
252 pub ze_command_list_create: ZeCommandListCreateFn,
253 pub ze_command_list_destroy: ZeCommandListDestroyFn,
254 pub ze_command_list_close: ZeCommandListCloseFn,
255 #[allow(dead_code)]
256 pub ze_command_list_reset: ZeCommandListResetFn,
257 pub ze_command_list_append_memory_copy: ZeCommandListAppendMemoryCopyFn,
258 pub ze_mem_alloc_device: ZeMemAllocDeviceFn,
259 pub ze_mem_alloc_host: ZeMemAllocHostFn,
260 pub ze_mem_free: ZeMemFreeFn,
261}
262
263#[cfg(any(target_os = "linux", target_os = "windows"))]
264impl L0Api {
265 unsafe fn load() -> LevelZeroResult<Self> {
273 #[cfg(target_os = "linux")]
274 let lib_name = "libze_loader.so.1";
275 #[cfg(target_os = "windows")]
276 let lib_name = "ze_loader.dll";
277
278 let lib = unsafe {
281 Library::new(lib_name)
282 .map_err(|e| LevelZeroError::LibraryNotFound(format!("{lib_name}: {e}")))?
283 };
284
285 macro_rules! sym {
286 ($name:literal, $ty:ty) => {{
287 *unsafe {
291 lib.get::<$ty>($name).map_err(|e| {
292 LevelZeroError::LibraryNotFound(format!(
293 "symbol {}: {e}",
294 stringify!($name)
295 ))
296 })?
297 }
298 }};
299 }
300
301 let ze_init = sym!(b"zeInit\0", ZeInitFn);
302 let ze_driver_get = sym!(b"zeDriverGet\0", ZeDriverGetFn);
303 let ze_device_get = sym!(b"zeDeviceGet\0", ZeDeviceGetFn);
304 let ze_device_get_properties = sym!(b"zeDeviceGetProperties\0", ZeDeviceGetPropertiesFn);
305 let ze_context_create = sym!(b"zeContextCreate\0", ZeContextCreateFn);
306 let ze_context_destroy = sym!(b"zeContextDestroy\0", ZeContextDestroyFn);
307 let ze_command_queue_create = sym!(b"zeCommandQueueCreate\0", ZeCommandQueueCreateFn);
308 let ze_command_queue_destroy = sym!(b"zeCommandQueueDestroy\0", ZeCommandQueueDestroyFn);
309 let ze_command_queue_synchronize =
310 sym!(b"zeCommandQueueSynchronize\0", ZeCommandQueueSynchronizeFn);
311 let ze_command_queue_execute_command_lists = sym!(
312 b"zeCommandQueueExecuteCommandLists\0",
313 ZeCommandQueueExecuteCommandListsFn
314 );
315 let ze_command_list_create = sym!(b"zeCommandListCreate\0", ZeCommandListCreateFn);
316 let ze_command_list_destroy = sym!(b"zeCommandListDestroy\0", ZeCommandListDestroyFn);
317 let ze_command_list_close = sym!(b"zeCommandListClose\0", ZeCommandListCloseFn);
318 let ze_command_list_reset = sym!(b"zeCommandListReset\0", ZeCommandListResetFn);
319 let ze_command_list_append_memory_copy = sym!(
320 b"zeCommandListAppendMemoryCopy\0",
321 ZeCommandListAppendMemoryCopyFn
322 );
323 let ze_mem_alloc_device = sym!(b"zeMemAllocDevice\0", ZeMemAllocDeviceFn);
324 let ze_mem_alloc_host = sym!(b"zeMemAllocHost\0", ZeMemAllocHostFn);
325 let ze_mem_free = sym!(b"zeMemFree\0", ZeMemFreeFn);
326
327 Ok(Self {
328 _lib: lib,
329 ze_init,
330 ze_driver_get,
331 ze_device_get,
332 ze_device_get_properties,
333 ze_context_create,
334 ze_context_destroy,
335 ze_command_queue_create,
336 ze_command_queue_destroy,
337 ze_command_queue_synchronize,
338 ze_command_queue_execute_command_lists,
339 ze_command_list_create,
340 ze_command_list_destroy,
341 ze_command_list_close,
342 ze_command_list_reset,
343 ze_command_list_append_memory_copy,
344 ze_mem_alloc_device,
345 ze_mem_alloc_host,
346 ze_mem_free,
347 })
348 }
349}
350
351pub struct LevelZeroDevice {
358 #[cfg(any(target_os = "linux", target_os = "windows"))]
360 pub(crate) api: Arc<L0Api>,
361 #[cfg(any(target_os = "linux", target_os = "windows"))]
363 pub(crate) context: ZeContextHandle,
364 #[cfg(any(target_os = "linux", target_os = "windows"))]
366 pub(crate) device: ZeDeviceHandle,
367 #[cfg(any(target_os = "linux", target_os = "windows"))]
369 pub(crate) queue: ZeCommandQueueHandle,
370 device_name: String,
372}
373
374impl LevelZeroDevice {
375 pub fn new() -> LevelZeroResult<Self> {
379 #[cfg(any(target_os = "linux", target_os = "windows"))]
380 {
381 let api = Arc::new(unsafe { L0Api::load()? });
384
385 let rc = unsafe { (api.ze_init)(0) };
388 if rc != ZE_RESULT_SUCCESS {
389 return Err(LevelZeroError::ZeError(rc, "zeInit failed".into()));
390 }
391
392 let mut driver_count: u32 = 0;
394 let rc =
396 unsafe { (api.ze_driver_get)(&mut driver_count as *mut u32, std::ptr::null_mut()) };
397 if rc != ZE_RESULT_SUCCESS {
398 return Err(LevelZeroError::ZeError(
399 rc,
400 "zeDriverGet (count) failed".into(),
401 ));
402 }
403 if driver_count == 0 {
404 return Err(LevelZeroError::NoSuitableDevice);
405 }
406
407 let mut drivers: Vec<ZeDriverHandle> =
408 vec![std::ptr::null_mut(); driver_count as usize];
409 let rc =
411 unsafe { (api.ze_driver_get)(&mut driver_count as *mut u32, drivers.as_mut_ptr()) };
412 if rc != ZE_RESULT_SUCCESS {
413 return Err(LevelZeroError::ZeError(
414 rc,
415 "zeDriverGet (enumerate) failed".into(),
416 ));
417 }
418
419 let driver = drivers[0];
420
421 let mut device_count: u32 = 0;
423 let rc = unsafe {
425 (api.ze_device_get)(driver, &mut device_count as *mut u32, std::ptr::null_mut())
426 };
427 if rc != ZE_RESULT_SUCCESS {
428 return Err(LevelZeroError::ZeError(
429 rc,
430 "zeDeviceGet (count) failed".into(),
431 ));
432 }
433 if device_count == 0 {
434 return Err(LevelZeroError::NoSuitableDevice);
435 }
436
437 let mut devices: Vec<ZeDeviceHandle> =
438 vec![std::ptr::null_mut(); device_count as usize];
439 let rc = unsafe {
441 (api.ze_device_get)(driver, &mut device_count as *mut u32, devices.as_mut_ptr())
442 };
443 if rc != ZE_RESULT_SUCCESS {
444 return Err(LevelZeroError::ZeError(
445 rc,
446 "zeDeviceGet (enumerate) failed".into(),
447 ));
448 }
449
450 let mut chosen_device: Option<ZeDeviceHandle> = None;
452 let mut device_name = String::from("Intel GPU");
453
454 for &dev in &devices {
455 let mut props =
458 unsafe { std::mem::MaybeUninit::<ZeDeviceProperties>::zeroed().assume_init() };
459 props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
460 props.p_next = std::ptr::null();
461
462 let rc = unsafe {
465 (api.ze_device_get_properties)(dev, &mut props as *mut ZeDeviceProperties)
466 };
467 if rc != ZE_RESULT_SUCCESS {
468 continue;
469 }
470
471 if props.device_type == ZE_DEVICE_TYPE_GPU {
472 let name_len = props
474 .name
475 .iter()
476 .position(|&b| b == 0)
477 .unwrap_or(props.name.len());
478 device_name = String::from_utf8_lossy(&props.name[..name_len]).into_owned();
479 chosen_device = Some(dev);
480 break;
481 }
482 }
483
484 let device = chosen_device.ok_or(LevelZeroError::NoSuitableDevice)?;
485
486 let ctx_desc = ZeContextDesc {
488 stype: ZE_STRUCTURE_TYPE_CONTEXT_DESC,
489 p_next: std::ptr::null(),
490 };
491 let mut context: ZeContextHandle = std::ptr::null_mut();
492 let rc = unsafe {
494 (api.ze_context_create)(driver, &ctx_desc, &mut context as *mut ZeContextHandle)
495 };
496 if rc != ZE_RESULT_SUCCESS {
497 return Err(LevelZeroError::ZeError(rc, "zeContextCreate failed".into()));
498 }
499
500 let queue_desc = ZeCommandQueueDesc {
502 stype: ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
503 p_next: std::ptr::null(),
504 ordinal: 0,
505 index: 0,
506 flags: 0,
507 mode: 0, priority: 0,
509 };
510 let mut queue: ZeCommandQueueHandle = std::ptr::null_mut();
511 let rc = unsafe {
514 (api.ze_command_queue_create)(
515 context,
516 device,
517 &queue_desc,
518 &mut queue as *mut ZeCommandQueueHandle,
519 )
520 };
521 if rc != ZE_RESULT_SUCCESS {
522 unsafe { (api.ze_context_destroy)(context) };
525 return Err(LevelZeroError::ZeError(
526 rc,
527 "zeCommandQueueCreate failed".into(),
528 ));
529 }
530
531 tracing::info!("Level Zero device selected: {device_name}");
532
533 Ok(Self {
534 api,
535 context,
536 device,
537 queue,
538 device_name,
539 })
540 }
541
542 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
543 {
544 Err(LevelZeroError::UnsupportedPlatform)
545 }
546 }
547
548 pub fn name(&self) -> &str {
550 &self.device_name
551 }
552}
553
554impl Drop for LevelZeroDevice {
557 fn drop(&mut self) {
558 #[cfg(any(target_os = "linux", target_os = "windows"))]
559 {
560 unsafe {
563 (self.api.ze_command_queue_destroy)(self.queue);
564 (self.api.ze_context_destroy)(self.context);
565 }
566 }
567 }
568}
569
570unsafe impl Send for LevelZeroDevice {}
578unsafe impl Sync for LevelZeroDevice {}
581
582impl std::fmt::Debug for LevelZeroDevice {
585 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 write!(f, "LevelZeroDevice({})", self.device_name)
587 }
588}
589
590#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 #[cfg(any(target_os = "linux", target_os = "windows"))]
598 fn level_zero_device_graceful_init() {
599 match LevelZeroDevice::new() {
600 Ok(dev) => {
601 assert!(!dev.name().is_empty());
602 let dbg = format!("{dev:?}");
603 assert!(dbg.contains("LevelZeroDevice"));
604 }
605 Err(LevelZeroError::LibraryNotFound(_)) => {
606 }
608 Err(LevelZeroError::NoSuitableDevice) => {
609 }
611 Err(LevelZeroError::ZeError(_, _)) => {
612 }
614 Err(e) => {
615 let _ = format!("Level Zero device init error (non-fatal): {e}");
617 }
618 }
619 }
620
621 #[test]
622 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
623 fn level_zero_device_unsupported_on_macos() {
624 let result = LevelZeroDevice::new();
625 assert!(matches!(result, Err(LevelZeroError::UnsupportedPlatform)));
626 }
627}