tensorrt 0.1.0

Rust wrapper for NVIDIA TensorRT
use crate::error::{TRTError, TRTResult};
use cuda_rs::{memory::DeviceMemory, stream::CuStream};
use tensorrt_rs_sys::runtime::DataType;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Shape(pub Vec<i32>);

impl Shape {
    pub fn size(&self) -> usize {
        self.0
            .iter()
            .map(|x| *x as usize)
            .product::<usize>()
    }
}

pub struct Tensor {
    mem: DeviceMemory,
    shape: Shape,
    dtype: DataType,
}

impl Tensor {
    pub fn empty(shape: &Shape, dtype: DataType, stream: &CuStream) -> TRTResult<Self> {
        let mem_size = shape.size() * dtype.get_elem_size();
        let mem = DeviceMemory::new(mem_size, stream)?;
        Ok(Self { mem, shape: shape.clone(), dtype })
    }

    pub fn from_memory(mem: DeviceMemory, shape: &Shape, dtype: DataType) -> Self {
        Self { mem, shape: shape.clone(), dtype }
    }

    pub fn get_memory(&self) -> &DeviceMemory {
        &self.mem
    }

    pub fn from_raw_ptr(
        ptr: usize, shape: &Shape, dtype: DataType, stream: &CuStream
    ) -> Self {
        let mem_size = shape.size() * dtype.get_elem_size();
        let mem = unsafe {
            DeviceMemory::from_raw(ptr as _, mem_size, stream)
        };
        Self { mem, shape: shape.clone(), dtype }
    }

    pub unsafe fn get_raw_ptr(&self) -> usize {
        self.mem.get_raw() as usize
    }

    pub fn shape(&self) -> &Shape {
        &self.shape
    }

    pub fn dtype(&self) -> DataType {
        self.dtype
    }

    pub unsafe fn reset_shape(&mut self, shape: &Shape) -> TRTResult<()> {
        if self.shape.size() < shape.size() {
            return Err(TRTError::ResetShapesError);
        }
        self.shape = shape.clone();
        Ok(())
    }

    pub fn copy_from(&mut self, src: &Self, stream: Option<&CuStream>) -> TRTResult<()> {
        if self.shape != src.shape {
            return Err(TRTError::ShapeMismatch);
        }
        if self.dtype != src.dtype {
            return Err(TRTError::DTypeMismatch);
        }
        self.mem.copy_from(&src.mem, stream)?;

        Ok(())
    }
}