nncombinator/cuda/cudnn/
tensor.rs1use std::marker::PhantomData;
4use rcudnn::API;
5use rcudnn_sys::{cudnnSetTensor4dDescriptor, cudnnTensorDescriptor_t,cudnnStatus_t};
6use rcudnn_sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
7use crate::cuda::DataTypeInfo;
8
9pub struct CudnnTensor4dDescriptor<T> where T: DataTypeInfo {
11 id: cudnnTensorDescriptor_t,
12 t:PhantomData<T>
13}
14impl<T> CudnnTensor4dDescriptor<T> where T: DataTypeInfo {
15 pub fn new(n:usize,c:usize,h:usize,w:usize) -> Result<CudnnTensor4dDescriptor<T>,rcudnn::Error> where T: DataTypeInfo {
23 let desc = API::create_tensor_descriptor()?;
24
25 unsafe {
26 match cudnnSetTensor4dDescriptor(desc,CUDNN_TENSOR_NCHW,T::cudnn_raw_data_type(),
27 n as libc::c_int,c as libc::c_int, h as libc::c_int, w as libc::c_int) {
28 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
29 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
30 return Err(rcudnn::Error::BadParam("The parameter passed to the vs is invalid."));
31 },
32 status => {
33 return Err(rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64));
34 }
35 }
36 }
37
38 Ok(CudnnTensor4dDescriptor {
39 id: desc,
40 t:PhantomData::<T>
41 })
42 }
43
44 pub fn id_c(&self) -> &cudnnTensorDescriptor_t {
45 &self.id
46 }
47}
48impl<T> Drop for CudnnTensor4dDescriptor<T> where T: DataTypeInfo {
49 #[allow(unused_must_use)]
50 fn drop(&mut self) {
51 API::destroy_tensor_descriptor(*self.id_c());
52 }
53}