use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
#[cfg(any(target_os = "linux", target_os = "windows"))]
use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
use crate::{
device::LevelZeroDevice,
error::{LevelZeroError, LevelZeroResult},
};
#[cfg(any(target_os = "linux", target_os = "windows"))]
use std::ffi::c_void;
#[cfg(any(target_os = "linux", target_os = "windows"))]
use crate::device::{
ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC, ZeCommandListDesc, ZeCommandListHandle,
ZeDeviceMemAllocDesc, ZeHostMemAllocDesc,
};
struct L0BufferRecord {
#[cfg(any(target_os = "linux", target_os = "windows"))]
device_ptr: *mut c_void,
#[cfg(any(target_os = "linux", target_os = "windows"))]
size: u64,
}
#[cfg(any(target_os = "linux", target_os = "windows"))]
unsafe impl Send for L0BufferRecord {}
pub struct LevelZeroMemoryManager {
#[cfg(any(target_os = "linux", target_os = "windows"))]
device: Arc<LevelZeroDevice>,
buffers: Mutex<HashMap<u64, L0BufferRecord>>,
#[cfg(any(target_os = "linux", target_os = "windows"))]
next_handle: AtomicU64,
}
impl LevelZeroMemoryManager {
#[cfg(any(target_os = "linux", target_os = "windows"))]
pub fn new(device: Arc<LevelZeroDevice>) -> Self {
Self {
device,
buffers: Mutex::new(HashMap::new()),
next_handle: AtomicU64::new(1),
}
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
pub fn new(_device: Arc<LevelZeroDevice>) -> Self {
Self {
buffers: Mutex::new(HashMap::new()),
}
}
#[cfg(any(target_os = "linux", target_os = "windows"))]
pub fn device_ptr(&self, handle: u64) -> LevelZeroResult<*mut c_void> {
let buffers = self
.buffers
.lock()
.map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
let rec = buffers
.get(&handle)
.ok_or_else(|| LevelZeroError::InvalidArgument(format!("unknown handle {handle}")))?;
Ok(rec.device_ptr)
}
pub fn alloc(&self, bytes: usize) -> LevelZeroResult<u64> {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
let api = &self.device.api;
let context = self.device.context;
let device_handle = self.device.device;
let desc = ZeDeviceMemAllocDesc {
stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
p_next: std::ptr::null(),
flags: 0,
ordinal: 0,
};
let mut ptr: *mut c_void = std::ptr::null_mut();
let rc = unsafe {
(api.ze_mem_alloc_device)(
context,
&desc,
bytes,
64, device_handle,
&mut ptr as *mut *mut c_void,
)
};
if rc != 0 {
return Err(LevelZeroError::ZeError(
rc,
"zeMemAllocDevice failed".into(),
));
}
let handle = self.next_handle.fetch_add(1, Relaxed);
self.buffers
.lock()
.map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
.insert(
handle,
L0BufferRecord {
device_ptr: ptr,
size: bytes as u64,
},
);
Ok(handle)
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
let _ = bytes;
Err(LevelZeroError::UnsupportedPlatform)
}
}
pub fn free(&self, handle: u64) -> LevelZeroResult<()> {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
let record = self
.buffers
.lock()
.map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
.remove(&handle);
if let Some(rec) = record {
let api = &self.device.api;
let context = self.device.context;
let rc = unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
if rc != 0 {
return Err(LevelZeroError::ZeError(rc, "zeMemFree failed".into()));
}
}
Ok(())
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
let _ = handle;
Err(LevelZeroError::UnsupportedPlatform)
}
}
pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> LevelZeroResult<()> {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
let device_ptr = {
let buffers = self
.buffers
.lock()
.map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
let rec = buffers.get(&handle).ok_or_else(|| {
LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
})?;
rec.device_ptr
};
let api = &self.device.api;
let context = self.device.context;
let device_handle = self.device.device;
let queue = self.device.queue;
let copy_len = src.len();
let host_desc = ZeHostMemAllocDesc {
stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
p_next: std::ptr::null(),
flags: 0,
};
let mut host_ptr: *mut c_void = std::ptr::null_mut();
let rc = unsafe {
(api.ze_mem_alloc_host)(
context,
&host_desc,
copy_len,
64,
&mut host_ptr as *mut *mut c_void,
)
};
if rc != 0 {
return Err(LevelZeroError::ZeError(
rc,
"zeMemAllocHost (staging) failed".into(),
));
}
unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), host_ptr as *mut u8, copy_len);
}
let list_desc = ZeCommandListDesc {
stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
p_next: std::ptr::null(),
command_queue_group_ordinal: 0,
flags: 0,
};
let mut list: ZeCommandListHandle = std::ptr::null_mut();
let rc = unsafe {
(api.ze_command_list_create)(
context,
device_handle,
&list_desc,
&mut list as *mut ZeCommandListHandle,
)
};
if rc != 0 {
unsafe { (api.ze_mem_free)(context, host_ptr) };
return Err(LevelZeroError::CommandListError(format!(
"zeCommandListCreate failed: 0x{rc:08x}"
)));
}
let rc = unsafe {
(api.ze_command_list_append_memory_copy)(
list,
device_ptr,
host_ptr as *const c_void,
copy_len,
0, 0, std::ptr::null(),
)
};
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
)));
}
let rc = unsafe { (api.ze_command_list_close)(list) };
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandListClose failed: 0x{rc:08x}"
)));
}
let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
)));
}
let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandQueueSynchronize failed: 0x{rc:08x}"
)));
}
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
Ok(())
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
let _ = (handle, src);
Err(LevelZeroError::UnsupportedPlatform)
}
}
pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> LevelZeroResult<()> {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
let device_ptr = {
let buffers = self
.buffers
.lock()
.map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
let rec = buffers.get(&handle).ok_or_else(|| {
LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
})?;
rec.device_ptr
};
let api = &self.device.api;
let context = self.device.context;
let device_handle = self.device.device;
let queue = self.device.queue;
let copy_len = dst.len();
let host_desc = ZeHostMemAllocDesc {
stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
p_next: std::ptr::null(),
flags: 0,
};
let mut host_ptr: *mut c_void = std::ptr::null_mut();
let rc = unsafe {
(api.ze_mem_alloc_host)(
context,
&host_desc,
copy_len,
64,
&mut host_ptr as *mut *mut c_void,
)
};
if rc != 0 {
return Err(LevelZeroError::ZeError(
rc,
"zeMemAllocHost (staging) failed".into(),
));
}
let list_desc = ZeCommandListDesc {
stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
p_next: std::ptr::null(),
command_queue_group_ordinal: 0,
flags: 0,
};
let mut list: ZeCommandListHandle = std::ptr::null_mut();
let rc = unsafe {
(api.ze_command_list_create)(
context,
device_handle,
&list_desc,
&mut list as *mut ZeCommandListHandle,
)
};
if rc != 0 {
unsafe { (api.ze_mem_free)(context, host_ptr) };
return Err(LevelZeroError::CommandListError(format!(
"zeCommandListCreate failed: 0x{rc:08x}"
)));
}
let rc = unsafe {
(api.ze_command_list_append_memory_copy)(
list,
host_ptr,
device_ptr as *const c_void,
copy_len,
0, 0, std::ptr::null(),
)
};
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
)));
}
let rc = unsafe { (api.ze_command_list_close)(list) };
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandListClose failed: 0x{rc:08x}"
)));
}
let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
)));
}
let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
if rc != 0 {
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
return Err(LevelZeroError::CommandListError(format!(
"zeCommandQueueSynchronize failed: 0x{rc:08x}"
)));
}
unsafe {
std::ptr::copy_nonoverlapping(host_ptr as *const u8, dst.as_mut_ptr(), copy_len);
}
unsafe {
(api.ze_command_list_destroy)(list);
(api.ze_mem_free)(context, host_ptr);
}
Ok(())
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
let _ = (dst, handle);
Err(LevelZeroError::UnsupportedPlatform)
}
}
}
impl Drop for LevelZeroMemoryManager {
fn drop(&mut self) {
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
let api = &self.device.api;
let context = self.device.context;
if let Ok(mut map) = self.buffers.lock() {
for (handle, rec) in map.drain() {
tracing::warn!(
"LevelZeroMemoryManager: leaked buffer handle {handle} ({} bytes)",
rec.size
);
unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
}
}
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
}
}
}
impl std::fmt::Debug for LevelZeroMemoryManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
write!(f, "LevelZeroMemoryManager(buffers={count})")
}
}
unsafe impl Send for LevelZeroMemoryManager {}
unsafe impl Sync for LevelZeroMemoryManager {}
#[cfg(test)]
mod tests {
use super::*;
fn try_get_device() -> Option<Arc<LevelZeroDevice>> {
LevelZeroDevice::new().ok().map(Arc::new)
}
#[test]
fn alloc_and_free_requires_device() {
let Some(dev) = try_get_device() else {
return;
};
let mm = LevelZeroMemoryManager::new(dev);
let h = mm.alloc(256).expect("alloc 256 bytes");
assert!(h > 0);
mm.free(h).expect("free");
mm.free(h).expect("double-free is a no-op");
}
#[test]
fn copy_roundtrip_requires_device() {
let Some(dev) = try_get_device() else {
return;
};
let mm = LevelZeroMemoryManager::new(dev);
let src: Vec<u8> = (0u8..64).collect();
let h = mm.alloc(src.len()).expect("alloc");
mm.copy_to_device(h, &src).expect("copy_to_device");
let mut dst = vec![0u8; src.len()];
mm.copy_from_device(&mut dst, h).expect("copy_from_device");
assert_eq!(src, dst);
mm.free(h).expect("free");
}
#[test]
fn unknown_handle_returns_error() {
let Some(dev) = try_get_device() else {
return;
};
let mm = LevelZeroMemoryManager::new(dev);
let err = mm.copy_to_device(9999, b"hello").unwrap_err();
assert!(matches!(err, LevelZeroError::InvalidArgument(_)));
}
#[test]
fn debug_impl_smoke() {
let Some(dev) = try_get_device() else {
return;
};
let mm = LevelZeroMemoryManager::new(dev);
let s = format!("{mm:?}");
assert!(s.contains("LevelZeroMemoryManager"));
}
}