nncombinator/cuda/cudnn/
tensor.rs

1//! Wrapper implementation for tensors used in cudnn
2
3use std::marker::PhantomData;
4use rcudnn::API;
5use rcudnn_sys::{cudnnSetTensor4dDescriptor, cudnnTensorDescriptor_t,cudnnStatus_t};
6use rcudnn_sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
7use crate::cuda::DataTypeInfo;
8
9/// Wrapper for cudnnTensorDescriptor_t initialized by cudnnSetTensor4dDescriptors
10pub struct CudnnTensor4dDescriptor<T> where T: DataTypeInfo {
11    id: cudnnTensorDescriptor_t,
12    t:PhantomData<T>
13}
14impl<T> CudnnTensor4dDescriptor<T> where T: DataTypeInfo {
15    /// Create an instance of CudnnTensor4dDescriptor
16    /// # Arguments
17    ///
18    /// * `n` - batch size
19    /// * `c` - Number of Channels
20    /// * `h` - height
21    /// * `w` - width
22    pub fn new(n:usize,c:usize,h:usize,w:usize) -> Result<CudnnTensor4dDescriptor<T>,rcudnn::Error> where T: DataTypeInfo {
23        let desc = API::create_tensor_descriptor()?;
24
25        unsafe {
26            match cudnnSetTensor4dDescriptor(desc,CUDNN_TENSOR_NCHW,T::cudnn_raw_data_type(),
27                                             n as libc::c_int,c as libc::c_int, h as libc::c_int, w as libc::c_int) {
28                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
29                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
30                    return Err(rcudnn::Error::BadParam("The parameter passed to the vs is invalid."));
31                },
32                status => {
33                    return Err(rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64));
34                }
35            }
36        }
37
38        Ok(CudnnTensor4dDescriptor {
39            id: desc,
40            t:PhantomData::<T>
41        })
42    }
43
44    pub fn id_c(&self) -> &cudnnTensorDescriptor_t {
45        &self.id
46    }
47}
48impl<T> Drop for CudnnTensor4dDescriptor<T> where T: DataTypeInfo {
49    #[allow(unused_must_use)]
50    fn drop(&mut self) {
51        API::destroy_tensor_descriptor(*self.id_c());
52    }
53}