kn_cuda_sys/wrapper/
handle.rs

1use std::mem::MaybeUninit;
2use std::ptr::null_mut;
3
4use bytemuck::cast_slice;
5
6use crate::bindings::{
7    cublasCreate_v2, cublasDestroy_v2, cublasHandle_t, cublasLtCreate, cublasLtDestroy, cublasLtHandle_t,
8    cublasSetStream_v2, cudaDeviceAttr, cudaDeviceGetAttribute, cudaDeviceProp, cudaEventRecord, cudaGetDevice,
9    cudaGetDeviceCount, cudaSetDevice, cudaStream_t, cudaStreamBeginCapture, cudaStreamCaptureMode,
10    cudaStreamCreate, cudaStreamDestroy, cudaStreamEndCapture, cudaStreamSynchronize, cudaStreamWaitEvent, cudnnCreate,
11    cudnnDestroy, cudnnHandle_t, cudnnSetStream,
12};
13// TODO fix this annoying v2 import once https://github.com/rust-lang/rust-bindgen/issues/2544 is fixed
14use crate::bindings::cudaGetDeviceProperties_v2 as cudaGetDeviceProperties;
15use crate::wrapper::event::CudaEvent;
16use crate::wrapper::graph::CudaGraph;
17use crate::wrapper::mem::device::DevicePtr;
18use crate::wrapper::status::Status;
19
20/// A cuda device index.
21///
22/// This crate tries to eliminate the global "current device" cuda state:
23/// Every cuda call that depends on the device should be preceded by `device.switch_to()`,
24/// which corresponds to [cudaSetDevice].
25#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
26pub struct CudaDevice(i32);
27
28#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
29pub struct ComputeCapability {
30    pub major: i32,
31    pub minor: i32,
32}
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
35pub struct CudaDeviceNotAvailable;
36
37impl CudaDevice {
38    pub fn new(device: i32) -> Result<Self, CudaDeviceNotAvailable> {
39        if 0 <= device && device <= cuda_device_count() {
40            Ok(CudaDevice(device))
41        } else {
42            Err(CudaDeviceNotAvailable)
43        }
44    }
45
46    pub fn all() -> impl Iterator<Item = Self> {
47        (0..cuda_device_count()).map(CudaDevice)
48    }
49
50    pub unsafe fn current() -> CudaDevice {
51        let mut inner = 0;
52        cudaGetDevice(&mut inner as *mut _).unwrap();
53        CudaDevice(inner)
54    }
55
56    pub fn inner(self) -> i32 {
57        self.0
58    }
59
60    // Set the current cuda device to this device.
61    //TODO is this enough when there are multiple threads running?
62    pub fn switch_to(self) {
63        unsafe { cudaSetDevice(self.inner()).unwrap() }
64    }
65
66    pub fn alloc(self, len_bytes: usize) -> DevicePtr {
67        DevicePtr::alloc(self, len_bytes)
68    }
69
70    pub fn properties(self) -> cudaDeviceProp {
71        unsafe {
72            self.switch_to();
73            let mut properties = MaybeUninit::uninit();
74            cudaGetDeviceProperties(properties.as_mut_ptr(), self.inner()).unwrap();
75            properties.assume_init()
76        }
77    }
78
79    pub fn attribute(self, attribute: cudaDeviceAttr) -> i32 {
80        unsafe {
81            let mut value: i32 = 0;
82            cudaDeviceGetAttribute(&mut value as *mut _, attribute, self.inner()).unwrap();
83            value
84        }
85    }
86
87    pub fn compute_capability(self) -> ComputeCapability {
88        ComputeCapability {
89            major: self.attribute(cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor),
90            minor: self.attribute(cudaDeviceAttr::cudaDevAttrComputeCapabilityMinor),
91        }
92    }
93
94    pub fn name(self) -> String {
95        let properties = self.properties();
96        let name = &properties.name;
97
98        let len = name.iter().position(|&c| c == 0).unwrap_or(name.len());
99        std::str::from_utf8(cast_slice::<i8, u8>(&name[..len]))
100            .unwrap()
101            .to_owned()
102    }
103}
104
105fn cuda_device_count() -> i32 {
106    unsafe {
107        let mut count = 0;
108        cudaGetDeviceCount(&mut count as *mut _).unwrap();
109        count
110    }
111}
112
113//TODO copy? clone? default stream?
114#[derive(Debug)]
115pub struct CudaStream {
116    device: CudaDevice,
117    inner: cudaStream_t,
118}
119
120impl Drop for CudaStream {
121    fn drop(&mut self) {
122        unsafe {
123            cudaStreamDestroy(self.inner).unwrap_in_drop();
124        }
125    }
126}
127
128impl CudaStream {
129    pub fn new(device: CudaDevice) -> Self {
130        unsafe {
131            let mut inner = null_mut();
132            device.switch_to();
133            cudaStreamCreate(&mut inner as *mut _).unwrap();
134            CudaStream { device, inner }
135        }
136    }
137
138    pub fn synchronize(&self) {
139        unsafe { cudaStreamSynchronize(self.inner()).unwrap() }
140    }
141
142    pub fn device(&self) -> CudaDevice {
143        self.device
144    }
145
146    pub unsafe fn inner(&self) -> cudaStream_t {
147        self.inner
148    }
149
150    pub fn record_event(&self) -> CudaEvent {
151        let event = CudaEvent::new();
152        self.record_existing_event(&event);
153        event
154    }
155
156    pub fn record_existing_event(&self, event: &CudaEvent) {
157        unsafe { cudaEventRecord(event.inner(), self.inner()).unwrap() }
158    }
159
160    pub fn wait_for_event(&self, event: &CudaEvent) {
161        unsafe {
162            cudaStreamWaitEvent(self.inner, event.inner(), 0).unwrap();
163        }
164    }
165
166    pub unsafe fn begin_capture(&self) {
167        cudaStreamBeginCapture(self.inner(), cudaStreamCaptureMode::cudaStreamCaptureModeGlobal).unwrap()
168    }
169
170    pub unsafe fn end_capture(&self) -> CudaGraph {
171        let mut graph = null_mut();
172        cudaStreamEndCapture(self.inner(), &mut graph as *mut _).unwrap();
173        CudaGraph::new_from_inner(graph)
174    }
175}
176
177#[derive(Debug)]
178pub struct CudnnHandle {
179    inner: cudnnHandle_t,
180    stream: CudaStream,
181}
182
183impl Drop for CudnnHandle {
184    fn drop(&mut self) {
185        unsafe {
186            self.device().switch_to();
187            cudnnDestroy(self.inner).unwrap_in_drop()
188        }
189    }
190}
191
192impl CudnnHandle {
193    pub fn new(device: CudaDevice) -> Self {
194        CudnnHandle::new_with_stream(CudaStream::new(device))
195    }
196
197    pub fn new_with_stream(stream: CudaStream) -> Self {
198        unsafe {
199            let mut inner = null_mut();
200            stream.device.switch_to();
201            cudnnCreate(&mut inner as *mut _).unwrap();
202            cudnnSetStream(inner, stream.inner()).unwrap();
203            CudnnHandle { inner, stream }
204        }
205    }
206
207    pub fn device(&self) -> CudaDevice {
208        self.stream.device()
209    }
210
211    pub fn stream(&self) -> &CudaStream {
212        &self.stream
213    }
214
215    pub unsafe fn inner(&self) -> cudnnHandle_t {
216        self.inner
217    }
218}
219
220#[derive(Debug)]
221pub struct CublasHandle {
222    inner: cublasHandle_t,
223    stream: CudaStream,
224}
225
226impl Drop for CublasHandle {
227    fn drop(&mut self) {
228        unsafe { cublasDestroy_v2(self.inner).unwrap_in_drop() }
229    }
230}
231
232impl CublasHandle {
233    pub fn new(device: CudaDevice) -> Self {
234        CublasHandle::new_with_stream(CudaStream::new(device))
235    }
236
237    pub fn new_with_stream(stream: CudaStream) -> Self {
238        unsafe {
239            let mut inner = null_mut();
240            stream.device.switch_to();
241            cublasCreate_v2(&mut inner as *mut _).unwrap();
242            cublasSetStream_v2(inner, stream.inner()).unwrap();
243            CublasHandle { inner, stream }
244        }
245    }
246
247    pub fn stream(&self) -> &CudaStream {
248        &self.stream
249    }
250
251    pub unsafe fn inner(&self) -> cublasHandle_t {
252        self.inner
253    }
254}
255
256#[derive(Debug)]
257pub struct CublasLtHandle {
258    inner: cublasLtHandle_t,
259}
260
261impl Drop for CublasLtHandle {
262    fn drop(&mut self) {
263        unsafe { cublasLtDestroy(self.inner).unwrap_in_drop() }
264    }
265}
266
267impl CublasLtHandle {
268    pub fn new(device: CudaDevice) -> Self {
269        unsafe {
270            let mut inner = null_mut();
271            device.switch_to();
272            cublasLtCreate(&mut inner as *mut _).unwrap();
273            CublasLtHandle { inner }
274        }
275    }
276
277    pub unsafe fn inner(&self) -> cublasLtHandle_t {
278        self.inner
279    }
280}