1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use super::utils::DataType;
use super::{Error, API};
use crate::ffi::*;
#[derive(Debug, Clone)]
pub struct TensorDescriptor {
id: cudnnTensorDescriptor_t,
}
pub fn tensor_vec_id_c(tensor_vec: &[TensorDescriptor]) -> Vec<cudnnTensorDescriptor_t> {
tensor_vec.iter().map(|tensor| *tensor.id_c()).collect()
}
impl Drop for TensorDescriptor {
#[allow(unused_must_use)]
fn drop(&mut self) {
API::destroy_tensor_descriptor(*self.id_c());
}
}
impl TensorDescriptor {
pub fn new(
dims: &[i32],
strides: &[i32],
data_type: DataType,
) -> Result<TensorDescriptor, Error> {
let nb_dims = dims.len() as i32;
if nb_dims < 3 {
return Err(Error::BadParam(
"CUDA cuDNN only supports Tensors with 3 to 8 dimensions.",
));
}
let dims_ptr = dims.as_ptr();
let strides_ptr = strides.as_ptr();
let generic_tensor_desc = API::create_tensor_descriptor()?;
let data_type = API::cudnn_data_type(data_type);
API::set_tensor_descriptor(
generic_tensor_desc,
data_type,
nb_dims,
dims_ptr,
strides_ptr,
)?;
Ok(TensorDescriptor::from_c(generic_tensor_desc))
}
pub fn from_c(id: cudnnTensorDescriptor_t) -> TensorDescriptor {
TensorDescriptor { id }
}
pub fn id_c(&self) -> &cudnnTensorDescriptor_t {
&self.id
}
}