kn_cuda_sys/wrapper/
handle.rs1use std::mem::MaybeUninit;
2use std::ptr::null_mut;
3
4use bytemuck::cast_slice;
5
6use crate::bindings::{
7 cublasCreate_v2, cublasDestroy_v2, cublasHandle_t, cublasLtCreate, cublasLtDestroy, cublasLtHandle_t,
8 cublasSetStream_v2, cudaDeviceAttr, cudaDeviceGetAttribute, cudaDeviceProp, cudaEventRecord, cudaGetDevice,
9 cudaGetDeviceCount, cudaSetDevice, cudaStream_t, cudaStreamBeginCapture, cudaStreamCaptureMode,
10 cudaStreamCreate, cudaStreamDestroy, cudaStreamEndCapture, cudaStreamSynchronize, cudaStreamWaitEvent, cudnnCreate,
11 cudnnDestroy, cudnnHandle_t, cudnnSetStream,
12};
13use crate::bindings::cudaGetDeviceProperties_v2 as cudaGetDeviceProperties;
15use crate::wrapper::event::CudaEvent;
16use crate::wrapper::graph::CudaGraph;
17use crate::wrapper::mem::device::DevicePtr;
18use crate::wrapper::status::Status;
19
20#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
26pub struct CudaDevice(i32);
27
28#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
29pub struct ComputeCapability {
30 pub major: i32,
31 pub minor: i32,
32}
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
35pub struct CudaDeviceNotAvailable;
36
37impl CudaDevice {
38 pub fn new(device: i32) -> Result<Self, CudaDeviceNotAvailable> {
39 if 0 <= device && device <= cuda_device_count() {
40 Ok(CudaDevice(device))
41 } else {
42 Err(CudaDeviceNotAvailable)
43 }
44 }
45
46 pub fn all() -> impl Iterator<Item = Self> {
47 (0..cuda_device_count()).map(CudaDevice)
48 }
49
50 pub unsafe fn current() -> CudaDevice {
51 let mut inner = 0;
52 cudaGetDevice(&mut inner as *mut _).unwrap();
53 CudaDevice(inner)
54 }
55
56 pub fn inner(self) -> i32 {
57 self.0
58 }
59
60 pub fn switch_to(self) {
63 unsafe { cudaSetDevice(self.inner()).unwrap() }
64 }
65
66 pub fn alloc(self, len_bytes: usize) -> DevicePtr {
67 DevicePtr::alloc(self, len_bytes)
68 }
69
70 pub fn properties(self) -> cudaDeviceProp {
71 unsafe {
72 self.switch_to();
73 let mut properties = MaybeUninit::uninit();
74 cudaGetDeviceProperties(properties.as_mut_ptr(), self.inner()).unwrap();
75 properties.assume_init()
76 }
77 }
78
79 pub fn attribute(self, attribute: cudaDeviceAttr) -> i32 {
80 unsafe {
81 let mut value: i32 = 0;
82 cudaDeviceGetAttribute(&mut value as *mut _, attribute, self.inner()).unwrap();
83 value
84 }
85 }
86
87 pub fn compute_capability(self) -> ComputeCapability {
88 ComputeCapability {
89 major: self.attribute(cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor),
90 minor: self.attribute(cudaDeviceAttr::cudaDevAttrComputeCapabilityMinor),
91 }
92 }
93
94 pub fn name(self) -> String {
95 let properties = self.properties();
96 let name = &properties.name;
97
98 let len = name.iter().position(|&c| c == 0).unwrap_or(name.len());
99 std::str::from_utf8(cast_slice::<i8, u8>(&name[..len]))
100 .unwrap()
101 .to_owned()
102 }
103}
104
105fn cuda_device_count() -> i32 {
106 unsafe {
107 let mut count = 0;
108 cudaGetDeviceCount(&mut count as *mut _).unwrap();
109 count
110 }
111}
112
113#[derive(Debug)]
115pub struct CudaStream {
116 device: CudaDevice,
117 inner: cudaStream_t,
118}
119
120impl Drop for CudaStream {
121 fn drop(&mut self) {
122 unsafe {
123 cudaStreamDestroy(self.inner).unwrap_in_drop();
124 }
125 }
126}
127
128impl CudaStream {
129 pub fn new(device: CudaDevice) -> Self {
130 unsafe {
131 let mut inner = null_mut();
132 device.switch_to();
133 cudaStreamCreate(&mut inner as *mut _).unwrap();
134 CudaStream { device, inner }
135 }
136 }
137
138 pub fn synchronize(&self) {
139 unsafe { cudaStreamSynchronize(self.inner()).unwrap() }
140 }
141
142 pub fn device(&self) -> CudaDevice {
143 self.device
144 }
145
146 pub unsafe fn inner(&self) -> cudaStream_t {
147 self.inner
148 }
149
150 pub fn record_event(&self) -> CudaEvent {
151 let event = CudaEvent::new();
152 self.record_existing_event(&event);
153 event
154 }
155
156 pub fn record_existing_event(&self, event: &CudaEvent) {
157 unsafe { cudaEventRecord(event.inner(), self.inner()).unwrap() }
158 }
159
160 pub fn wait_for_event(&self, event: &CudaEvent) {
161 unsafe {
162 cudaStreamWaitEvent(self.inner, event.inner(), 0).unwrap();
163 }
164 }
165
166 pub unsafe fn begin_capture(&self) {
167 cudaStreamBeginCapture(self.inner(), cudaStreamCaptureMode::cudaStreamCaptureModeGlobal).unwrap()
168 }
169
170 pub unsafe fn end_capture(&self) -> CudaGraph {
171 let mut graph = null_mut();
172 cudaStreamEndCapture(self.inner(), &mut graph as *mut _).unwrap();
173 CudaGraph::new_from_inner(graph)
174 }
175}
176
177#[derive(Debug)]
178pub struct CudnnHandle {
179 inner: cudnnHandle_t,
180 stream: CudaStream,
181}
182
183impl Drop for CudnnHandle {
184 fn drop(&mut self) {
185 unsafe {
186 self.device().switch_to();
187 cudnnDestroy(self.inner).unwrap_in_drop()
188 }
189 }
190}
191
192impl CudnnHandle {
193 pub fn new(device: CudaDevice) -> Self {
194 CudnnHandle::new_with_stream(CudaStream::new(device))
195 }
196
197 pub fn new_with_stream(stream: CudaStream) -> Self {
198 unsafe {
199 let mut inner = null_mut();
200 stream.device.switch_to();
201 cudnnCreate(&mut inner as *mut _).unwrap();
202 cudnnSetStream(inner, stream.inner()).unwrap();
203 CudnnHandle { inner, stream }
204 }
205 }
206
207 pub fn device(&self) -> CudaDevice {
208 self.stream.device()
209 }
210
211 pub fn stream(&self) -> &CudaStream {
212 &self.stream
213 }
214
215 pub unsafe fn inner(&self) -> cudnnHandle_t {
216 self.inner
217 }
218}
219
220#[derive(Debug)]
221pub struct CublasHandle {
222 inner: cublasHandle_t,
223 stream: CudaStream,
224}
225
226impl Drop for CublasHandle {
227 fn drop(&mut self) {
228 unsafe { cublasDestroy_v2(self.inner).unwrap_in_drop() }
229 }
230}
231
232impl CublasHandle {
233 pub fn new(device: CudaDevice) -> Self {
234 CublasHandle::new_with_stream(CudaStream::new(device))
235 }
236
237 pub fn new_with_stream(stream: CudaStream) -> Self {
238 unsafe {
239 let mut inner = null_mut();
240 stream.device.switch_to();
241 cublasCreate_v2(&mut inner as *mut _).unwrap();
242 cublasSetStream_v2(inner, stream.inner()).unwrap();
243 CublasHandle { inner, stream }
244 }
245 }
246
247 pub fn stream(&self) -> &CudaStream {
248 &self.stream
249 }
250
251 pub unsafe fn inner(&self) -> cublasHandle_t {
252 self.inner
253 }
254}
255
256#[derive(Debug)]
257pub struct CublasLtHandle {
258 inner: cublasLtHandle_t,
259}
260
261impl Drop for CublasLtHandle {
262 fn drop(&mut self) {
263 unsafe { cublasLtDestroy(self.inner).unwrap_in_drop() }
264 }
265}
266
267impl CublasLtHandle {
268 pub fn new(device: CudaDevice) -> Self {
269 unsafe {
270 let mut inner = null_mut();
271 device.switch_to();
272 cublasLtCreate(&mut inner as *mut _).unwrap();
273 CublasLtHandle { inner }
274 }
275 }
276
277 pub unsafe fn inner(&self) -> cublasLtHandle_t {
278 self.inner
279 }
280}