use super::{API, Error};
use ffi::*;
#[derive(Debug, Clone)]
pub struct TensorDescriptor {
id: isize,
}
impl Drop for TensorDescriptor {
#[allow(unused_must_use)]
fn drop(&mut self) {
API::destroy_tensor_descriptor(self.id_c());
}
}
impl TensorDescriptor {
pub fn new(dims: &[i32], strides: &[i32], data_type: DataType) -> Result<TensorDescriptor, Error> {
let nb_dims = dims.len() as i32;
if nb_dims < 3 { return Err(Error::BadParam("CUDA cuDNN only supports Tensors with 3 to 8 dimensions.")) }
let dims_ptr = dims.as_ptr();
let strides_ptr = strides.as_ptr();
let generic_tensor_desc = try!(API::create_tensor_descriptor());
match data_type {
DataType::Float => {
let d_type = cudnnDataType_t::CUDNN_DATA_FLOAT;
try!(API::set_tensor_descriptor(generic_tensor_desc, d_type, nb_dims, dims_ptr, strides_ptr));
Ok(TensorDescriptor::from_c(generic_tensor_desc))
},
DataType::Double => {
let d_type = cudnnDataType_t::CUDNN_DATA_DOUBLE;
try!(API::set_tensor_descriptor(generic_tensor_desc, d_type, nb_dims, dims_ptr, strides_ptr));
Ok(TensorDescriptor::from_c(generic_tensor_desc))
},
DataType::Half => {
let d_type = cudnnDataType_t::CUDNN_DATA_HALF;
try!(API::set_tensor_descriptor(generic_tensor_desc, d_type, nb_dims, dims_ptr, strides_ptr));
Ok(TensorDescriptor::from_c(generic_tensor_desc))
}
}
}
pub fn from_c(id: cudnnTensorDescriptor_t) -> TensorDescriptor {
TensorDescriptor { id: id as isize }
}
pub fn id(&self) -> isize {
self.id
}
pub fn id_c(&self) -> cudnnTensorDescriptor_t {
self.id as cudnnTensorDescriptor_t
}
}
#[derive(Debug, Copy, Clone)]
pub enum DataType {
Float,
Double,
Half,
}