cudarc/cudnn/safe/
core.rs

1use crate::{
2    cudnn::{result, result::CudnnError, sys},
3    driver::CudaStream,
4};
5
6use std::{marker::PhantomData, sync::Arc};
7
8/// A handle to cuDNN.
9///
10/// This type is not send/sync because of <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#thread-safety>
11#[derive(Debug)]
12pub struct Cudnn {
13    pub(crate) handle: sys::cudnnHandle_t,
14    pub(crate) stream: Arc<CudaStream>,
15}
16
17impl Cudnn {
18    /// Creates a new cudnn handle and sets the stream to the `device`'s stream.
19    #[allow(clippy::arc_with_non_send_sync)]
20    pub fn new(stream: Arc<CudaStream>) -> Result<Arc<Self>, CudnnError> {
21        stream.context().bind_to_thread().unwrap();
22        let handle = result::create_handle()?;
23        unsafe { result::set_stream(handle, stream.cu_stream as *mut _) }?;
24        Ok(Arc::new(Self { handle, stream }))
25    }
26
27    /// Sets the handle's current to either the stream specified, or the device's default work
28    /// stream.
29    ///
30    /// # Safety
31    /// This is unsafe because you can end up scheduling multiple concurrent kernels that all
32    /// write to the same memory address.
33    pub unsafe fn set_stream(&mut self, stream: Arc<CudaStream>) -> Result<(), CudnnError> {
34        self.stream = stream;
35        unsafe { result::set_stream(self.handle, self.stream.cu_stream as *mut _) }
36    }
37}
38
39impl Drop for Cudnn {
40    fn drop(&mut self) {
41        let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
42        if !handle.is_null() {
43            unsafe { result::destroy_handle(handle) }.unwrap();
44        }
45    }
46}
47
48/// Maps a rust type to a [sys::cudnnDataType_t]
49pub trait CudnnDataType {
50    const DATA_TYPE: sys::cudnnDataType_t;
51
52    /// Certain CUDNN data types have a scaling parameter (usually called alpha/beta)
53    /// that is a different type. See [nvidia docs](https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters)
54    /// for more info, but basically f16 has a scalar of f32.
55    type Scalar;
56
57    /// Converts the type into the scaling parameter type. See [Self::Scalar].
58    fn into_scaling_parameter(self) -> Self::Scalar;
59}
60
61macro_rules! cudnn_dtype {
62    ($RustTy:ty, $CudnnTy:tt) => {
63        impl CudnnDataType for $RustTy {
64            const DATA_TYPE: sys::cudnnDataType_t = sys::cudnnDataType_t::$CudnnTy;
65            type Scalar = Self;
66            fn into_scaling_parameter(self) -> Self::Scalar {
67                self
68            }
69        }
70    };
71}
72
73cudnn_dtype!(f32, CUDNN_DATA_FLOAT);
74cudnn_dtype!(f64, CUDNN_DATA_DOUBLE);
75cudnn_dtype!(i8, CUDNN_DATA_INT8);
76cudnn_dtype!(i32, CUDNN_DATA_INT32);
77cudnn_dtype!(i64, CUDNN_DATA_INT64);
78cudnn_dtype!(u8, CUDNN_DATA_UINT8);
79#[cfg(not(feature = "cuda-11040"))]
80cudnn_dtype!(bool, CUDNN_DATA_BOOLEAN);
81
82#[cfg(feature = "f16")]
83impl CudnnDataType for half::f16 {
84    const DATA_TYPE: sys::cudnnDataType_t = sys::cudnnDataType_t::CUDNN_DATA_HALF;
85    type Scalar = f32;
86    fn into_scaling_parameter(self) -> Self::Scalar {
87        self.to_f32()
88    }
89}
90#[cfg(feature = "f16")]
91impl CudnnDataType for half::bf16 {
92    const DATA_TYPE: sys::cudnnDataType_t = sys::cudnnDataType_t::CUDNN_DATA_BFLOAT16;
93    type Scalar = f32;
94    fn into_scaling_parameter(self) -> Self::Scalar {
95        self.to_f32()
96    }
97}
98
99/// A descriptor of a tensor. Create with:
100/// 1. [`Cudnn::create_4d_tensor()`]
101/// 2. [`Cudnn::create_4d_tensor_ex()`]
102/// 3. [`Cudnn::create_nd_tensor()`]
103#[derive(Debug)]
104pub struct TensorDescriptor<T> {
105    pub(crate) desc: sys::cudnnTensorDescriptor_t,
106    #[allow(unused)]
107    pub(crate) handle: Arc<Cudnn>,
108    pub(crate) marker: PhantomData<T>,
109}
110
111impl Cudnn {
112    /// Creates a 4d tensor descriptor.
113    pub fn create_4d_tensor<T: CudnnDataType>(
114        self: &Arc<Cudnn>,
115        format: sys::cudnnTensorFormat_t,
116        dims: [std::ffi::c_int; 4],
117    ) -> Result<TensorDescriptor<T>, CudnnError> {
118        let desc = result::create_tensor_descriptor()?;
119        let desc = TensorDescriptor {
120            desc,
121            handle: self.clone(),
122            marker: PhantomData,
123        };
124        unsafe { result::set_tensor4d_descriptor(desc.desc, format, T::DATA_TYPE, dims) }?;
125        Ok(desc)
126    }
127
128    /// Creates a 4d tensor descriptor.
129    pub fn create_4d_tensor_ex<T: CudnnDataType>(
130        self: &Arc<Cudnn>,
131        dims: [std::ffi::c_int; 4],
132        strides: [std::ffi::c_int; 4],
133    ) -> Result<TensorDescriptor<T>, CudnnError> {
134        let desc = result::create_tensor_descriptor()?;
135        let desc = TensorDescriptor {
136            desc,
137            handle: self.clone(),
138            marker: PhantomData,
139        };
140        unsafe { result::set_tensor4d_descriptor_ex(desc.desc, T::DATA_TYPE, dims, strides) }?;
141        Ok(desc)
142    }
143
144    /// Creates an nd (at LEAST 4d) tensor descriptor.
145    pub fn create_nd_tensor<T: CudnnDataType>(
146        self: &Arc<Cudnn>,
147        dims: &[std::ffi::c_int],
148        strides: &[std::ffi::c_int],
149    ) -> Result<TensorDescriptor<T>, CudnnError> {
150        assert!(dims.len() >= 4);
151        assert_eq!(dims.len(), strides.len());
152        let desc = result::create_tensor_descriptor()?;
153        let desc = TensorDescriptor {
154            desc,
155            handle: self.clone(),
156            marker: PhantomData,
157        };
158        unsafe {
159            result::set_tensornd_descriptor(
160                desc.desc,
161                T::DATA_TYPE,
162                dims.len() as std::ffi::c_int,
163                dims.as_ptr(),
164                strides.as_ptr(),
165            )
166        }?;
167        Ok(desc)
168    }
169}
170
171impl<T> Drop for TensorDescriptor<T> {
172    fn drop(&mut self) {
173        let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
174        if !desc.is_null() {
175            unsafe { result::destroy_tensor_descriptor(desc) }.unwrap()
176        }
177    }
178}