Skip to main content

accel/memory/
registered.rs

1use super::*;
2use crate::*;
3use cuda::*;
4use std::{
5    ffi::c_void,
6    ops::{Deref, DerefMut},
7    sync::Arc,
8};
9
10pub struct RegisteredMemory<'a, T> {
11    ctx: Arc<Context>,
12    mem: &'a mut [T],
13}
14
15impl<T> Deref for RegisteredMemory<'_, T> {
16    type Target = [T];
17    fn deref(&self) -> &[T] {
18        self.mem
19    }
20}
21
22impl<T> DerefMut for RegisteredMemory<'_, T> {
23    fn deref_mut(&mut self) -> &mut [T] {
24        self.mem
25    }
26}
27
28impl<T> Drop for RegisteredMemory<'_, T> {
29    fn drop(&mut self) {
30        if let Err(e) = unsafe {
31            contexted_call!(
32                &self.ctx,
33                cuMemHostUnregister,
34                self.mem.as_mut_ptr() as *mut c_void
35            )
36        } {
37            log::error!("Failed to unregister memory: {:?}", e);
38        }
39    }
40}
41
42impl<'a, T: Scalar> RegisteredMemory<'a, T> {
43    pub fn new(ctx: Arc<Context>, mem: &'a mut [T]) -> Self {
44        unsafe {
45            contexted_call!(
46                &ctx,
47                cuMemHostRegister_v2,
48                mem.as_mut_ptr() as *mut c_void,
49                mem.len() * T::size_of(),
50                0
51            )
52        }
53        .expect("Failed to register host memory into CUDA memory system");
54        Self { ctx, mem }
55    }
56}
57
58impl<T: Scalar> Memory for RegisteredMemory<'_, T> {
59    type Elem = T;
60
61    fn head_addr(&self) -> *const T {
62        self.mem.as_ptr()
63    }
64
65    fn head_addr_mut(&mut self) -> *mut T {
66        self.mem.as_mut_ptr()
67    }
68
69    fn num_elem(&self) -> usize {
70        self.mem.len()
71    }
72
73    fn memory_type(&self) -> MemoryType {
74        MemoryType::Host
75    }
76}
77
78impl<T> Contexted for RegisteredMemory<'_, T> {
79    fn get_context(&self) -> Arc<Context> {
80        self.ctx.clone()
81    }
82}
83
84impl<T: Scalar> Memcpy<Self> for RegisteredMemory<'_, T> {
85    fn copy_from(&mut self, src: &Self) {
86        assert_ne!(self.head_addr(), src.head_addr());
87        assert_eq!(self.num_elem(), src.num_elem());
88        self.copy_from_slice(src)
89    }
90}
91
92impl<T: Scalar> Memcpy<PageLockedMemory<T>> for RegisteredMemory<'_, T> {
93    fn copy_from(&mut self, src: &PageLockedMemory<T>) {
94        assert_ne!(self.head_addr(), src.head_addr());
95        assert_eq!(self.num_elem(), src.num_elem());
96        self.copy_from_slice(src)
97    }
98}
99
100impl<T: Scalar> Memcpy<DeviceMemory<T>> for RegisteredMemory<'_, T> {
101    fn copy_from(&mut self, src: &DeviceMemory<T>) {
102        assert_ne!(self.head_addr(), src.head_addr());
103        assert_eq!(self.num_elem(), src.num_elem());
104        unsafe {
105            contexted_call!(
106                &self.get_context(),
107                cuMemcpyDtoH_v2,
108                self.as_mut_ptr() as *mut _,
109                src.as_ptr() as CUdeviceptr,
110                self.num_elem() * T::size_of()
111            )
112        }
113        .expect("memcpy from Device to registered host memory failed")
114    }
115}
116
117impl<T: Scalar> Memset for RegisteredMemory<'_, T> {
118    fn set(&mut self, value: Self::Elem) {
119        self.iter_mut().for_each(|v| *v = value);
120    }
121}
122
123impl<T: Scalar> Continuous for RegisteredMemory<'_, T> {
124    fn as_slice(&self) -> &[T] {
125        self
126    }
127    fn as_mut_slice(&mut self) -> &mut [T] {
128        self
129    }
130}