kn_cuda_sys/wrapper/mem/
device.rs

1use std::ffi::c_void;
2use std::ptr::null_mut;
3use std::sync::Arc;
4
5use crate::bindings::{cudaFree, cudaMalloc, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind};
6use crate::wrapper::handle::{CudaDevice, CudaStream};
7use crate::wrapper::mem::pinned::PinnedMem;
8use crate::wrapper::status::Status;
9
10/// A reference-counted pointer into a [DeviceBuffer]. The buffer cannot be constructed directly,
11/// instead it can only be created by allocating a new [DevicePtr] with [DevicePtr::alloc].
12///
13/// The inner [DeviceBuffer] is automatically freed when there are no [DevicePtr] any more that refer to it.
14/// Since the memory may be shared all accessor methods are marked unsafe.
15///
16/// Cloning this type does not copy the underlying memory, but only increases the reference count.
17#[derive(Debug, Clone, Eq, PartialEq, Hash)]
18pub struct DevicePtr {
19    buffer: Arc<DeviceBuffer>,
20    offset: isize,
21}
22
23// TODO it's a bit weird that this is public without being able to construct it,
24// but it's useful to let the user know it exists in docs.
25/// A device allocation as returned by cudaMalloc.
26#[derive(Debug, Eq, PartialEq, Hash)]
27pub struct DeviceBuffer {
28    device: CudaDevice,
29    base_ptr: *mut c_void,
30    len_bytes: isize,
31}
32
33// TODO is this correct? We've don't even attempt device-side memory safety, but can this cause cpu-side issues?
34// TODO should we implement Sync? It's probably never a good idea to actually share pointers between threads...
35unsafe impl Send for DeviceBuffer {}
36
37unsafe impl Sync for DeviceBuffer {}
38
39impl Drop for DeviceBuffer {
40    fn drop(&mut self) {
41        unsafe {
42            self.device.switch_to();
43            cudaFree(self.base_ptr).unwrap_in_drop()
44        }
45    }
46}
47
48impl DevicePtr {
49    pub fn alloc(device: CudaDevice, len_bytes: usize) -> Self {
50        unsafe {
51            let mut device_ptr = null_mut();
52
53            device.switch_to();
54            cudaMalloc(&mut device_ptr as *mut _, len_bytes).unwrap();
55
56            let inner = DeviceBuffer {
57                device,
58                base_ptr: device_ptr,
59                len_bytes: len_bytes as isize,
60            };
61            DevicePtr {
62                buffer: Arc::new(inner),
63                offset: 0,
64            }
65        }
66    }
67
68    pub fn device(&self) -> CudaDevice {
69        self.buffer.device
70    }
71
72    pub fn offset_bytes(self, offset: isize) -> DevicePtr {
73        let new_offset = self.offset + offset;
74
75        if self.buffer.len_bytes == 0 {
76            assert_eq!(offset, 0, "Non-zero offset not allowed on empty buffer");
77        } else {
78            assert!(
79                (0..self.buffer.len_bytes).contains(&new_offset),
80                "Offset {} is out of range on {:?}",
81                offset,
82                self
83            );
84        }
85
86        DevicePtr {
87            buffer: self.buffer,
88            offset: new_offset,
89        }
90    }
91
92    pub unsafe fn ptr(&self) -> *mut c_void {
93        self.buffer.base_ptr.offset(self.offset)
94    }
95
96    /// The number of `DevicePtr` sharing the underlying buffer that are still alive.
97    pub fn shared_count(&self) -> usize {
98        Arc::strong_count(&self.buffer)
99    }
100
101    pub unsafe fn copy_linear_from_host(&self, buffer: &[u8]) {
102        self.assert_linear_in_bounds(buffer.len());
103
104        self.device().switch_to();
105        cudaMemcpy(
106            self.ptr(),
107            buffer as *const _ as *const _,
108            buffer.len(),
109            cudaMemcpyKind::cudaMemcpyHostToDevice,
110        )
111        .unwrap();
112    }
113
114    pub unsafe fn copy_linear_from_host_async(&self, buffer: &PinnedMem, stream: &CudaStream) {
115        self.assert_linear_in_bounds(buffer.len_bytes());
116
117        self.device().switch_to();
118        cudaMemcpyAsync(
119            self.ptr(),
120            buffer.ptr(),
121            buffer.len_bytes(),
122            cudaMemcpyKind::cudaMemcpyDeviceToHost,
123            stream.inner(),
124        )
125        .unwrap();
126    }
127
128    pub unsafe fn copy_linear_to_host(&self, buffer: &mut [u8]) {
129        self.assert_linear_in_bounds(buffer.len());
130
131        self.device().switch_to();
132        cudaMemcpy(
133            buffer as *mut _ as *mut _,
134            self.ptr(),
135            buffer.len(),
136            cudaMemcpyKind::cudaMemcpyDeviceToHost,
137        )
138        .unwrap();
139    }
140
141    pub unsafe fn copy_linear_to_host_async(&self, buffer: &mut PinnedMem, stream: &CudaStream) {
142        self.assert_linear_in_bounds(buffer.len_bytes());
143
144        self.device().switch_to();
145        cudaMemcpyAsync(
146            buffer.ptr(),
147            self.ptr(),
148            buffer.len_bytes(),
149            cudaMemcpyKind::cudaMemcpyDeviceToHost,
150            stream.inner(),
151        )
152        .unwrap();
153    }
154
155    pub unsafe fn copy_linear_from_device(&self, other: &DevicePtr, len_bytes: usize) {
156        assert_eq!(
157            self.device(),
158            other.device(),
159            "Can only copy between tensors on the same device"
160        );
161
162        self.assert_linear_in_bounds(len_bytes);
163        other.assert_linear_in_bounds(len_bytes);
164
165        self.device().switch_to();
166        cudaMemcpy(
167            self.ptr(),
168            other.ptr(),
169            len_bytes,
170            cudaMemcpyKind::cudaMemcpyDeviceToDevice,
171        )
172        .unwrap();
173    }
174
175    fn assert_linear_in_bounds(&self, len: usize) {
176        assert!(
177            (self.offset + len as isize) <= self.buffer.len_bytes,
178            "Linear slice with length {} out of bounds for {:?}",
179            len,
180            self
181        );
182    }
183}