kn_cuda_sys/wrapper/mem/
device.rs1use std::ffi::c_void;
2use std::ptr::null_mut;
3use std::sync::Arc;
4
5use crate::bindings::{cudaFree, cudaMalloc, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind};
6use crate::wrapper::handle::{CudaDevice, CudaStream};
7use crate::wrapper::mem::pinned::PinnedMem;
8use crate::wrapper::status::Status;
9
10#[derive(Debug, Clone, Eq, PartialEq, Hash)]
18pub struct DevicePtr {
19 buffer: Arc<DeviceBuffer>,
20 offset: isize,
21}
22
23#[derive(Debug, Eq, PartialEq, Hash)]
27pub struct DeviceBuffer {
28 device: CudaDevice,
29 base_ptr: *mut c_void,
30 len_bytes: isize,
31}
32
33unsafe impl Send for DeviceBuffer {}
36
37unsafe impl Sync for DeviceBuffer {}
38
39impl Drop for DeviceBuffer {
40 fn drop(&mut self) {
41 unsafe {
42 self.device.switch_to();
43 cudaFree(self.base_ptr).unwrap_in_drop()
44 }
45 }
46}
47
48impl DevicePtr {
49 pub fn alloc(device: CudaDevice, len_bytes: usize) -> Self {
50 unsafe {
51 let mut device_ptr = null_mut();
52
53 device.switch_to();
54 cudaMalloc(&mut device_ptr as *mut _, len_bytes).unwrap();
55
56 let inner = DeviceBuffer {
57 device,
58 base_ptr: device_ptr,
59 len_bytes: len_bytes as isize,
60 };
61 DevicePtr {
62 buffer: Arc::new(inner),
63 offset: 0,
64 }
65 }
66 }
67
68 pub fn device(&self) -> CudaDevice {
69 self.buffer.device
70 }
71
72 pub fn offset_bytes(self, offset: isize) -> DevicePtr {
73 let new_offset = self.offset + offset;
74
75 if self.buffer.len_bytes == 0 {
76 assert_eq!(offset, 0, "Non-zero offset not allowed on empty buffer");
77 } else {
78 assert!(
79 (0..self.buffer.len_bytes).contains(&new_offset),
80 "Offset {} is out of range on {:?}",
81 offset,
82 self
83 );
84 }
85
86 DevicePtr {
87 buffer: self.buffer,
88 offset: new_offset,
89 }
90 }
91
92 pub unsafe fn ptr(&self) -> *mut c_void {
93 self.buffer.base_ptr.offset(self.offset)
94 }
95
96 pub fn shared_count(&self) -> usize {
98 Arc::strong_count(&self.buffer)
99 }
100
101 pub unsafe fn copy_linear_from_host(&self, buffer: &[u8]) {
102 self.assert_linear_in_bounds(buffer.len());
103
104 self.device().switch_to();
105 cudaMemcpy(
106 self.ptr(),
107 buffer as *const _ as *const _,
108 buffer.len(),
109 cudaMemcpyKind::cudaMemcpyHostToDevice,
110 )
111 .unwrap();
112 }
113
114 pub unsafe fn copy_linear_from_host_async(&self, buffer: &PinnedMem, stream: &CudaStream) {
115 self.assert_linear_in_bounds(buffer.len_bytes());
116
117 self.device().switch_to();
118 cudaMemcpyAsync(
119 self.ptr(),
120 buffer.ptr(),
121 buffer.len_bytes(),
122 cudaMemcpyKind::cudaMemcpyDeviceToHost,
123 stream.inner(),
124 )
125 .unwrap();
126 }
127
128 pub unsafe fn copy_linear_to_host(&self, buffer: &mut [u8]) {
129 self.assert_linear_in_bounds(buffer.len());
130
131 self.device().switch_to();
132 cudaMemcpy(
133 buffer as *mut _ as *mut _,
134 self.ptr(),
135 buffer.len(),
136 cudaMemcpyKind::cudaMemcpyDeviceToHost,
137 )
138 .unwrap();
139 }
140
141 pub unsafe fn copy_linear_to_host_async(&self, buffer: &mut PinnedMem, stream: &CudaStream) {
142 self.assert_linear_in_bounds(buffer.len_bytes());
143
144 self.device().switch_to();
145 cudaMemcpyAsync(
146 buffer.ptr(),
147 self.ptr(),
148 buffer.len_bytes(),
149 cudaMemcpyKind::cudaMemcpyDeviceToHost,
150 stream.inner(),
151 )
152 .unwrap();
153 }
154
155 pub unsafe fn copy_linear_from_device(&self, other: &DevicePtr, len_bytes: usize) {
156 assert_eq!(
157 self.device(),
158 other.device(),
159 "Can only copy between tensors on the same device"
160 );
161
162 self.assert_linear_in_bounds(len_bytes);
163 other.assert_linear_in_bounds(len_bytes);
164
165 self.device().switch_to();
166 cudaMemcpy(
167 self.ptr(),
168 other.ptr(),
169 len_bytes,
170 cudaMemcpyKind::cudaMemcpyDeviceToDevice,
171 )
172 .unwrap();
173 }
174
175 fn assert_linear_in_bounds(&self, len: usize) {
176 assert!(
177 (self.offset + len as isize) <= self.buffer.len_bytes,
178 "Linear slice with length {} out of bounds for {:?}",
179 len,
180 self
181 );
182 }
183}