accel/memory/
registered.rs1use super::*;
2use crate::*;
3use cuda::*;
4use std::{
5 ffi::c_void,
6 ops::{Deref, DerefMut},
7 sync::Arc,
8};
9
10pub struct RegisteredMemory<'a, T> {
11 ctx: Arc<Context>,
12 mem: &'a mut [T],
13}
14
15impl<T> Deref for RegisteredMemory<'_, T> {
16 type Target = [T];
17 fn deref(&self) -> &[T] {
18 self.mem
19 }
20}
21
22impl<T> DerefMut for RegisteredMemory<'_, T> {
23 fn deref_mut(&mut self) -> &mut [T] {
24 self.mem
25 }
26}
27
28impl<T> Drop for RegisteredMemory<'_, T> {
29 fn drop(&mut self) {
30 if let Err(e) = unsafe {
31 contexted_call!(
32 &self.ctx,
33 cuMemHostUnregister,
34 self.mem.as_mut_ptr() as *mut c_void
35 )
36 } {
37 log::error!("Failed to unregister memory: {:?}", e);
38 }
39 }
40}
41
42impl<'a, T: Scalar> RegisteredMemory<'a, T> {
43 pub fn new(ctx: Arc<Context>, mem: &'a mut [T]) -> Self {
44 unsafe {
45 contexted_call!(
46 &ctx,
47 cuMemHostRegister_v2,
48 mem.as_mut_ptr() as *mut c_void,
49 mem.len() * T::size_of(),
50 0
51 )
52 }
53 .expect("Failed to register host memory into CUDA memory system");
54 Self { ctx, mem }
55 }
56}
57
58impl<T: Scalar> Memory for RegisteredMemory<'_, T> {
59 type Elem = T;
60
61 fn head_addr(&self) -> *const T {
62 self.mem.as_ptr()
63 }
64
65 fn head_addr_mut(&mut self) -> *mut T {
66 self.mem.as_mut_ptr()
67 }
68
69 fn num_elem(&self) -> usize {
70 self.mem.len()
71 }
72
73 fn memory_type(&self) -> MemoryType {
74 MemoryType::Host
75 }
76}
77
78impl<T> Contexted for RegisteredMemory<'_, T> {
79 fn get_context(&self) -> Arc<Context> {
80 self.ctx.clone()
81 }
82}
83
84impl<T: Scalar> Memcpy<Self> for RegisteredMemory<'_, T> {
85 fn copy_from(&mut self, src: &Self) {
86 assert_ne!(self.head_addr(), src.head_addr());
87 assert_eq!(self.num_elem(), src.num_elem());
88 self.copy_from_slice(src)
89 }
90}
91
92impl<T: Scalar> Memcpy<PageLockedMemory<T>> for RegisteredMemory<'_, T> {
93 fn copy_from(&mut self, src: &PageLockedMemory<T>) {
94 assert_ne!(self.head_addr(), src.head_addr());
95 assert_eq!(self.num_elem(), src.num_elem());
96 self.copy_from_slice(src)
97 }
98}
99
100impl<T: Scalar> Memcpy<DeviceMemory<T>> for RegisteredMemory<'_, T> {
101 fn copy_from(&mut self, src: &DeviceMemory<T>) {
102 assert_ne!(self.head_addr(), src.head_addr());
103 assert_eq!(self.num_elem(), src.num_elem());
104 unsafe {
105 contexted_call!(
106 &self.get_context(),
107 cuMemcpyDtoH_v2,
108 self.as_mut_ptr() as *mut _,
109 src.as_ptr() as CUdeviceptr,
110 self.num_elem() * T::size_of()
111 )
112 }
113 .expect("memcpy from Device to registered host memory failed")
114 }
115}
116
117impl<T: Scalar> Memset for RegisteredMemory<'_, T> {
118 fn set(&mut self, value: Self::Elem) {
119 self.iter_mut().for_each(|v| *v = value);
120 }
121}
122
123impl<T: Scalar> Continuous for RegisteredMemory<'_, T> {
124 fn as_slice(&self) -> &[T] {
125 self
126 }
127 fn as_mut_slice(&mut self) -> &mut [T] {
128 self
129 }
130}