1use std::fmt;
32use std::ops::{BitAnd, BitOr, Deref, DerefMut};
33
34use oxicuda_driver::error::{CudaError, CudaResult};
35use oxicuda_driver::ffi::{
36 CU_MEMHOSTREGISTER_DEVICEMAP, CU_MEMHOSTREGISTER_IOMEMORY, CU_MEMHOSTREGISTER_PORTABLE,
37 CU_MEMHOSTREGISTER_READ_ONLY, CUdeviceptr,
38};
39
40#[cfg(not(target_os = "macos"))]
41use oxicuda_driver::ffi;
42#[cfg(not(target_os = "macos"))]
43use oxicuda_driver::loader::try_driver;
44#[cfg(not(target_os = "macos"))]
45use std::ffi::c_void;
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
53pub struct RegisterFlags(u32);
54
55impl RegisterFlags {
56 pub const PORTABLE: Self = Self(CU_MEMHOSTREGISTER_PORTABLE);
58
59 pub const DEVICE_MAP: Self = Self(CU_MEMHOSTREGISTER_DEVICEMAP);
62
63 pub const IO_MEMORY: Self = Self(CU_MEMHOSTREGISTER_IOMEMORY);
65
66 pub const READ_ONLY: Self = Self(CU_MEMHOSTREGISTER_READ_ONLY);
68
69 pub const DEFAULT: Self = Self(CU_MEMHOSTREGISTER_PORTABLE | CU_MEMHOSTREGISTER_DEVICEMAP);
71
72 pub const NONE: Self = Self(0);
74
75 #[inline]
77 pub const fn bits(self) -> u32 {
78 self.0
79 }
80
81 #[inline]
83 pub const fn from_bits(bits: u32) -> Self {
84 Self(bits)
85 }
86
87 #[inline]
89 pub const fn contains(self, other: Self) -> bool {
90 (self.0 & other.0) == other.0
91 }
92}
93
94impl BitOr for RegisterFlags {
95 type Output = Self;
96
97 #[inline]
98 fn bitor(self, rhs: Self) -> Self {
99 Self(self.0 | rhs.0)
100 }
101}
102
103impl BitAnd for RegisterFlags {
104 type Output = Self;
105
106 #[inline]
107 fn bitand(self, rhs: Self) -> Self {
108 Self(self.0 & rhs.0)
109 }
110}
111
112impl fmt::Display for RegisterFlags {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 let mut parts = Vec::new();
115 if self.contains(Self::PORTABLE) {
116 parts.push("PORTABLE");
117 }
118 if self.contains(Self::DEVICE_MAP) {
119 parts.push("DEVICE_MAP");
120 }
121 if self.contains(Self::IO_MEMORY) {
122 parts.push("IO_MEMORY");
123 }
124 if self.contains(Self::READ_ONLY) {
125 parts.push("READ_ONLY");
126 }
127 if parts.is_empty() {
128 write!(f, "NONE")
129 } else {
130 write!(f, "{}", parts.join(" | "))
131 }
132 }
133}
134
135pub struct RegisteredMemory<T: Copy> {
151 ptr: *mut T,
153 len: usize,
155 flags: RegisterFlags,
157 device_ptr: CUdeviceptr,
159}
160
161unsafe impl<T: Copy + Send> Send for RegisteredMemory<T> {}
164unsafe impl<T: Copy + Sync> Sync for RegisteredMemory<T> {}
165
166impl<T: Copy> RegisteredMemory<T> {
167 #[inline]
169 pub fn as_ptr(&self) -> *const T {
170 self.ptr
171 }
172
173 #[inline]
175 pub fn as_mut_ptr(&mut self) -> *mut T {
176 self.ptr
177 }
178
179 #[inline]
183 pub fn device_ptr(&self) -> CUdeviceptr {
184 self.device_ptr
185 }
186
187 #[inline]
189 pub fn len(&self) -> usize {
190 self.len
191 }
192
193 #[inline]
195 pub fn is_empty(&self) -> bool {
196 self.len == 0
197 }
198
199 #[inline]
201 pub fn flags(&self) -> RegisterFlags {
202 self.flags
203 }
204
205 #[inline]
207 pub fn as_slice(&self) -> &[T] {
208 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
211 }
212
213 #[inline]
215 pub fn as_mut_slice(&mut self) -> &mut [T] {
216 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
219 }
220}
221
222impl<T: Copy> Deref for RegisteredMemory<T> {
223 type Target = [T];
224
225 #[inline]
226 fn deref(&self) -> &[T] {
227 self.as_slice()
228 }
229}
230
231impl<T: Copy> DerefMut for RegisteredMemory<T> {
232 #[inline]
233 fn deref_mut(&mut self) -> &mut [T] {
234 self.as_mut_slice()
235 }
236}
237
238impl<T: Copy> Drop for RegisteredMemory<T> {
239 fn drop(&mut self) {
240 #[cfg(not(target_os = "macos"))]
241 {
242 if let Ok(api) = try_driver() {
243 let rc = unsafe { (api.cu_mem_host_unregister)(self.ptr.cast::<c_void>()) };
244 if rc != 0 {
245 tracing::warn!(
246 cuda_error = rc,
247 len = self.len,
248 "cuMemHostUnregister failed during RegisteredMemory drop"
249 );
250 }
251 }
252 }
253 }
254}
255
256pub fn register<T: Copy>(
273 ptr: *mut T,
274 len: usize,
275 flags: RegisterFlags,
276) -> CudaResult<RegisteredMemory<T>> {
277 if len == 0 {
278 return Err(CudaError::InvalidValue);
279 }
280 if ptr.is_null() {
281 return Err(CudaError::InvalidValue);
282 }
283 let byte_size = len
284 .checked_mul(std::mem::size_of::<T>())
285 .ok_or(CudaError::InvalidValue)?;
286
287 #[cfg(target_os = "macos")]
288 {
289 let _ = byte_size;
292 Ok(RegisteredMemory {
293 ptr,
294 len,
295 flags,
296 device_ptr: ptr as CUdeviceptr,
297 })
298 }
299
300 #[cfg(not(target_os = "macos"))]
301 {
302 let api = try_driver()?;
303
304 let rc =
306 unsafe { (api.cu_mem_host_register_v2)(ptr.cast::<c_void>(), byte_size, flags.bits()) };
307 oxicuda_driver::check(rc)?;
308
309 let device_ptr = if flags.contains(RegisterFlags::DEVICE_MAP) {
311 let mut dptr: CUdeviceptr = 0;
312 let rc2 = unsafe {
313 (api.cu_mem_host_get_device_pointer_v2)(&mut dptr, ptr.cast::<c_void>(), 0)
314 };
315 oxicuda_driver::check(rc2)?;
316 dptr
317 } else {
318 0
319 };
320
321 Ok(RegisteredMemory {
322 ptr,
323 len,
324 flags,
325 device_ptr,
326 })
327 }
328}
329
330pub fn register_slice<T: Copy>(
336 slice: &mut [T],
337 flags: RegisterFlags,
338) -> CudaResult<RegisteredMemory<T>> {
339 register(slice.as_mut_ptr(), slice.len(), flags)
340}
341
342pub fn register_vec<T: Copy>(
351 vec: &mut Vec<T>,
352 flags: RegisterFlags,
353) -> CudaResult<RegisteredMemory<T>> {
354 register(vec.as_mut_ptr(), vec.len(), flags)
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
363pub enum RegisteredMemoryType {
364 Host,
366 Device,
368 Unified,
370 Unregistered,
372}
373
374impl fmt::Display for RegisteredMemoryType {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 match self {
377 Self::Host => write!(f, "Host"),
378 Self::Device => write!(f, "Device"),
379 Self::Unified => write!(f, "Unified"),
380 Self::Unregistered => write!(f, "Unregistered"),
381 }
382 }
383}
384
385#[derive(Debug, Clone, Copy)]
387pub struct RegisteredPointerInfo {
388 pub device_ptr: CUdeviceptr,
390 pub is_managed: bool,
392 pub memory_type: RegisteredMemoryType,
394}
395
396pub fn query_registered_pointer_info(ptr: *const u8) -> CudaResult<RegisteredPointerInfo> {
404 if ptr.is_null() {
405 return Err(CudaError::InvalidValue);
406 }
407
408 #[cfg(target_os = "macos")]
409 {
410 Ok(RegisteredPointerInfo {
412 device_ptr: ptr as CUdeviceptr,
413 is_managed: false,
414 memory_type: RegisteredMemoryType::Host,
415 })
416 }
417
418 #[cfg(not(target_os = "macos"))]
419 {
420 let api = try_driver()?;
421 let dev_ptr_val = ptr as CUdeviceptr;
422
423 let mut mem_type: u32 = 0;
425 let rc = unsafe {
426 (api.cu_pointer_get_attribute)(
427 (&mut mem_type as *mut u32).cast::<c_void>(),
428 ffi::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
429 dev_ptr_val,
430 )
431 };
432 let memory_type = if rc != 0 {
433 RegisteredMemoryType::Unregistered
435 } else {
436 match mem_type {
437 ffi::CU_MEMORYTYPE_HOST => RegisteredMemoryType::Host,
438 ffi::CU_MEMORYTYPE_DEVICE => RegisteredMemoryType::Device,
439 ffi::CU_MEMORYTYPE_UNIFIED => RegisteredMemoryType::Unified,
440 _ => RegisteredMemoryType::Unregistered,
441 }
442 };
443
444 let mut managed: u32 = 0;
446 let rc2 = unsafe {
447 (api.cu_pointer_get_attribute)(
448 (&mut managed as *mut u32).cast::<c_void>(),
449 ffi::CU_POINTER_ATTRIBUTE_IS_MANAGED,
450 dev_ptr_val,
451 )
452 };
453 let is_managed = rc2 == 0 && managed != 0;
454
455 let mut dptr: CUdeviceptr = 0;
457 let rc3 = unsafe {
458 (api.cu_pointer_get_attribute)(
459 (&mut dptr as *mut CUdeviceptr).cast::<c_void>(),
460 ffi::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
461 dev_ptr_val,
462 )
463 };
464 if rc3 != 0 {
465 dptr = 0;
466 }
467
468 Ok(RegisteredPointerInfo {
469 device_ptr: dptr,
470 is_managed,
471 memory_type,
472 })
473 }
474}
475
476#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
487 fn flags_default_contains_portable_and_device_map() {
488 assert!(RegisterFlags::DEFAULT.contains(RegisterFlags::PORTABLE));
489 assert!(RegisterFlags::DEFAULT.contains(RegisterFlags::DEVICE_MAP));
490 assert!(!RegisterFlags::DEFAULT.contains(RegisterFlags::IO_MEMORY));
491 assert!(!RegisterFlags::DEFAULT.contains(RegisterFlags::READ_ONLY));
492 }
493
494 #[test]
495 fn flags_bitor_combines() {
496 let combined = RegisterFlags::PORTABLE | RegisterFlags::READ_ONLY;
497 assert!(combined.contains(RegisterFlags::PORTABLE));
498 assert!(combined.contains(RegisterFlags::READ_ONLY));
499 assert!(!combined.contains(RegisterFlags::IO_MEMORY));
500 }
501
502 #[test]
503 fn flags_bitand_intersects() {
504 let a = RegisterFlags::PORTABLE | RegisterFlags::DEVICE_MAP;
505 let b = RegisterFlags::PORTABLE | RegisterFlags::READ_ONLY;
506 let intersected = a & b;
507 assert!(intersected.contains(RegisterFlags::PORTABLE));
508 assert!(!intersected.contains(RegisterFlags::DEVICE_MAP));
509 assert!(!intersected.contains(RegisterFlags::READ_ONLY));
510 }
511
512 #[test]
513 fn flags_display() {
514 assert_eq!(RegisterFlags::NONE.to_string(), "NONE");
515 assert_eq!(RegisterFlags::PORTABLE.to_string(), "PORTABLE");
516 let default_str = RegisterFlags::DEFAULT.to_string();
517 assert!(default_str.contains("PORTABLE"));
518 assert!(default_str.contains("DEVICE_MAP"));
519 }
520
521 #[test]
522 fn flags_bits_roundtrip() {
523 let flags = RegisterFlags::PORTABLE | RegisterFlags::IO_MEMORY;
524 let bits = flags.bits();
525 assert_eq!(RegisterFlags::from_bits(bits), flags);
526 }
527
528 #[test]
529 fn flags_none_is_zero() {
530 assert_eq!(RegisterFlags::NONE.bits(), 0);
531 }
532
533 #[test]
536 fn memory_type_display() {
537 assert_eq!(RegisteredMemoryType::Host.to_string(), "Host");
538 assert_eq!(RegisteredMemoryType::Device.to_string(), "Device");
539 assert_eq!(RegisteredMemoryType::Unified.to_string(), "Unified");
540 assert_eq!(
541 RegisteredMemoryType::Unregistered.to_string(),
542 "Unregistered"
543 );
544 }
545
546 #[test]
547 fn memory_type_equality() {
548 assert_eq!(RegisteredMemoryType::Host, RegisteredMemoryType::Host);
549 assert_ne!(RegisteredMemoryType::Host, RegisteredMemoryType::Device);
550 }
551
552 #[test]
555 fn register_zero_len_fails() {
556 let mut buf = [0u8; 16];
557 let result = register(buf.as_mut_ptr(), 0, RegisterFlags::DEFAULT);
558 assert!(matches!(result, Err(CudaError::InvalidValue)));
559 }
560
561 #[test]
562 fn register_null_ptr_fails() {
563 let result = register::<u8>(std::ptr::null_mut(), 10, RegisterFlags::DEFAULT);
564 assert!(matches!(result, Err(CudaError::InvalidValue)));
565 }
566
567 #[test]
568 fn register_slice_zero_len_fails() {
569 let mut empty: [f32; 0] = [];
570 let result = register_slice(&mut empty, RegisterFlags::DEFAULT);
571 assert!(matches!(result, Err(CudaError::InvalidValue)));
572 }
573
574 #[test]
575 fn register_vec_zero_len_fails() {
576 let mut v: Vec<i32> = Vec::new();
577 let result = register_vec(&mut v, RegisterFlags::DEFAULT);
578 assert!(matches!(result, Err(CudaError::InvalidValue)));
579 }
580
581 #[test]
582 fn query_null_ptr_fails() {
583 let result = query_registered_pointer_info(std::ptr::null());
584 assert!(matches!(result, Err(CudaError::InvalidValue)));
585 }
586
587 #[cfg(target_os = "macos")]
590 mod macos_tests {
591 use super::*;
592
593 #[test]
594 fn register_slice_succeeds_on_macos() {
595 let mut data = vec![1.0f32, 2.0, 3.0, 4.0];
596 let reg = register_slice(data.as_mut_slice(), RegisterFlags::DEFAULT);
597 let reg = reg.ok();
598 assert!(reg.is_some());
599 let reg = reg.inspect(|r| {
600 assert_eq!(r.len(), 4);
601 assert!(!r.is_empty());
602 assert_eq!(r.flags(), RegisterFlags::DEFAULT);
603 assert_eq!(r.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
604 });
605 drop(reg);
606 }
607
608 #[test]
609 fn register_vec_succeeds_on_macos() {
610 let mut v = vec![10u32, 20, 30];
611 let reg = register_vec(&mut v, RegisterFlags::PORTABLE);
612 assert!(reg.is_ok());
613 if let Ok(r) = reg {
614 assert_eq!(r.len(), 3);
615 assert_eq!(r.flags(), RegisterFlags::PORTABLE);
616 assert_ne!(r.device_ptr(), 0);
617 }
618 }
619
620 #[test]
621 fn registered_memory_deref_works() {
622 let mut data = vec![100i64, 200, 300];
623 let reg = register_vec(&mut data, RegisterFlags::DEFAULT);
624 assert!(reg.is_ok());
625 if let Ok(r) = reg {
626 let slice: &[i64] = &r;
628 assert_eq!(slice.len(), 3);
629 assert_eq!(slice[0], 100);
630 assert_eq!(slice[2], 300);
631 }
632 }
633
634 #[test]
635 fn registered_memory_deref_mut_works() {
636 let mut data = vec![1u8, 2, 3, 4, 5];
637 let reg = register_slice(&mut data, RegisterFlags::DEFAULT);
638 assert!(reg.is_ok());
639 if let Ok(mut r) = reg {
640 r[0] = 99;
641 assert_eq!(r[0], 99);
642 let mslice: &mut [u8] = &mut r;
643 mslice[4] = 88;
644 assert_eq!(mslice[4], 88);
645 }
646 }
647
648 #[test]
649 fn query_pointer_info_on_macos() {
650 let data = [42u8; 64];
651 let info = query_registered_pointer_info(data.as_ptr());
652 assert!(info.is_ok());
653 if let Ok(info) = info {
654 assert!(!info.is_managed);
655 assert_eq!(info.memory_type, RegisteredMemoryType::Host);
656 assert_ne!(info.device_ptr, 0);
657 }
658 }
659
660 #[test]
661 fn registered_memory_as_ptr_mut_ptr() {
662 let mut data = vec![5.0f64; 10];
663 let original_ptr = data.as_mut_ptr();
664 let reg = register_vec(&mut data, RegisterFlags::DEFAULT);
665 assert!(reg.is_ok());
666 if let Ok(mut r) = reg {
667 assert_eq!(r.as_ptr(), original_ptr as *const f64);
668 assert_eq!(r.as_mut_ptr(), original_ptr);
669 }
670 }
671 }
672
673 #[cfg(feature = "gpu-tests")]
676 mod gpu_tests {
677 use super::*;
678
679 #[test]
680 fn register_and_unregister_on_gpu() {
681 if oxicuda_driver::init().is_err() || oxicuda_driver::Device::count().unwrap_or(0) == 0
686 {
687 return;
688 }
689 let Ok(dev) = oxicuda_driver::Device::get(0) else {
690 return;
691 };
692 let Ok(_ctx) = oxicuda_driver::Context::new(&dev) else {
693 return;
694 };
695 let mut data = vec![0.0f32; 4096];
698 let reg = register_vec(&mut data, RegisterFlags::DEFAULT);
699 assert!(reg.is_ok(), "registration failed: {:?}", reg.err());
700 if let Ok(r) = reg {
701 assert_eq!(r.len(), 4096);
702 assert!(r.device_ptr() != 0, "device_ptr should be non-zero");
703 }
704 }
705 }
706}