cudarc/cudnn/safe/
core.rs1use crate::{
2 cudnn::{result, result::CudnnError, sys},
3 driver::CudaStream,
4};
5
6use std::{marker::PhantomData, sync::Arc};
7
8#[derive(Debug)]
12pub struct Cudnn {
13 pub(crate) handle: sys::cudnnHandle_t,
14 pub(crate) stream: Arc<CudaStream>,
15}
16
17impl Cudnn {
18 #[allow(clippy::arc_with_non_send_sync)]
20 pub fn new(stream: Arc<CudaStream>) -> Result<Arc<Self>, CudnnError> {
21 stream.context().bind_to_thread().unwrap();
22 let handle = result::create_handle()?;
23 unsafe { result::set_stream(handle, stream.cu_stream as *mut _) }?;
24 Ok(Arc::new(Self { handle, stream }))
25 }
26
27 pub unsafe fn set_stream(&mut self, stream: Arc<CudaStream>) -> Result<(), CudnnError> {
34 self.stream = stream;
35 unsafe { result::set_stream(self.handle, self.stream.cu_stream as *mut _) }
36 }
37}
38
39impl Drop for Cudnn {
40 fn drop(&mut self) {
41 let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
42 if !handle.is_null() {
43 unsafe { result::destroy_handle(handle) }.unwrap();
44 }
45 }
46}
47
48pub trait CudnnDataType {
50 const DATA_TYPE: sys::cudnnDataType_t;
51
52 type Scalar;
56
57 fn into_scaling_parameter(self) -> Self::Scalar;
59}
60
61macro_rules! cudnn_dtype {
62 ($RustTy:ty, $CudnnTy:tt) => {
63 impl CudnnDataType for $RustTy {
64 const DATA_TYPE: sys::cudnnDataType_t = sys::cudnnDataType_t::$CudnnTy;
65 type Scalar = Self;
66 fn into_scaling_parameter(self) -> Self::Scalar {
67 self
68 }
69 }
70 };
71}
72
73cudnn_dtype!(f32, CUDNN_DATA_FLOAT);
74cudnn_dtype!(f64, CUDNN_DATA_DOUBLE);
75cudnn_dtype!(i8, CUDNN_DATA_INT8);
76cudnn_dtype!(i32, CUDNN_DATA_INT32);
77cudnn_dtype!(i64, CUDNN_DATA_INT64);
78cudnn_dtype!(u8, CUDNN_DATA_UINT8);
79#[cfg(not(feature = "cuda-11040"))]
80cudnn_dtype!(bool, CUDNN_DATA_BOOLEAN);
81
82#[cfg(feature = "f16")]
83impl CudnnDataType for half::f16 {
84 const DATA_TYPE: sys::cudnnDataType_t = sys::cudnnDataType_t::CUDNN_DATA_HALF;
85 type Scalar = f32;
86 fn into_scaling_parameter(self) -> Self::Scalar {
87 self.to_f32()
88 }
89}
90#[cfg(feature = "f16")]
91impl CudnnDataType for half::bf16 {
92 const DATA_TYPE: sys::cudnnDataType_t = sys::cudnnDataType_t::CUDNN_DATA_BFLOAT16;
93 type Scalar = f32;
94 fn into_scaling_parameter(self) -> Self::Scalar {
95 self.to_f32()
96 }
97}
98
99#[derive(Debug)]
104pub struct TensorDescriptor<T> {
105 pub(crate) desc: sys::cudnnTensorDescriptor_t,
106 #[allow(unused)]
107 pub(crate) handle: Arc<Cudnn>,
108 pub(crate) marker: PhantomData<T>,
109}
110
111impl Cudnn {
112 pub fn create_4d_tensor<T: CudnnDataType>(
114 self: &Arc<Cudnn>,
115 format: sys::cudnnTensorFormat_t,
116 dims: [std::ffi::c_int; 4],
117 ) -> Result<TensorDescriptor<T>, CudnnError> {
118 let desc = result::create_tensor_descriptor()?;
119 let desc = TensorDescriptor {
120 desc,
121 handle: self.clone(),
122 marker: PhantomData,
123 };
124 unsafe { result::set_tensor4d_descriptor(desc.desc, format, T::DATA_TYPE, dims) }?;
125 Ok(desc)
126 }
127
128 pub fn create_4d_tensor_ex<T: CudnnDataType>(
130 self: &Arc<Cudnn>,
131 dims: [std::ffi::c_int; 4],
132 strides: [std::ffi::c_int; 4],
133 ) -> Result<TensorDescriptor<T>, CudnnError> {
134 let desc = result::create_tensor_descriptor()?;
135 let desc = TensorDescriptor {
136 desc,
137 handle: self.clone(),
138 marker: PhantomData,
139 };
140 unsafe { result::set_tensor4d_descriptor_ex(desc.desc, T::DATA_TYPE, dims, strides) }?;
141 Ok(desc)
142 }
143
144 pub fn create_nd_tensor<T: CudnnDataType>(
146 self: &Arc<Cudnn>,
147 dims: &[std::ffi::c_int],
148 strides: &[std::ffi::c_int],
149 ) -> Result<TensorDescriptor<T>, CudnnError> {
150 assert!(dims.len() >= 4);
151 assert_eq!(dims.len(), strides.len());
152 let desc = result::create_tensor_descriptor()?;
153 let desc = TensorDescriptor {
154 desc,
155 handle: self.clone(),
156 marker: PhantomData,
157 };
158 unsafe {
159 result::set_tensornd_descriptor(
160 desc.desc,
161 T::DATA_TYPE,
162 dims.len() as std::ffi::c_int,
163 dims.as_ptr(),
164 strides.as_ptr(),
165 )
166 }?;
167 Ok(desc)
168 }
169}
170
171impl<T> Drop for TensorDescriptor<T> {
172 fn drop(&mut self) {
173 let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
174 if !desc.is_null() {
175 unsafe { result::destroy_tensor_descriptor(desc) }.unwrap()
176 }
177 }
178}