1use crate::level_zero::*;
13use std::ffi::{CStr, c_void};
14use std::sync::{Mutex, OnceLock};
15
16pub struct OneApiDevice {
18 pub lib: Lib,
19 pub driver: DriverHandle,
20 pub device: DeviceHandle,
21 pub context: ContextHandle,
22 pub queue: CommandQueueHandle,
23 pub queue_ordinal: u32,
25 pub name: String,
26 submit_lock: Mutex<()>,
28}
29
30unsafe impl Send for OneApiDevice {}
33unsafe impl Sync for OneApiDevice {}
34
35static DEVICE: OnceLock<Option<OneApiDevice>> = OnceLock::new();
36
37pub fn oneapi_device() -> Option<&'static OneApiDevice> {
39 DEVICE.get_or_init(|| OneApiDevice::new().ok()).as_ref()
40}
41
42impl OneApiDevice {
43 fn new() -> Result<Self, String> {
44 let lib = unsafe { Lib::load() }?;
46
47 unsafe {
48 check((lib.ze_init)(0), "zeInit")?;
49
50 let mut driver_count: u32 = 0;
52 check(
53 (lib.driver_get)(&mut driver_count, std::ptr::null_mut()),
54 "zeDriverGet(count)",
55 )?;
56 if driver_count == 0 {
57 return Err("no Level Zero drivers".into());
58 }
59 let mut drivers: Vec<DriverHandle> = vec![std::ptr::null_mut(); driver_count as usize];
60 check(
61 (lib.driver_get)(&mut driver_count, drivers.as_mut_ptr()),
62 "zeDriverGet",
63 )?;
64
65 let mut chosen: Option<(DriverHandle, DeviceHandle, String)> = None;
67 for &driver in &drivers {
68 let mut dev_count: u32 = 0;
69 if check(
70 (lib.device_get)(driver, &mut dev_count, std::ptr::null_mut()),
71 "zeDeviceGet(count)",
72 )
73 .is_err()
74 || dev_count == 0
75 {
76 continue;
77 }
78 let mut devices: Vec<DeviceHandle> = vec![std::ptr::null_mut(); dev_count as usize];
79 if check(
80 (lib.device_get)(driver, &mut dev_count, devices.as_mut_ptr()),
81 "zeDeviceGet",
82 )
83 .is_err()
84 {
85 continue;
86 }
87 for &device in &devices {
88 let mut props = DeviceProperties::default();
89 if (lib.device_get_properties)(device, &mut props) != ZE_RESULT_SUCCESS {
90 continue;
91 }
92 if props.type_ != ZE_DEVICE_TYPE_GPU {
93 continue;
94 }
95 let name = CStr::from_ptr(props.name.as_ptr())
96 .to_string_lossy()
97 .into_owned();
98 chosen = Some((driver, device, name));
99 break;
100 }
101 if chosen.is_some() {
102 break;
103 }
104 }
105 let (driver, device, name) =
106 chosen.ok_or_else(|| "no Level Zero GPU device".to_string())?;
107
108 let ctx_desc = ContextDesc {
110 stype: ZE_STRUCTURE_TYPE_CONTEXT_DESC,
111 pnext: std::ptr::null(),
112 flags: 0,
113 };
114 let mut context: ContextHandle = std::ptr::null_mut();
115 check(
116 (lib.context_create)(driver, &ctx_desc, &mut context),
117 "zeContextCreate",
118 )?;
119
120 let queue_ordinal = 0u32;
125 let q_desc = CommandQueueDesc {
126 stype: ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
127 pnext: std::ptr::null(),
128 ordinal: queue_ordinal,
129 index: 0,
130 flags: 0,
131 mode: ZE_COMMAND_QUEUE_MODE_DEFAULT,
132 priority: 0,
133 };
134 let mut queue: CommandQueueHandle = std::ptr::null_mut();
135 check(
136 (lib.command_queue_create)(context, device, &q_desc, &mut queue),
137 "zeCommandQueueCreate",
138 )?;
139
140 Ok(Self {
141 lib,
142 driver,
143 device,
144 context,
145 queue,
146 queue_ordinal,
147 name,
148 submit_lock: Mutex::new(()),
149 })
150 }
151 }
152
153 pub fn alloc_shared(&self, size: usize) -> Result<*mut c_void, String> {
157 let dev_desc = DeviceMemAllocDesc {
158 stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
159 pnext: std::ptr::null(),
160 flags: 0,
161 ordinal: 0,
162 };
163 let host_desc = HostMemAllocDesc {
164 stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
165 pnext: std::ptr::null(),
166 flags: 0,
167 };
168 let mut ptr: *mut c_void = std::ptr::null_mut();
169 unsafe {
170 check(
171 (self.lib.mem_alloc_shared)(
172 self.context,
173 &dev_desc,
174 &host_desc,
175 size.max(1),
176 64,
177 self.device,
178 &mut ptr,
179 ),
180 "zeMemAllocShared",
181 )?;
182 std::ptr::write_bytes(ptr as *mut u8, 0, size);
183 }
184 Ok(ptr)
185 }
186
187 #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn free(&self, ptr: *mut c_void) {
190 if !ptr.is_null() {
191 unsafe {
192 let _ = (self.lib.mem_free)(self.context, ptr);
193 }
194 }
195 }
196
197 pub fn create_command_list(&self) -> Result<CommandListHandle, String> {
199 let desc = CommandListDesc {
200 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
201 pnext: std::ptr::null(),
202 command_queue_group_ordinal: self.queue_ordinal,
203 flags: 0,
204 };
205 let mut list: CommandListHandle = std::ptr::null_mut();
206 unsafe {
207 check(
208 (self.lib.command_list_create)(self.context, self.device, &desc, &mut list),
209 "zeCommandListCreate",
210 )?;
211 }
212 Ok(list)
213 }
214
215 #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn execute_sync(&self, list: CommandListHandle) -> Result<(), String> {
219 let _guard = self.submit_lock.lock().unwrap();
220 unsafe {
221 check((self.lib.command_list_close)(list), "zeCommandListClose")?;
222 let lists = [list];
223 check(
224 (self.lib.command_queue_execute)(
225 self.queue,
226 1,
227 lists.as_ptr(),
228 std::ptr::null_mut(),
229 ),
230 "zeCommandQueueExecuteCommandLists",
231 )?;
232 check(
233 (self.lib.command_queue_synchronize)(self.queue, u64::MAX),
234 "zeCommandQueueSynchronize",
235 )?;
236 }
237 Ok(())
238 }
239}
240
241impl Drop for OneApiDevice {
242 fn drop(&mut self) {
243 unsafe {
244 if !self.queue.is_null() {
245 let _ = (self.lib.command_queue_destroy)(self.queue);
246 }
247 if !self.context.is_null() {
248 let _ = (self.lib.context_destroy)(self.context);
249 }
250 }
251 }
252}