1use std::{
11 collections::HashMap,
12 sync::{Arc, Mutex},
13};
14
15#[cfg(any(target_os = "linux", target_os = "windows"))]
16use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
17
18use crate::{
19 device::LevelZeroDevice,
20 error::{LevelZeroError, LevelZeroResult},
21};
22
23#[cfg(any(target_os = "linux", target_os = "windows"))]
26use std::ffi::c_void;
27
28#[cfg(any(target_os = "linux", target_os = "windows"))]
29use crate::device::{
30 ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
31 ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC, ZeCommandListDesc, ZeCommandListHandle,
32 ZeDeviceMemAllocDesc, ZeHostMemAllocDesc,
33};
34
35struct L0BufferRecord {
39 #[cfg(any(target_os = "linux", target_os = "windows"))]
41 device_ptr: *mut c_void,
42 #[cfg(any(target_os = "linux", target_os = "windows"))]
44 size: u64,
45}
46
47#[cfg(any(target_os = "linux", target_os = "windows"))]
50unsafe impl Send for L0BufferRecord {}
51
52pub struct LevelZeroMemoryManager {
61 #[cfg(any(target_os = "linux", target_os = "windows"))]
62 device: Arc<LevelZeroDevice>,
63 buffers: Mutex<HashMap<u64, L0BufferRecord>>,
64 #[cfg(any(target_os = "linux", target_os = "windows"))]
65 next_handle: AtomicU64,
66}
67
68impl LevelZeroMemoryManager {
69 #[cfg(any(target_os = "linux", target_os = "windows"))]
71 pub fn new(device: Arc<LevelZeroDevice>) -> Self {
72 Self {
73 device,
74 buffers: Mutex::new(HashMap::new()),
75 next_handle: AtomicU64::new(1),
76 }
77 }
78
79 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
83 pub fn new(_device: Arc<LevelZeroDevice>) -> Self {
84 Self {
85 buffers: Mutex::new(HashMap::new()),
86 }
87 }
88
89 #[cfg(any(target_os = "linux", target_os = "windows"))]
94 pub fn device_ptr(&self, handle: u64) -> LevelZeroResult<*mut c_void> {
95 let buffers = self
96 .buffers
97 .lock()
98 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
99 let rec = buffers
100 .get(&handle)
101 .ok_or_else(|| LevelZeroError::InvalidArgument(format!("unknown handle {handle}")))?;
102 Ok(rec.device_ptr)
103 }
104
105 pub fn alloc(&self, bytes: usize) -> LevelZeroResult<u64> {
109 #[cfg(any(target_os = "linux", target_os = "windows"))]
110 {
111 let api = &self.device.api;
112 let context = self.device.context;
113 let device_handle = self.device.device;
114
115 let desc = ZeDeviceMemAllocDesc {
116 stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
117 p_next: std::ptr::null(),
118 flags: 0,
119 ordinal: 0,
120 };
121
122 let mut ptr: *mut c_void = std::ptr::null_mut();
123 let rc = unsafe {
126 (api.ze_mem_alloc_device)(
127 context,
128 &desc,
129 bytes,
130 64, device_handle,
132 &mut ptr as *mut *mut c_void,
133 )
134 };
135
136 if rc != 0 {
137 return Err(LevelZeroError::ZeError(
138 rc,
139 "zeMemAllocDevice failed".into(),
140 ));
141 }
142
143 let handle = self.next_handle.fetch_add(1, Relaxed);
144 self.buffers
145 .lock()
146 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
147 .insert(
148 handle,
149 L0BufferRecord {
150 device_ptr: ptr,
151 size: bytes as u64,
152 },
153 );
154
155 Ok(handle)
156 }
157
158 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
159 {
160 let _ = bytes;
161 Err(LevelZeroError::UnsupportedPlatform)
162 }
163 }
164
165 pub fn free(&self, handle: u64) -> LevelZeroResult<()> {
169 #[cfg(any(target_os = "linux", target_os = "windows"))]
170 {
171 let record = self
172 .buffers
173 .lock()
174 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
175 .remove(&handle);
176
177 if let Some(rec) = record {
178 let api = &self.device.api;
179 let context = self.device.context;
180 let rc = unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
183 if rc != 0 {
184 return Err(LevelZeroError::ZeError(rc, "zeMemFree failed".into()));
185 }
186 }
187 Ok(())
188 }
189
190 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
191 {
192 let _ = handle;
193 Err(LevelZeroError::UnsupportedPlatform)
194 }
195 }
196
197 pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> LevelZeroResult<()> {
202 #[cfg(any(target_os = "linux", target_os = "windows"))]
203 {
204 let device_ptr = {
205 let buffers = self
206 .buffers
207 .lock()
208 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
209 let rec = buffers.get(&handle).ok_or_else(|| {
210 LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
211 })?;
212 rec.device_ptr
213 };
214
215 let api = &self.device.api;
216 let context = self.device.context;
217 let device_handle = self.device.device;
218 let queue = self.device.queue;
219 let copy_len = src.len();
220
221 let host_desc = ZeHostMemAllocDesc {
223 stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
224 p_next: std::ptr::null(),
225 flags: 0,
226 };
227 let mut host_ptr: *mut c_void = std::ptr::null_mut();
228 let rc = unsafe {
231 (api.ze_mem_alloc_host)(
232 context,
233 &host_desc,
234 copy_len,
235 64,
236 &mut host_ptr as *mut *mut c_void,
237 )
238 };
239 if rc != 0 {
240 return Err(LevelZeroError::ZeError(
241 rc,
242 "zeMemAllocHost (staging) failed".into(),
243 ));
244 }
245
246 unsafe {
250 std::ptr::copy_nonoverlapping(src.as_ptr(), host_ptr as *mut u8, copy_len);
251 }
252
253 let list_desc = ZeCommandListDesc {
255 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
256 p_next: std::ptr::null(),
257 command_queue_group_ordinal: 0,
258 flags: 0,
259 };
260 let mut list: ZeCommandListHandle = std::ptr::null_mut();
261 let rc = unsafe {
264 (api.ze_command_list_create)(
265 context,
266 device_handle,
267 &list_desc,
268 &mut list as *mut ZeCommandListHandle,
269 )
270 };
271 if rc != 0 {
272 unsafe { (api.ze_mem_free)(context, host_ptr) };
274 return Err(LevelZeroError::CommandListError(format!(
275 "zeCommandListCreate failed: 0x{rc:08x}"
276 )));
277 }
278
279 let rc = unsafe {
283 (api.ze_command_list_append_memory_copy)(
284 list,
285 device_ptr,
286 host_ptr as *const c_void,
287 copy_len,
288 0, 0, std::ptr::null(),
291 )
292 };
293 if rc != 0 {
294 unsafe {
296 (api.ze_command_list_destroy)(list);
297 (api.ze_mem_free)(context, host_ptr);
298 }
299 return Err(LevelZeroError::CommandListError(format!(
300 "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
301 )));
302 }
303
304 let rc = unsafe { (api.ze_command_list_close)(list) };
307 if rc != 0 {
308 unsafe {
309 (api.ze_command_list_destroy)(list);
310 (api.ze_mem_free)(context, host_ptr);
311 }
312 return Err(LevelZeroError::CommandListError(format!(
313 "zeCommandListClose failed: 0x{rc:08x}"
314 )));
315 }
316
317 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
319 if rc != 0 {
320 unsafe {
321 (api.ze_command_list_destroy)(list);
322 (api.ze_mem_free)(context, host_ptr);
323 }
324 return Err(LevelZeroError::CommandListError(format!(
325 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
326 )));
327 }
328
329 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
332 if rc != 0 {
333 unsafe {
334 (api.ze_command_list_destroy)(list);
335 (api.ze_mem_free)(context, host_ptr);
336 }
337 return Err(LevelZeroError::CommandListError(format!(
338 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
339 )));
340 }
341
342 unsafe {
345 (api.ze_command_list_destroy)(list);
346 (api.ze_mem_free)(context, host_ptr);
347 }
348
349 Ok(())
350 }
351
352 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
353 {
354 let _ = (handle, src);
355 Err(LevelZeroError::UnsupportedPlatform)
356 }
357 }
358
359 pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> LevelZeroResult<()> {
363 #[cfg(any(target_os = "linux", target_os = "windows"))]
364 {
365 let device_ptr = {
366 let buffers = self
367 .buffers
368 .lock()
369 .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
370 let rec = buffers.get(&handle).ok_or_else(|| {
371 LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
372 })?;
373 rec.device_ptr
374 };
375
376 let api = &self.device.api;
377 let context = self.device.context;
378 let device_handle = self.device.device;
379 let queue = self.device.queue;
380 let copy_len = dst.len();
381
382 let host_desc = ZeHostMemAllocDesc {
384 stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
385 p_next: std::ptr::null(),
386 flags: 0,
387 };
388 let mut host_ptr: *mut c_void = std::ptr::null_mut();
389 let rc = unsafe {
392 (api.ze_mem_alloc_host)(
393 context,
394 &host_desc,
395 copy_len,
396 64,
397 &mut host_ptr as *mut *mut c_void,
398 )
399 };
400 if rc != 0 {
401 return Err(LevelZeroError::ZeError(
402 rc,
403 "zeMemAllocHost (staging) failed".into(),
404 ));
405 }
406
407 let list_desc = ZeCommandListDesc {
409 stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
410 p_next: std::ptr::null(),
411 command_queue_group_ordinal: 0,
412 flags: 0,
413 };
414 let mut list: ZeCommandListHandle = std::ptr::null_mut();
415 let rc = unsafe {
418 (api.ze_command_list_create)(
419 context,
420 device_handle,
421 &list_desc,
422 &mut list as *mut ZeCommandListHandle,
423 )
424 };
425 if rc != 0 {
426 unsafe { (api.ze_mem_free)(context, host_ptr) };
427 return Err(LevelZeroError::CommandListError(format!(
428 "zeCommandListCreate failed: 0x{rc:08x}"
429 )));
430 }
431
432 let rc = unsafe {
436 (api.ze_command_list_append_memory_copy)(
437 list,
438 host_ptr,
439 device_ptr as *const c_void,
440 copy_len,
441 0, 0, std::ptr::null(),
444 )
445 };
446 if rc != 0 {
447 unsafe {
448 (api.ze_command_list_destroy)(list);
449 (api.ze_mem_free)(context, host_ptr);
450 }
451 return Err(LevelZeroError::CommandListError(format!(
452 "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
453 )));
454 }
455
456 let rc = unsafe { (api.ze_command_list_close)(list) };
459 if rc != 0 {
460 unsafe {
461 (api.ze_command_list_destroy)(list);
462 (api.ze_mem_free)(context, host_ptr);
463 }
464 return Err(LevelZeroError::CommandListError(format!(
465 "zeCommandListClose failed: 0x{rc:08x}"
466 )));
467 }
468
469 let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
471 if rc != 0 {
472 unsafe {
473 (api.ze_command_list_destroy)(list);
474 (api.ze_mem_free)(context, host_ptr);
475 }
476 return Err(LevelZeroError::CommandListError(format!(
477 "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
478 )));
479 }
480
481 let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
484 if rc != 0 {
485 unsafe {
486 (api.ze_command_list_destroy)(list);
487 (api.ze_mem_free)(context, host_ptr);
488 }
489 return Err(LevelZeroError::CommandListError(format!(
490 "zeCommandQueueSynchronize failed: 0x{rc:08x}"
491 )));
492 }
493
494 unsafe {
498 std::ptr::copy_nonoverlapping(host_ptr as *const u8, dst.as_mut_ptr(), copy_len);
499 }
500
501 unsafe {
504 (api.ze_command_list_destroy)(list);
505 (api.ze_mem_free)(context, host_ptr);
506 }
507
508 Ok(())
509 }
510
511 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
512 {
513 let _ = (dst, handle);
514 Err(LevelZeroError::UnsupportedPlatform)
515 }
516 }
517}
518
519impl Drop for LevelZeroMemoryManager {
522 fn drop(&mut self) {
523 #[cfg(any(target_os = "linux", target_os = "windows"))]
524 {
525 let api = &self.device.api;
526 let context = self.device.context;
527
528 if let Ok(mut map) = self.buffers.lock() {
529 for (handle, rec) in map.drain() {
530 tracing::warn!(
531 "LevelZeroMemoryManager: leaked buffer handle {handle} ({} bytes)",
532 rec.size
533 );
534 unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
536 }
537 }
538 }
539
540 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
541 {
542 }
544 }
545}
546
547impl std::fmt::Debug for LevelZeroMemoryManager {
550 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551 let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
552 write!(f, "LevelZeroMemoryManager(buffers={count})")
553 }
554}
555
556unsafe impl Send for LevelZeroMemoryManager {}
561unsafe impl Sync for LevelZeroMemoryManager {}
563
564#[cfg(test)]
567mod tests {
568 use super::*;
569
570 fn try_get_device() -> Option<Arc<LevelZeroDevice>> {
571 LevelZeroDevice::new().ok().map(Arc::new)
572 }
573
574 #[test]
575 fn alloc_and_free_requires_device() {
576 let Some(dev) = try_get_device() else {
577 return;
578 };
579 let mm = LevelZeroMemoryManager::new(dev);
580 let h = mm.alloc(256).expect("alloc 256 bytes");
581 assert!(h > 0);
582 mm.free(h).expect("free");
583 mm.free(h).expect("double-free is a no-op");
585 }
586
587 #[test]
588 fn copy_roundtrip_requires_device() {
589 let Some(dev) = try_get_device() else {
590 return;
591 };
592 let mm = LevelZeroMemoryManager::new(dev);
593
594 let src: Vec<u8> = (0u8..64).collect();
595 let h = mm.alloc(src.len()).expect("alloc");
596 mm.copy_to_device(h, &src).expect("copy_to_device");
597
598 let mut dst = vec![0u8; src.len()];
599 mm.copy_from_device(&mut dst, h).expect("copy_from_device");
600
601 assert_eq!(src, dst);
602 mm.free(h).expect("free");
603 }
604
605 #[test]
606 fn unknown_handle_returns_error() {
607 let Some(dev) = try_get_device() else {
608 return;
609 };
610 let mm = LevelZeroMemoryManager::new(dev);
611 let err = mm.copy_to_device(9999, b"hello").unwrap_err();
612 assert!(matches!(err, LevelZeroError::InvalidArgument(_)));
613 }
614
615 #[test]
616 fn debug_impl_smoke() {
617 let Some(dev) = try_get_device() else {
618 return;
619 };
620 let mm = LevelZeroMemoryManager::new(dev);
621 let s = format!("{mm:?}");
622 assert!(s.contains("LevelZeroMemoryManager"));
623 }
624}