rcudnn/
tensor_descriptor.rs

1//! Defines a Tensor Descriptor.
2//!
3//! A Tensor Descriptor is used to hold information about the data,
4//! which is needed for the operations to obtain information about
5//! the structure and dimensionality of the data.
6
7use super::utils::DataType;
8use super::{Error, API};
9use crate::ffi::*;
10
11#[derive(Debug, Clone)]
12/// Describes a TensorDescriptor.
13pub struct TensorDescriptor {
14    id: cudnnTensorDescriptor_t,
15}
16
17/// Return C Handle for a Vector of Tensor Descriptors
18pub fn tensor_vec_id_c(tensor_vec: &[TensorDescriptor]) -> Vec<cudnnTensorDescriptor_t> {
19    tensor_vec.iter().map(|tensor| *tensor.id_c()).collect()
20}
21
22impl Drop for TensorDescriptor {
23    #[allow(unused_must_use)]
24    fn drop(&mut self) {
25        API::destroy_tensor_descriptor(*self.id_c());
26    }
27}
28
29impl TensorDescriptor {
30    /// Initializes a new CUDA cuDNN Tensor Descriptor.
31    pub fn new(
32        dims: &[i32],
33        strides: &[i32],
34        data_type: DataType,
35    ) -> Result<TensorDescriptor, Error> {
36        let nb_dims = dims.len() as i32;
37        if nb_dims < 3 {
38            return Err(Error::BadParam(
39                "CUDA cuDNN only supports Tensors with 3 to 8 dimensions.",
40            ));
41        }
42
43        let dims_ptr = dims.as_ptr();
44        let strides_ptr = strides.as_ptr();
45        let generic_tensor_desc = API::create_tensor_descriptor()?;
46        let data_type = API::cudnn_data_type(data_type);
47
48        API::set_tensor_descriptor(
49            generic_tensor_desc,
50            data_type,
51            nb_dims,
52            dims_ptr,
53            strides_ptr,
54        )?;
55        Ok(TensorDescriptor::from_c(generic_tensor_desc))
56    }
57
58    /// Initializes a new CUDA cuDNN Tensor Descriptor from its C type.
59    pub fn from_c(id: cudnnTensorDescriptor_t) -> TensorDescriptor {
60        TensorDescriptor { id }
61    }
62
63    /// Returns the CUDA cuDNN Tensor Descriptor as its C type.
64    pub fn id_c(&self) -> &cudnnTensorDescriptor_t {
65        &self.id
66    }
67}