1use std::{
11 collections::HashMap,
12 sync::{Arc, Mutex},
13};
14
15#[cfg(any(target_os = "linux", target_os = "windows"))]
16use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
17
18use crate::{
19 device::LevelZeroDevice,
20 error::{LevelZeroError, LevelZeroResult},
21};
22
23#[cfg(any(target_os = "linux", target_os = "windows"))]
26use std::ffi::c_void;
27
28#[cfg(any(target_os = "linux", target_os = "windows"))]
29use crate::device::{
30 ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
31 ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC, ZeCommandListDesc, ZeCommandListHandle,
32 ZeDeviceMemAllocDesc, ZeHostMemAllocDesc,
33};
34
35struct L0BufferRecord {
39 #[cfg(any(target_os = "linux", target_os = "windows"))]
41 device_ptr: *mut c_void,
42 #[cfg(any(target_os = "linux", target_os = "windows"))]
44 size: u64,
45}
46
47#[cfg(any(target_os = "linux", target_os = "windows"))]
50unsafe impl Send for L0BufferRecord {}
51
52pub struct LevelZeroMemoryManager {
61 #[cfg(any(target_os = "linux", target_os = "windows"))]
62 device: Arc<LevelZeroDevice>,
63 buffers: Mutex<HashMap<u64, L0BufferRecord>>,
64 #[cfg(any(target_os = "linux", target_os = "windows"))]
65 next_handle: AtomicU64,
66}
67
68impl LevelZeroMemoryManager {
69 #[cfg(any(target_os = "linux", target_os = "windows"))]
71 pub fn new(device: Arc<LevelZeroDevice>) -> Self {
72 Self {
73 device,
74 buffers: Mutex::new(HashMap::new()),
75 next_handle: AtomicU64::new(1),
76 }
77 }
78
79 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
83 pub fn new(_device: Arc<LevelZeroDevice>) -> Self {
84 Self {
85 buffers: Mutex::new(HashMap::new()),
86 }
87 }
88
89 pub fn alloc(&self, bytes: usize) -> LevelZeroResult<u64> {
93 #[cfg(any(target_os = "linux", target_os = "windows"))]
94 {
95 let api = &self.device.api;
96 let context = self.device.context;
97 let device_handle = self.device.device;
98
99 let desc = ZeDeviceMemAllocDesc {
100 stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
101 p_next: std::ptr::null(),
102 flags: 0,
103 ordinal: 0,
104 };
105
106 let mut ptr: *mut c_void = std::ptr::null_mut();
107 let rc = unsafe {
110 (api.ze_mem_alloc_device)(
111 context,
112 &desc,
113 bytes,
114 64, device_handle,
116 &mut ptr as *mut *mut c_void,
117 )
118 };
119
120 if rc != 0 {
121 return Err(LevelZeroError::ZeError(
122 rc,
123 "zeMemAllocDevice failed".into(),
124 ));
125 }
126
127 let handle = self.next_handle.fetch_add(1, Relaxed);
128 self.buffers
129 .lock()
130 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
131 .insert(
132 handle,
133 L0BufferRecord {
134 device_ptr: ptr,
135 size: bytes as u64,
136 },
137 );
138
139 Ok(handle)
140 }
141
142 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
143 {
144 let _ = bytes;
145 Err(LevelZeroError::UnsupportedPlatform)
146 }
147 }
148
149 pub fn free(&self, handle: u64) -> LevelZeroResult<()> {
153 #[cfg(any(target_os = "linux", target_os = "windows"))]
154 {
155 let record = self
156 .buffers
157 .lock()
158 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
159 .remove(&handle);
160
161 if let Some(rec) = record {
162 let api = &self.device.api;
163 let context = self.device.context;
164 let rc = unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
167 if rc != 0 {
168 return Err(LevelZeroError::ZeError(rc, "zeMemFree failed".into()));
169 }
170 }
171 Ok(())
172 }
173
174 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
175 {
176 let _ = handle;
177 Err(LevelZeroError::UnsupportedPlatform)
178 }
179 }
180
181 pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> LevelZeroResult<()> {
186 #[cfg(any(target_os = "linux", target_os = "windows"))]
187 {
188 let device_ptr = {
189 let buffers = self
190 .buffers
191 .lock()
192 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
193 let rec = buffers.get(&handle).ok_or_else(|| {
194 LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
195 })?;
196 rec.device_ptr
197 };
198
199 let api = &self.device.api;
200 let context = self.device.context;
201 let device_handle = self.device.device;
202 let queue = self.device.queue;
203 let copy_len = src.len();
204
205 let host_desc = ZeHostMemAllocDesc {
207 stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
208 p_next: std::ptr::null(),
209 flags: 0,
210 };
211 let mut host_ptr: *mut c_void = std::ptr::null_mut();
212 let rc = unsafe {
215 (api.ze_mem_alloc_host)(
216 context,
217 &host_desc,
218 copy_len,
219 64,
220 &mut host_ptr as *mut *mut c_void,
221 )
222 };
223 if rc != 0 {
224 return Err(LevelZeroError::ZeError(
225 rc,
226 "zeMemAllocHost (staging) failed".into(),
227 ));
228 }
229
230 unsafe {
234 std::ptr::copy_nonoverlapping(src.as_ptr(), host_ptr as *mut u8, copy_len);
235 }
236
237 let list_desc = ZeCommandListDesc {
239 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
240 p_next: std::ptr::null(),
241 command_queue_group_ordinal: 0,
242 flags: 0,
243 };
244 let mut list: ZeCommandListHandle = std::ptr::null_mut();
245 let rc = unsafe {
248 (api.ze_command_list_create)(
249 context,
250 device_handle,
251 &list_desc,
252 &mut list as *mut ZeCommandListHandle,
253 )
254 };
255 if rc != 0 {
256 unsafe { (api.ze_mem_free)(context, host_ptr) };
258 return Err(LevelZeroError::CommandListError(format!(
259 "zeCommandListCreate failed: 0x{rc:08x}"
260 )));
261 }
262
263 let rc = unsafe {
267 (api.ze_command_list_append_memory_copy)(
268 list,
269 device_ptr,
270 host_ptr as *const c_void,
271 copy_len,
272 0, 0, std::ptr::null(),
275 )
276 };
277 if rc != 0 {
278 unsafe {
280 (api.ze_command_list_destroy)(list);
281 (api.ze_mem_free)(context, host_ptr);
282 }
283 return Err(LevelZeroError::CommandListError(format!(
284 "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
285 )));
286 }
287
288 let rc = unsafe { (api.ze_command_list_close)(list) };
291 if rc != 0 {
292 unsafe {
293 (api.ze_command_list_destroy)(list);
294 (api.ze_mem_free)(context, host_ptr);
295 }
296 return Err(LevelZeroError::CommandListError(format!(
297 "zeCommandListClose failed: 0x{rc:08x}"
298 )));
299 }
300
301 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
303 if rc != 0 {
304 unsafe {
305 (api.ze_command_list_destroy)(list);
306 (api.ze_mem_free)(context, host_ptr);
307 }
308 return Err(LevelZeroError::CommandListError(format!(
309 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
310 )));
311 }
312
313 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
316 if rc != 0 {
317 unsafe {
318 (api.ze_command_list_destroy)(list);
319 (api.ze_mem_free)(context, host_ptr);
320 }
321 return Err(LevelZeroError::CommandListError(format!(
322 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
323 )));
324 }
325
326 unsafe {
329 (api.ze_command_list_destroy)(list);
330 (api.ze_mem_free)(context, host_ptr);
331 }
332
333 Ok(())
334 }
335
336 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
337 {
338 let _ = (handle, src);
339 Err(LevelZeroError::UnsupportedPlatform)
340 }
341 }
342
343 pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> LevelZeroResult<()> {
347 #[cfg(any(target_os = "linux", target_os = "windows"))]
348 {
349 let device_ptr = {
350 let buffers = self
351 .buffers
352 .lock()
353 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
354 let rec = buffers.get(&handle).ok_or_else(|| {
355 LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
356 })?;
357 rec.device_ptr
358 };
359
360 let api = &self.device.api;
361 let context = self.device.context;
362 let device_handle = self.device.device;
363 let queue = self.device.queue;
364 let copy_len = dst.len();
365
366 let host_desc = ZeHostMemAllocDesc {
368 stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
369 p_next: std::ptr::null(),
370 flags: 0,
371 };
372 let mut host_ptr: *mut c_void = std::ptr::null_mut();
373 let rc = unsafe {
376 (api.ze_mem_alloc_host)(
377 context,
378 &host_desc,
379 copy_len,
380 64,
381 &mut host_ptr as *mut *mut c_void,
382 )
383 };
384 if rc != 0 {
385 return Err(LevelZeroError::ZeError(
386 rc,
387 "zeMemAllocHost (staging) failed".into(),
388 ));
389 }
390
391 let list_desc = ZeCommandListDesc {
393 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
394 p_next: std::ptr::null(),
395 command_queue_group_ordinal: 0,
396 flags: 0,
397 };
398 let mut list: ZeCommandListHandle = std::ptr::null_mut();
399 let rc = unsafe {
402 (api.ze_command_list_create)(
403 context,
404 device_handle,
405 &list_desc,
406 &mut list as *mut ZeCommandListHandle,
407 )
408 };
409 if rc != 0 {
410 unsafe { (api.ze_mem_free)(context, host_ptr) };
411 return Err(LevelZeroError::CommandListError(format!(
412 "zeCommandListCreate failed: 0x{rc:08x}"
413 )));
414 }
415
416 let rc = unsafe {
420 (api.ze_command_list_append_memory_copy)(
421 list,
422 host_ptr,
423 device_ptr as *const c_void,
424 copy_len,
425 0, 0, std::ptr::null(),
428 )
429 };
430 if rc != 0 {
431 unsafe {
432 (api.ze_command_list_destroy)(list);
433 (api.ze_mem_free)(context, host_ptr);
434 }
435 return Err(LevelZeroError::CommandListError(format!(
436 "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
437 )));
438 }
439
440 let rc = unsafe { (api.ze_command_list_close)(list) };
443 if rc != 0 {
444 unsafe {
445 (api.ze_command_list_destroy)(list);
446 (api.ze_mem_free)(context, host_ptr);
447 }
448 return Err(LevelZeroError::CommandListError(format!(
449 "zeCommandListClose failed: 0x{rc:08x}"
450 )));
451 }
452
453 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
455 if rc != 0 {
456 unsafe {
457 (api.ze_command_list_destroy)(list);
458 (api.ze_mem_free)(context, host_ptr);
459 }
460 return Err(LevelZeroError::CommandListError(format!(
461 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
462 )));
463 }
464
465 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
468 if rc != 0 {
469 unsafe {
470 (api.ze_command_list_destroy)(list);
471 (api.ze_mem_free)(context, host_ptr);
472 }
473 return Err(LevelZeroError::CommandListError(format!(
474 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
475 )));
476 }
477
478 unsafe {
482 std::ptr::copy_nonoverlapping(host_ptr as *const u8, dst.as_mut_ptr(), copy_len);
483 }
484
485 unsafe {
488 (api.ze_command_list_destroy)(list);
489 (api.ze_mem_free)(context, host_ptr);
490 }
491
492 Ok(())
493 }
494
495 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
496 {
497 let _ = (dst, handle);
498 Err(LevelZeroError::UnsupportedPlatform)
499 }
500 }
501}
502
503impl Drop for LevelZeroMemoryManager {
506 fn drop(&mut self) {
507 #[cfg(any(target_os = "linux", target_os = "windows"))]
508 {
509 let api = &self.device.api;
510 let context = self.device.context;
511
512 if let Ok(mut map) = self.buffers.lock() {
513 for (handle, rec) in map.drain() {
514 tracing::warn!(
515 "LevelZeroMemoryManager: leaked buffer handle {handle} ({} bytes)",
516 rec.size
517 );
518 unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
520 }
521 }
522 }
523
524 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
525 {
526 }
528 }
529}
530
531impl std::fmt::Debug for LevelZeroMemoryManager {
534 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
536 write!(f, "LevelZeroMemoryManager(buffers={count})")
537 }
538}
539
540unsafe impl Send for LevelZeroMemoryManager {}
545unsafe impl Sync for LevelZeroMemoryManager {}
547
548#[cfg(test)]
551mod tests {
552 use super::*;
553
554 fn try_get_device() -> Option<Arc<LevelZeroDevice>> {
555 LevelZeroDevice::new().ok().map(Arc::new)
556 }
557
558 #[test]
559 fn alloc_and_free_requires_device() {
560 let Some(dev) = try_get_device() else {
561 return;
562 };
563 let mm = LevelZeroMemoryManager::new(dev);
564 let h = mm.alloc(256).expect("alloc 256 bytes");
565 assert!(h > 0);
566 mm.free(h).expect("free");
567 mm.free(h).expect("double-free is a no-op");
569 }
570
571 #[test]
572 fn copy_roundtrip_requires_device() {
573 let Some(dev) = try_get_device() else {
574 return;
575 };
576 let mm = LevelZeroMemoryManager::new(dev);
577
578 let src: Vec<u8> = (0u8..64).collect();
579 let h = mm.alloc(src.len()).expect("alloc");
580 mm.copy_to_device(h, &src).expect("copy_to_device");
581
582 let mut dst = vec![0u8; src.len()];
583 mm.copy_from_device(&mut dst, h).expect("copy_from_device");
584
585 assert_eq!(src, dst);
586 mm.free(h).expect("free");
587 }
588
589 #[test]
590 fn unknown_handle_returns_error() {
591 let Some(dev) = try_get_device() else {
592 return;
593 };
594 let mm = LevelZeroMemoryManager::new(dev);
595 let err = mm.copy_to_device(9999, b"hello").unwrap_err();
596 assert!(matches!(err, LevelZeroError::InvalidArgument(_)));
597 }
598
599 #[test]
600 fn debug_impl_smoke() {
601 let Some(dev) = try_get_device() else {
602 return;
603 };
604 let mm = LevelZeroMemoryManager::new(dev);
605 let s = format!("{mm:?}");
606 assert!(s.contains("LevelZeroMemoryManager"));
607 }
608}