rcudnn/
tensor_descriptor.rs1use super::utils::DataType;
8use super::{Error, API};
9use crate::ffi::*;
10
11#[derive(Debug, Clone)]
12pub struct TensorDescriptor {
14 id: cudnnTensorDescriptor_t,
15}
16
17pub fn tensor_vec_id_c(tensor_vec: &[TensorDescriptor]) -> Vec<cudnnTensorDescriptor_t> {
19 tensor_vec.iter().map(|tensor| *tensor.id_c()).collect()
20}
21
22impl Drop for TensorDescriptor {
23 #[allow(unused_must_use)]
24 fn drop(&mut self) {
25 API::destroy_tensor_descriptor(*self.id_c());
26 }
27}
28
29impl TensorDescriptor {
30 pub fn new(
32 dims: &[i32],
33 strides: &[i32],
34 data_type: DataType,
35 ) -> Result<TensorDescriptor, Error> {
36 let nb_dims = dims.len() as i32;
37 if nb_dims < 3 {
38 return Err(Error::BadParam(
39 "CUDA cuDNN only supports Tensors with 3 to 8 dimensions.",
40 ));
41 }
42
43 let dims_ptr = dims.as_ptr();
44 let strides_ptr = strides.as_ptr();
45 let generic_tensor_desc = API::create_tensor_descriptor()?;
46 let data_type = API::cudnn_data_type(data_type);
47
48 API::set_tensor_descriptor(
49 generic_tensor_desc,
50 data_type,
51 nb_dims,
52 dims_ptr,
53 strides_ptr,
54 )?;
55 Ok(TensorDescriptor::from_c(generic_tensor_desc))
56 }
57
58 pub fn from_c(id: cudnnTensorDescriptor_t) -> TensorDescriptor {
60 TensorDescriptor { id }
61 }
62
63 pub fn id_c(&self) -> &cudnnTensorDescriptor_t {
65 &self.id
66 }
67}