Skip to main content

oxicuda_levelzero/
memory.rs

1//! Level Zero memory manager — allocates, copies, and frees device memory
2//! buffers using the Level Zero API with host-staging for transfers.
3//!
4//! Device memory is not directly CPU-accessible; all host↔device copies
5//! use a temporary host-side staging allocation and a command list.
6//!
7//! All buffers are tracked by opaque `u64` handles (starting at 1) that
8//! mirror the CUDA device-pointer model used by the rest of OxiCUDA.
9
10use std::{
11    collections::HashMap,
12    sync::{Arc, Mutex},
13};
14
15#[cfg(any(target_os = "linux", target_os = "windows"))]
16use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
17
18use crate::{
19    device::LevelZeroDevice,
20    error::{LevelZeroError, LevelZeroResult},
21};
22
23// ─── Platform-specific imports ───────────────────────────────────────────────
24
25#[cfg(any(target_os = "linux", target_os = "windows"))]
26use std::ffi::c_void;
27
28#[cfg(any(target_os = "linux", target_os = "windows"))]
29use crate::device::{
30    ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
31    ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC, ZeCommandListDesc, ZeCommandListHandle,
32    ZeDeviceMemAllocDesc, ZeHostMemAllocDesc,
33};
34
35// ─── Internal buffer record ──────────────────────────────────────────────────
36
37/// Bookkeeping entry for a single allocated Level Zero device buffer.
38struct L0BufferRecord {
39    /// Raw device pointer (Linux and Windows only).
40    #[cfg(any(target_os = "linux", target_os = "windows"))]
41    device_ptr: *mut c_void,
42    /// Byte size of the allocation.
43    #[cfg(any(target_os = "linux", target_os = "windows"))]
44    size: u64,
45}
46
47// SAFETY: `L0BufferRecord` contains a raw pointer that is logically owned
48// by the `LevelZeroMemoryManager`.  Access is serialized through a `Mutex`.
49#[cfg(any(target_os = "linux", target_os = "windows"))]
50unsafe impl Send for L0BufferRecord {}
51
52// ─── Memory manager ──────────────────────────────────────────────────────────
53
54/// Manages a pool of Level Zero device buffers, returning opaque `u64` handles.
55///
56/// Uses explicit host-staging buffers and command lists for host↔device
57/// data transfers, matching the Level Zero programming model.
58///
59/// All public methods take `&self` so the manager can be shared behind `Arc`.
60pub struct LevelZeroMemoryManager {
61    #[cfg(any(target_os = "linux", target_os = "windows"))]
62    device: Arc<LevelZeroDevice>,
63    buffers: Mutex<HashMap<u64, L0BufferRecord>>,
64    #[cfg(any(target_os = "linux", target_os = "windows"))]
65    next_handle: AtomicU64,
66}
67
68impl LevelZeroMemoryManager {
69    /// Create a new memory manager backed by `device`.
70    #[cfg(any(target_os = "linux", target_os = "windows"))]
71    pub fn new(device: Arc<LevelZeroDevice>) -> Self {
72        Self {
73            device,
74            buffers: Mutex::new(HashMap::new()),
75            next_handle: AtomicU64::new(1),
76        }
77    }
78
79    /// Stub constructor on unsupported platforms.
80    ///
81    /// All methods return [`LevelZeroError::UnsupportedPlatform`].
82    #[cfg(not(any(target_os = "linux", target_os = "windows")))]
83    pub fn new(_device: Arc<LevelZeroDevice>) -> Self {
84        Self {
85            buffers: Mutex::new(HashMap::new()),
86        }
87    }
88
89    /// Get the raw device pointer for a buffer handle.
90    ///
91    /// Returns the `*mut c_void` device pointer that Level Zero allocated for
92    /// the given handle.  This is needed to pass as a kernel argument.
93    #[cfg(any(target_os = "linux", target_os = "windows"))]
94    pub fn device_ptr(&self, handle: u64) -> LevelZeroResult<*mut c_void> {
95        let buffers = self
96            .buffers
97            .lock()
98            .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
99        let rec = buffers
100            .get(&handle)
101            .ok_or_else(|| LevelZeroError::InvalidArgument(format!("unknown handle {handle}")))?;
102        Ok(rec.device_ptr)
103    }
104
105    /// Allocate `bytes` bytes of device memory.
106    ///
107    /// Returns an opaque handle.  The caller must eventually call [`free`](Self::free).
108    pub fn alloc(&self, bytes: usize) -> LevelZeroResult<u64> {
109        #[cfg(any(target_os = "linux", target_os = "windows"))]
110        {
111            let api = &self.device.api;
112            let context = self.device.context;
113            let device_handle = self.device.device;
114
115            let desc = ZeDeviceMemAllocDesc {
116                stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
117                p_next: std::ptr::null(),
118                flags: 0,
119                ordinal: 0,
120            };
121
122            let mut ptr: *mut c_void = std::ptr::null_mut();
123            // SAFETY: `context` and `device_handle` are valid Level Zero handles;
124            // `desc` is properly initialized; `ptr` is a valid output pointer.
125            let rc = unsafe {
126                (api.ze_mem_alloc_device)(
127                    context,
128                    &desc,
129                    bytes,
130                    64, // 64-byte alignment
131                    device_handle,
132                    &mut ptr as *mut *mut c_void,
133                )
134            };
135
136            if rc != 0 {
137                return Err(LevelZeroError::ZeError(
138                    rc,
139                    "zeMemAllocDevice failed".into(),
140                ));
141            }
142
143            let handle = self.next_handle.fetch_add(1, Relaxed);
144            self.buffers
145                .lock()
146                .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
147                .insert(
148                    handle,
149                    L0BufferRecord {
150                        device_ptr: ptr,
151                        size: bytes as u64,
152                    },
153                );
154
155            Ok(handle)
156        }
157
158        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
159        {
160            let _ = bytes;
161            Err(LevelZeroError::UnsupportedPlatform)
162        }
163    }
164
165    /// Release the device buffer associated with `handle`.
166    ///
167    /// Unknown handles are silently ignored (idempotent free).
168    pub fn free(&self, handle: u64) -> LevelZeroResult<()> {
169        #[cfg(any(target_os = "linux", target_os = "windows"))]
170        {
171            let record = self
172                .buffers
173                .lock()
174                .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
175                .remove(&handle);
176
177            if let Some(rec) = record {
178                let api = &self.device.api;
179                let context = self.device.context;
180                // SAFETY: `rec.device_ptr` was allocated by `zeMemAllocDevice`
181                // and has not been freed yet (we just removed it from the map).
182                let rc = unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
183                if rc != 0 {
184                    return Err(LevelZeroError::ZeError(rc, "zeMemFree failed".into()));
185                }
186            }
187            Ok(())
188        }
189
190        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
191        {
192            let _ = handle;
193            Err(LevelZeroError::UnsupportedPlatform)
194        }
195    }
196
197    /// Upload host bytes `src` into the device buffer identified by `handle`.
198    ///
199    /// Allocates a temporary host-side staging buffer, copies the data into it,
200    /// then uses a command list to schedule the device copy and waits for completion.
201    pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> LevelZeroResult<()> {
202        #[cfg(any(target_os = "linux", target_os = "windows"))]
203        {
204            let device_ptr = {
205                let buffers = self
206                    .buffers
207                    .lock()
208                    .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
209                let rec = buffers.get(&handle).ok_or_else(|| {
210                    LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
211                })?;
212                rec.device_ptr
213            };
214
215            let api = &self.device.api;
216            let context = self.device.context;
217            let device_handle = self.device.device;
218            let queue = self.device.queue;
219            let copy_len = src.len();
220
221            // Allocate a host staging buffer.
222            let host_desc = ZeHostMemAllocDesc {
223                stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
224                p_next: std::ptr::null(),
225                flags: 0,
226            };
227            let mut host_ptr: *mut c_void = std::ptr::null_mut();
228            // SAFETY: `context` is valid; `host_desc` is properly initialized;
229            // `host_ptr` is a valid output pointer.
230            let rc = unsafe {
231                (api.ze_mem_alloc_host)(
232                    context,
233                    &host_desc,
234                    copy_len,
235                    64,
236                    &mut host_ptr as *mut *mut c_void,
237                )
238            };
239            if rc != 0 {
240                return Err(LevelZeroError::ZeError(
241                    rc,
242                    "zeMemAllocHost (staging) failed".into(),
243                ));
244            }
245
246            // Copy host data into the staging buffer.
247            // SAFETY: `host_ptr` is a valid CPU-accessible pointer allocated
248            // by `zeMemAllocHost`; `src` is a valid slice of `copy_len` bytes.
249            unsafe {
250                std::ptr::copy_nonoverlapping(src.as_ptr(), host_ptr as *mut u8, copy_len);
251            }
252
253            // Create a command list for the copy.
254            let list_desc = ZeCommandListDesc {
255                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
256                p_next: std::ptr::null(),
257                command_queue_group_ordinal: 0,
258                flags: 0,
259            };
260            let mut list: ZeCommandListHandle = std::ptr::null_mut();
261            // SAFETY: `context` and `device_handle` are valid; `list_desc` is
262            // properly initialized; `list` is a valid output pointer.
263            let rc = unsafe {
264                (api.ze_command_list_create)(
265                    context,
266                    device_handle,
267                    &list_desc,
268                    &mut list as *mut ZeCommandListHandle,
269                )
270            };
271            if rc != 0 {
272                // SAFETY: host_ptr was successfully allocated above.
273                unsafe { (api.ze_mem_free)(context, host_ptr) };
274                return Err(LevelZeroError::CommandListError(format!(
275                    "zeCommandListCreate failed: 0x{rc:08x}"
276                )));
277            }
278
279            // Append the host→device memory copy to the command list.
280            // SAFETY: `list`, `device_ptr`, and `host_ptr` are valid;
281            // the copy length matches the data we staged.
282            let rc = unsafe {
283                (api.ze_command_list_append_memory_copy)(
284                    list,
285                    device_ptr,
286                    host_ptr as *const c_void,
287                    copy_len,
288                    0, // no signal event
289                    0, // no wait events
290                    std::ptr::null(),
291                )
292            };
293            if rc != 0 {
294                // SAFETY: list and host_ptr were allocated above.
295                unsafe {
296                    (api.ze_command_list_destroy)(list);
297                    (api.ze_mem_free)(context, host_ptr);
298                }
299                return Err(LevelZeroError::CommandListError(format!(
300                    "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
301                )));
302            }
303
304            // Close and execute the command list.
305            // SAFETY: `list` is in the recording state.
306            let rc = unsafe { (api.ze_command_list_close)(list) };
307            if rc != 0 {
308                unsafe {
309                    (api.ze_command_list_destroy)(list);
310                    (api.ze_mem_free)(context, host_ptr);
311                }
312                return Err(LevelZeroError::CommandListError(format!(
313                    "zeCommandListClose failed: 0x{rc:08x}"
314                )));
315            }
316
317            // SAFETY: `queue` is valid; `list` is closed and ready for submission.
318            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
319            if rc != 0 {
320                unsafe {
321                    (api.ze_command_list_destroy)(list);
322                    (api.ze_mem_free)(context, host_ptr);
323                }
324                return Err(LevelZeroError::CommandListError(format!(
325                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
326                )));
327            }
328
329            // Wait for completion.
330            // SAFETY: `queue` is valid; u64::MAX means "wait indefinitely".
331            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
332            if rc != 0 {
333                unsafe {
334                    (api.ze_command_list_destroy)(list);
335                    (api.ze_mem_free)(context, host_ptr);
336                }
337                return Err(LevelZeroError::CommandListError(format!(
338                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
339                )));
340            }
341
342            // Clean up.
343            // SAFETY: `list` was created above and is no longer needed.
344            unsafe {
345                (api.ze_command_list_destroy)(list);
346                (api.ze_mem_free)(context, host_ptr);
347            }
348
349            Ok(())
350        }
351
352        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
353        {
354            let _ = (handle, src);
355            Err(LevelZeroError::UnsupportedPlatform)
356        }
357    }
358
359    /// Download device buffer `handle` into `dst`.
360    ///
361    /// Uses a host-staging buffer and a command list for the device→host copy.
362    pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> LevelZeroResult<()> {
363        #[cfg(any(target_os = "linux", target_os = "windows"))]
364        {
365            let device_ptr = {
366                let buffers = self
367                    .buffers
368                    .lock()
369                    .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
370                let rec = buffers.get(&handle).ok_or_else(|| {
371                    LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
372                })?;
373                rec.device_ptr
374            };
375
376            let api = &self.device.api;
377            let context = self.device.context;
378            let device_handle = self.device.device;
379            let queue = self.device.queue;
380            let copy_len = dst.len();
381
382            // Allocate a host staging buffer.
383            let host_desc = ZeHostMemAllocDesc {
384                stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
385                p_next: std::ptr::null(),
386                flags: 0,
387            };
388            let mut host_ptr: *mut c_void = std::ptr::null_mut();
389            // SAFETY: `context` is valid; `host_desc` is properly initialized;
390            // `host_ptr` is a valid output pointer.
391            let rc = unsafe {
392                (api.ze_mem_alloc_host)(
393                    context,
394                    &host_desc,
395                    copy_len,
396                    64,
397                    &mut host_ptr as *mut *mut c_void,
398                )
399            };
400            if rc != 0 {
401                return Err(LevelZeroError::ZeError(
402                    rc,
403                    "zeMemAllocHost (staging) failed".into(),
404                ));
405            }
406
407            // Create a command list for the copy.
408            let list_desc = ZeCommandListDesc {
409                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
410                p_next: std::ptr::null(),
411                command_queue_group_ordinal: 0,
412                flags: 0,
413            };
414            let mut list: ZeCommandListHandle = std::ptr::null_mut();
415            // SAFETY: `context` and `device_handle` are valid; `list_desc` is
416            // properly initialized; `list` is a valid output pointer.
417            let rc = unsafe {
418                (api.ze_command_list_create)(
419                    context,
420                    device_handle,
421                    &list_desc,
422                    &mut list as *mut ZeCommandListHandle,
423                )
424            };
425            if rc != 0 {
426                unsafe { (api.ze_mem_free)(context, host_ptr) };
427                return Err(LevelZeroError::CommandListError(format!(
428                    "zeCommandListCreate failed: 0x{rc:08x}"
429                )));
430            }
431
432            // Append the device→host memory copy to the command list.
433            // SAFETY: `list`, `host_ptr`, and `device_ptr` are valid;
434            // the copy length matches the destination buffer.
435            let rc = unsafe {
436                (api.ze_command_list_append_memory_copy)(
437                    list,
438                    host_ptr,
439                    device_ptr as *const c_void,
440                    copy_len,
441                    0, // no signal event
442                    0, // no wait events
443                    std::ptr::null(),
444                )
445            };
446            if rc != 0 {
447                unsafe {
448                    (api.ze_command_list_destroy)(list);
449                    (api.ze_mem_free)(context, host_ptr);
450                }
451                return Err(LevelZeroError::CommandListError(format!(
452                    "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
453                )));
454            }
455
456            // Close and execute the command list.
457            // SAFETY: `list` is in the recording state.
458            let rc = unsafe { (api.ze_command_list_close)(list) };
459            if rc != 0 {
460                unsafe {
461                    (api.ze_command_list_destroy)(list);
462                    (api.ze_mem_free)(context, host_ptr);
463                }
464                return Err(LevelZeroError::CommandListError(format!(
465                    "zeCommandListClose failed: 0x{rc:08x}"
466                )));
467            }
468
469            // SAFETY: `queue` is valid; `list` is closed and ready for submission.
470            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
471            if rc != 0 {
472                unsafe {
473                    (api.ze_command_list_destroy)(list);
474                    (api.ze_mem_free)(context, host_ptr);
475                }
476                return Err(LevelZeroError::CommandListError(format!(
477                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
478                )));
479            }
480
481            // Wait for completion.
482            // SAFETY: `queue` is valid; u64::MAX means "wait indefinitely".
483            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
484            if rc != 0 {
485                unsafe {
486                    (api.ze_command_list_destroy)(list);
487                    (api.ze_mem_free)(context, host_ptr);
488                }
489                return Err(LevelZeroError::CommandListError(format!(
490                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
491                )));
492            }
493
494            // Copy staging buffer to destination.
495            // SAFETY: `host_ptr` is valid and contains `copy_len` bytes of data
496            // transferred from the device; `dst` is a valid mutable slice.
497            unsafe {
498                std::ptr::copy_nonoverlapping(host_ptr as *const u8, dst.as_mut_ptr(), copy_len);
499            }
500
501            // Clean up.
502            // SAFETY: `list` and `host_ptr` were allocated above.
503            unsafe {
504                (api.ze_command_list_destroy)(list);
505                (api.ze_mem_free)(context, host_ptr);
506            }
507
508            Ok(())
509        }
510
511        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
512        {
513            let _ = (dst, handle);
514            Err(LevelZeroError::UnsupportedPlatform)
515        }
516    }
517}
518
519// ─── Drop ────────────────────────────────────────────────────────────────────
520
521impl Drop for LevelZeroMemoryManager {
522    fn drop(&mut self) {
523        #[cfg(any(target_os = "linux", target_os = "windows"))]
524        {
525            let api = &self.device.api;
526            let context = self.device.context;
527
528            if let Ok(mut map) = self.buffers.lock() {
529                for (handle, rec) in map.drain() {
530                    tracing::warn!(
531                        "LevelZeroMemoryManager: leaked buffer handle {handle} ({} bytes)",
532                        rec.size
533                    );
534                    // SAFETY: `rec.device_ptr` is a valid outstanding allocation.
535                    unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
536                }
537            }
538        }
539
540        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
541        {
542            // Nothing to do on unsupported platforms.
543        }
544    }
545}
546
547// ─── Debug ───────────────────────────────────────────────────────────────────
548
549impl std::fmt::Debug for LevelZeroMemoryManager {
550    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551        let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
552        write!(f, "LevelZeroMemoryManager(buffers={count})")
553    }
554}
555
556// ─── Send + Sync ─────────────────────────────────────────────────────────────
557
558// SAFETY: `LevelZeroMemoryManager` serializes all access through a `Mutex`.
559// The raw pointer inside `L0BufferRecord` is owned and not aliased.
560unsafe impl Send for LevelZeroMemoryManager {}
561// SAFETY: See `Send` impl above.  All mutable operations go through a `Mutex`.
562unsafe impl Sync for LevelZeroMemoryManager {}
563
564// ─── Tests ───────────────────────────────────────────────────────────────────
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    fn try_get_device() -> Option<Arc<LevelZeroDevice>> {
571        LevelZeroDevice::new().ok().map(Arc::new)
572    }
573
574    #[test]
575    fn alloc_and_free_requires_device() {
576        let Some(dev) = try_get_device() else {
577            return;
578        };
579        let mm = LevelZeroMemoryManager::new(dev);
580        let h = mm.alloc(256).expect("alloc 256 bytes");
581        assert!(h > 0);
582        mm.free(h).expect("free");
583        // Double-free: the handle is gone from the map, so it silently ignores.
584        mm.free(h).expect("double-free is a no-op");
585    }
586
587    #[test]
588    fn copy_roundtrip_requires_device() {
589        let Some(dev) = try_get_device() else {
590            return;
591        };
592        let mm = LevelZeroMemoryManager::new(dev);
593
594        let src: Vec<u8> = (0u8..64).collect();
595        let h = mm.alloc(src.len()).expect("alloc");
596        mm.copy_to_device(h, &src).expect("copy_to_device");
597
598        let mut dst = vec![0u8; src.len()];
599        mm.copy_from_device(&mut dst, h).expect("copy_from_device");
600
601        assert_eq!(src, dst);
602        mm.free(h).expect("free");
603    }
604
605    #[test]
606    fn unknown_handle_returns_error() {
607        let Some(dev) = try_get_device() else {
608            return;
609        };
610        let mm = LevelZeroMemoryManager::new(dev);
611        let err = mm.copy_to_device(9999, b"hello").unwrap_err();
612        assert!(matches!(err, LevelZeroError::InvalidArgument(_)));
613    }
614
615    #[test]
616    fn debug_impl_smoke() {
617        let Some(dev) = try_get_device() else {
618            return;
619        };
620        let mm = LevelZeroMemoryManager::new(dev);
621        let s = format!("{mm:?}");
622        assert!(s.contains("LevelZeroMemoryManager"));
623    }
624}