Skip to main content

oxicuda_levelzero/
memory.rs

1//! Level Zero memory manager — allocates, copies, and frees device memory
2//! buffers using the Level Zero API with host-staging for transfers.
3//!
4//! Device memory is not directly CPU-accessible; all host↔device copies
5//! use a temporary host-side staging allocation and a command list.
6//!
7//! All buffers are tracked by opaque `u64` handles (starting at 1) that
8//! mirror the CUDA device-pointer model used by the rest of OxiCUDA.
9
10use std::{
11    collections::HashMap,
12    sync::{Arc, Mutex},
13};
14
15#[cfg(any(target_os = "linux", target_os = "windows"))]
16use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
17
18use crate::{
19    device::LevelZeroDevice,
20    error::{LevelZeroError, LevelZeroResult},
21};
22
23// ─── Platform-specific imports ───────────────────────────────────────────────
24
25#[cfg(any(target_os = "linux", target_os = "windows"))]
26use std::ffi::c_void;
27
28#[cfg(any(target_os = "linux", target_os = "windows"))]
29use crate::device::{
30    ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
31    ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC, ZeCommandListDesc, ZeCommandListHandle,
32    ZeDeviceMemAllocDesc, ZeHostMemAllocDesc,
33};
34
35// ─── Internal buffer record ──────────────────────────────────────────────────
36
37/// Bookkeeping entry for a single allocated Level Zero device buffer.
38struct L0BufferRecord {
39    /// Raw device pointer (Linux and Windows only).
40    #[cfg(any(target_os = "linux", target_os = "windows"))]
41    device_ptr: *mut c_void,
42    /// Byte size of the allocation.
43    #[cfg(any(target_os = "linux", target_os = "windows"))]
44    size: u64,
45}
46
47// SAFETY: `L0BufferRecord` contains a raw pointer that is logically owned
48// by the `LevelZeroMemoryManager`.  Access is serialized through a `Mutex`.
49#[cfg(any(target_os = "linux", target_os = "windows"))]
50unsafe impl Send for L0BufferRecord {}
51
52// ─── Memory manager ──────────────────────────────────────────────────────────
53
54/// Manages a pool of Level Zero device buffers, returning opaque `u64` handles.
55///
56/// Uses explicit host-staging buffers and command lists for host↔device
57/// data transfers, matching the Level Zero programming model.
58///
59/// All public methods take `&self` so the manager can be shared behind `Arc`.
60pub struct LevelZeroMemoryManager {
61    #[cfg(any(target_os = "linux", target_os = "windows"))]
62    device: Arc<LevelZeroDevice>,
63    buffers: Mutex<HashMap<u64, L0BufferRecord>>,
64    #[cfg(any(target_os = "linux", target_os = "windows"))]
65    next_handle: AtomicU64,
66}
67
68impl LevelZeroMemoryManager {
69    /// Create a new memory manager backed by `device`.
70    #[cfg(any(target_os = "linux", target_os = "windows"))]
71    pub fn new(device: Arc<LevelZeroDevice>) -> Self {
72        Self {
73            device,
74            buffers: Mutex::new(HashMap::new()),
75            next_handle: AtomicU64::new(1),
76        }
77    }
78
79    /// Stub constructor on unsupported platforms.
80    ///
81    /// All methods return [`LevelZeroError::UnsupportedPlatform`].
82    #[cfg(not(any(target_os = "linux", target_os = "windows")))]
83    pub fn new(_device: Arc<LevelZeroDevice>) -> Self {
84        Self {
85            buffers: Mutex::new(HashMap::new()),
86        }
87    }
88
89    /// Allocate `bytes` bytes of device memory.
90    ///
91    /// Returns an opaque handle.  The caller must eventually call [`free`](Self::free).
92    pub fn alloc(&self, bytes: usize) -> LevelZeroResult<u64> {
93        #[cfg(any(target_os = "linux", target_os = "windows"))]
94        {
95            let api = &self.device.api;
96            let context = self.device.context;
97            let device_handle = self.device.device;
98
99            let desc = ZeDeviceMemAllocDesc {
100                stype: ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
101                p_next: std::ptr::null(),
102                flags: 0,
103                ordinal: 0,
104            };
105
106            let mut ptr: *mut c_void = std::ptr::null_mut();
107            // SAFETY: `context` and `device_handle` are valid Level Zero handles;
108            // `desc` is properly initialized; `ptr` is a valid output pointer.
109            let rc = unsafe {
110                (api.ze_mem_alloc_device)(
111                    context,
112                    &desc,
113                    bytes,
114                    64, // 64-byte alignment
115                    device_handle,
116                    &mut ptr as *mut *mut c_void,
117                )
118            };
119
120            if rc != 0 {
121                return Err(LevelZeroError::ZeError(
122                    rc,
123                    "zeMemAllocDevice failed".into(),
124                ));
125            }
126
127            let handle = self.next_handle.fetch_add(1, Relaxed);
128            self.buffers
129                .lock()
130                .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
131                .insert(
132                    handle,
133                    L0BufferRecord {
134                        device_ptr: ptr,
135                        size: bytes as u64,
136                    },
137                );
138
139            Ok(handle)
140        }
141
142        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
143        {
144            let _ = bytes;
145            Err(LevelZeroError::UnsupportedPlatform)
146        }
147    }
148
149    /// Release the device buffer associated with `handle`.
150    ///
151    /// Unknown handles are silently ignored (idempotent free).
152    pub fn free(&self, handle: u64) -> LevelZeroResult<()> {
153        #[cfg(any(target_os = "linux", target_os = "windows"))]
154        {
155            let record = self
156                .buffers
157                .lock()
158                .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?
159                .remove(&handle);
160
161            if let Some(rec) = record {
162                let api = &self.device.api;
163                let context = self.device.context;
164                // SAFETY: `rec.device_ptr` was allocated by `zeMemAllocDevice`
165                // and has not been freed yet (we just removed it from the map).
166                let rc = unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
167                if rc != 0 {
168                    return Err(LevelZeroError::ZeError(rc, "zeMemFree failed".into()));
169                }
170            }
171            Ok(())
172        }
173
174        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
175        {
176            let _ = handle;
177            Err(LevelZeroError::UnsupportedPlatform)
178        }
179    }
180
181    /// Upload host bytes `src` into the device buffer identified by `handle`.
182    ///
183    /// Allocates a temporary host-side staging buffer, copies the data into it,
184    /// then uses a command list to schedule the device copy and waits for completion.
185    pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> LevelZeroResult<()> {
186        #[cfg(any(target_os = "linux", target_os = "windows"))]
187        {
188            let device_ptr = {
189                let buffers = self
190                    .buffers
191                    .lock()
192                    .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
193                let rec = buffers.get(&handle).ok_or_else(|| {
194                    LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
195                })?;
196                rec.device_ptr
197            };
198
199            let api = &self.device.api;
200            let context = self.device.context;
201            let device_handle = self.device.device;
202            let queue = self.device.queue;
203            let copy_len = src.len();
204
205            // Allocate a host staging buffer.
206            let host_desc = ZeHostMemAllocDesc {
207                stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
208                p_next: std::ptr::null(),
209                flags: 0,
210            };
211            let mut host_ptr: *mut c_void = std::ptr::null_mut();
212            // SAFETY: `context` is valid; `host_desc` is properly initialized;
213            // `host_ptr` is a valid output pointer.
214            let rc = unsafe {
215                (api.ze_mem_alloc_host)(
216                    context,
217                    &host_desc,
218                    copy_len,
219                    64,
220                    &mut host_ptr as *mut *mut c_void,
221                )
222            };
223            if rc != 0 {
224                return Err(LevelZeroError::ZeError(
225                    rc,
226                    "zeMemAllocHost (staging) failed".into(),
227                ));
228            }
229
230            // Copy host data into the staging buffer.
231            // SAFETY: `host_ptr` is a valid CPU-accessible pointer allocated
232            // by `zeMemAllocHost`; `src` is a valid slice of `copy_len` bytes.
233            unsafe {
234                std::ptr::copy_nonoverlapping(src.as_ptr(), host_ptr as *mut u8, copy_len);
235            }
236
237            // Create a command list for the copy.
238            let list_desc = ZeCommandListDesc {
239                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
240                p_next: std::ptr::null(),
241                command_queue_group_ordinal: 0,
242                flags: 0,
243            };
244            let mut list: ZeCommandListHandle = std::ptr::null_mut();
245            // SAFETY: `context` and `device_handle` are valid; `list_desc` is
246            // properly initialized; `list` is a valid output pointer.
247            let rc = unsafe {
248                (api.ze_command_list_create)(
249                    context,
250                    device_handle,
251                    &list_desc,
252                    &mut list as *mut ZeCommandListHandle,
253                )
254            };
255            if rc != 0 {
256                // SAFETY: host_ptr was successfully allocated above.
257                unsafe { (api.ze_mem_free)(context, host_ptr) };
258                return Err(LevelZeroError::CommandListError(format!(
259                    "zeCommandListCreate failed: 0x{rc:08x}"
260                )));
261            }
262
263            // Append the host→device memory copy to the command list.
264            // SAFETY: `list`, `device_ptr`, and `host_ptr` are valid;
265            // the copy length matches the data we staged.
266            let rc = unsafe {
267                (api.ze_command_list_append_memory_copy)(
268                    list,
269                    device_ptr,
270                    host_ptr as *const c_void,
271                    copy_len,
272                    0, // no signal event
273                    0, // no wait events
274                    std::ptr::null(),
275                )
276            };
277            if rc != 0 {
278                // SAFETY: list and host_ptr were allocated above.
279                unsafe {
280                    (api.ze_command_list_destroy)(list);
281                    (api.ze_mem_free)(context, host_ptr);
282                }
283                return Err(LevelZeroError::CommandListError(format!(
284                    "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
285                )));
286            }
287
288            // Close and execute the command list.
289            // SAFETY: `list` is in the recording state.
290            let rc = unsafe { (api.ze_command_list_close)(list) };
291            if rc != 0 {
292                unsafe {
293                    (api.ze_command_list_destroy)(list);
294                    (api.ze_mem_free)(context, host_ptr);
295                }
296                return Err(LevelZeroError::CommandListError(format!(
297                    "zeCommandListClose failed: 0x{rc:08x}"
298                )));
299            }
300
301            // SAFETY: `queue` is valid; `list` is closed and ready for submission.
302            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
303            if rc != 0 {
304                unsafe {
305                    (api.ze_command_list_destroy)(list);
306                    (api.ze_mem_free)(context, host_ptr);
307                }
308                return Err(LevelZeroError::CommandListError(format!(
309                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
310                )));
311            }
312
313            // Wait for completion.
314            // SAFETY: `queue` is valid; u64::MAX means "wait indefinitely".
315            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
316            if rc != 0 {
317                unsafe {
318                    (api.ze_command_list_destroy)(list);
319                    (api.ze_mem_free)(context, host_ptr);
320                }
321                return Err(LevelZeroError::CommandListError(format!(
322                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
323                )));
324            }
325
326            // Clean up.
327            // SAFETY: `list` was created above and is no longer needed.
328            unsafe {
329                (api.ze_command_list_destroy)(list);
330                (api.ze_mem_free)(context, host_ptr);
331            }
332
333            Ok(())
334        }
335
336        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
337        {
338            let _ = (handle, src);
339            Err(LevelZeroError::UnsupportedPlatform)
340        }
341    }
342
343    /// Download device buffer `handle` into `dst`.
344    ///
345    /// Uses a host-staging buffer and a command list for the device→host copy.
346    pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> LevelZeroResult<()> {
347        #[cfg(any(target_os = "linux", target_os = "windows"))]
348        {
349            let device_ptr = {
350                let buffers = self
351                    .buffers
352                    .lock()
353                    .map_err(|_| LevelZeroError::CommandListError("mutex poisoned".into()))?;
354                let rec = buffers.get(&handle).ok_or_else(|| {
355                    LevelZeroError::InvalidArgument(format!("unknown handle {handle}"))
356                })?;
357                rec.device_ptr
358            };
359
360            let api = &self.device.api;
361            let context = self.device.context;
362            let device_handle = self.device.device;
363            let queue = self.device.queue;
364            let copy_len = dst.len();
365
366            // Allocate a host staging buffer.
367            let host_desc = ZeHostMemAllocDesc {
368                stype: ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
369                p_next: std::ptr::null(),
370                flags: 0,
371            };
372            let mut host_ptr: *mut c_void = std::ptr::null_mut();
373            // SAFETY: `context` is valid; `host_desc` is properly initialized;
374            // `host_ptr` is a valid output pointer.
375            let rc = unsafe {
376                (api.ze_mem_alloc_host)(
377                    context,
378                    &host_desc,
379                    copy_len,
380                    64,
381                    &mut host_ptr as *mut *mut c_void,
382                )
383            };
384            if rc != 0 {
385                return Err(LevelZeroError::ZeError(
386                    rc,
387                    "zeMemAllocHost (staging) failed".into(),
388                ));
389            }
390
391            // Create a command list for the copy.
392            let list_desc = ZeCommandListDesc {
393                stype: ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
394                p_next: std::ptr::null(),
395                command_queue_group_ordinal: 0,
396                flags: 0,
397            };
398            let mut list: ZeCommandListHandle = std::ptr::null_mut();
399            // SAFETY: `context` and `device_handle` are valid; `list_desc` is
400            // properly initialized; `list` is a valid output pointer.
401            let rc = unsafe {
402                (api.ze_command_list_create)(
403                    context,
404                    device_handle,
405                    &list_desc,
406                    &mut list as *mut ZeCommandListHandle,
407                )
408            };
409            if rc != 0 {
410                unsafe { (api.ze_mem_free)(context, host_ptr) };
411                return Err(LevelZeroError::CommandListError(format!(
412                    "zeCommandListCreate failed: 0x{rc:08x}"
413                )));
414            }
415
416            // Append the device→host memory copy to the command list.
417            // SAFETY: `list`, `host_ptr`, and `device_ptr` are valid;
418            // the copy length matches the destination buffer.
419            let rc = unsafe {
420                (api.ze_command_list_append_memory_copy)(
421                    list,
422                    host_ptr,
423                    device_ptr as *const c_void,
424                    copy_len,
425                    0, // no signal event
426                    0, // no wait events
427                    std::ptr::null(),
428                )
429            };
430            if rc != 0 {
431                unsafe {
432                    (api.ze_command_list_destroy)(list);
433                    (api.ze_mem_free)(context, host_ptr);
434                }
435                return Err(LevelZeroError::CommandListError(format!(
436                    "zeCommandListAppendMemoryCopy failed: 0x{rc:08x}"
437                )));
438            }
439
440            // Close and execute the command list.
441            // SAFETY: `list` is in the recording state.
442            let rc = unsafe { (api.ze_command_list_close)(list) };
443            if rc != 0 {
444                unsafe {
445                    (api.ze_command_list_destroy)(list);
446                    (api.ze_mem_free)(context, host_ptr);
447                }
448                return Err(LevelZeroError::CommandListError(format!(
449                    "zeCommandListClose failed: 0x{rc:08x}"
450                )));
451            }
452
453            // SAFETY: `queue` is valid; `list` is closed and ready for submission.
454            let rc = unsafe { (api.ze_command_queue_execute_command_lists)(queue, 1, &list, 0) };
455            if rc != 0 {
456                unsafe {
457                    (api.ze_command_list_destroy)(list);
458                    (api.ze_mem_free)(context, host_ptr);
459                }
460                return Err(LevelZeroError::CommandListError(format!(
461                    "zeCommandQueueExecuteCommandLists failed: 0x{rc:08x}"
462                )));
463            }
464
465            // Wait for completion.
466            // SAFETY: `queue` is valid; u64::MAX means "wait indefinitely".
467            let rc = unsafe { (api.ze_command_queue_synchronize)(queue, u64::MAX) };
468            if rc != 0 {
469                unsafe {
470                    (api.ze_command_list_destroy)(list);
471                    (api.ze_mem_free)(context, host_ptr);
472                }
473                return Err(LevelZeroError::CommandListError(format!(
474                    "zeCommandQueueSynchronize failed: 0x{rc:08x}"
475                )));
476            }
477
478            // Copy staging buffer to destination.
479            // SAFETY: `host_ptr` is valid and contains `copy_len` bytes of data
480            // transferred from the device; `dst` is a valid mutable slice.
481            unsafe {
482                std::ptr::copy_nonoverlapping(host_ptr as *const u8, dst.as_mut_ptr(), copy_len);
483            }
484
485            // Clean up.
486            // SAFETY: `list` and `host_ptr` were allocated above.
487            unsafe {
488                (api.ze_command_list_destroy)(list);
489                (api.ze_mem_free)(context, host_ptr);
490            }
491
492            Ok(())
493        }
494
495        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
496        {
497            let _ = (dst, handle);
498            Err(LevelZeroError::UnsupportedPlatform)
499        }
500    }
501}
502
503// ─── Drop ────────────────────────────────────────────────────────────────────
504
505impl Drop for LevelZeroMemoryManager {
506    fn drop(&mut self) {
507        #[cfg(any(target_os = "linux", target_os = "windows"))]
508        {
509            let api = &self.device.api;
510            let context = self.device.context;
511
512            if let Ok(mut map) = self.buffers.lock() {
513                for (handle, rec) in map.drain() {
514                    tracing::warn!(
515                        "LevelZeroMemoryManager: leaked buffer handle {handle} ({} bytes)",
516                        rec.size
517                    );
518                    // SAFETY: `rec.device_ptr` is a valid outstanding allocation.
519                    unsafe { (api.ze_mem_free)(context, rec.device_ptr) };
520                }
521            }
522        }
523
524        #[cfg(not(any(target_os = "linux", target_os = "windows")))]
525        {
526            // Nothing to do on unsupported platforms.
527        }
528    }
529}
530
531// ─── Debug ───────────────────────────────────────────────────────────────────
532
533impl std::fmt::Debug for LevelZeroMemoryManager {
534    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
536        write!(f, "LevelZeroMemoryManager(buffers={count})")
537    }
538}
539
540// ─── Send + Sync ─────────────────────────────────────────────────────────────
541
542// SAFETY: `LevelZeroMemoryManager` serializes all access through a `Mutex`.
543// The raw pointer inside `L0BufferRecord` is owned and not aliased.
544unsafe impl Send for LevelZeroMemoryManager {}
545// SAFETY: See `Send` impl above.  All mutable operations go through a `Mutex`.
546unsafe impl Sync for LevelZeroMemoryManager {}
547
548// ─── Tests ───────────────────────────────────────────────────────────────────
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    fn try_get_device() -> Option<Arc<LevelZeroDevice>> {
555        LevelZeroDevice::new().ok().map(Arc::new)
556    }
557
558    #[test]
559    fn alloc_and_free_requires_device() {
560        let Some(dev) = try_get_device() else {
561            return;
562        };
563        let mm = LevelZeroMemoryManager::new(dev);
564        let h = mm.alloc(256).expect("alloc 256 bytes");
565        assert!(h > 0);
566        mm.free(h).expect("free");
567        // Double-free: the handle is gone from the map, so it silently ignores.
568        mm.free(h).expect("double-free is a no-op");
569    }
570
571    #[test]
572    fn copy_roundtrip_requires_device() {
573        let Some(dev) = try_get_device() else {
574            return;
575        };
576        let mm = LevelZeroMemoryManager::new(dev);
577
578        let src: Vec<u8> = (0u8..64).collect();
579        let h = mm.alloc(src.len()).expect("alloc");
580        mm.copy_to_device(h, &src).expect("copy_to_device");
581
582        let mut dst = vec![0u8; src.len()];
583        mm.copy_from_device(&mut dst, h).expect("copy_from_device");
584
585        assert_eq!(src, dst);
586        mm.free(h).expect("free");
587    }
588
589    #[test]
590    fn unknown_handle_returns_error() {
591        let Some(dev) = try_get_device() else {
592            return;
593        };
594        let mm = LevelZeroMemoryManager::new(dev);
595        let err = mm.copy_to_device(9999, b"hello").unwrap_err();
596        assert!(matches!(err, LevelZeroError::InvalidArgument(_)));
597    }
598
599    #[test]
600    fn debug_impl_smoke() {
601        let Some(dev) = try_get_device() else {
602            return;
603        };
604        let mm = LevelZeroMemoryManager::new(dev);
605        let s = format!("{mm:?}");
606        assert!(s.contains("LevelZeroMemoryManager"));
607    }
608}