use std::pin::Pin;
use crate::{network::check_network, Error, NetworkDefinition, Result};
use trtx_sys::{nvinfer1, DataType, Dims64};
#[derive(Clone, Copy)]
pub struct Tensor<'network> {
pub(crate) inner: *mut nvinfer1::ITensor,
pub(crate) network: &'network nvinfer1::INetworkDefinition,
}
impl std::fmt::Debug for Tensor<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tensor")
.field("inner", &format!("{:x}", self.inner as usize))
.finish_non_exhaustive()
}
}
impl Tensor<'_> {
pub(crate) unsafe fn new(
network: *const nvinfer1::INetworkDefinition,
ptr: *mut nvinfer1::ITensor,
) -> Result<Self> {
unsafe {
if ptr.is_null() {
return Err(Error::GetTensorFailed);
}
Ok(Self {
inner: ptr,
network: network.as_ref().unwrap(),
})
}
}
#[allow(clippy::mut_from_ref)]
pub(crate) fn pin_mut(&self) -> Pin<&mut nvinfer1::ITensor> {
unsafe { Pin::new_unchecked(self.inner.as_mut().unwrap()) }
}
pub(crate) fn as_ref(&self) -> &nvinfer1::ITensor {
unsafe { self.inner.as_ref().unwrap() }
}
#[allow(clippy::mut_from_ref)]
pub(crate) fn as_mut(&self) -> &mut nvinfer1::ITensor {
unsafe { self.inner.as_mut().unwrap() }
}
pub fn name(&self, network: &NetworkDefinition) -> Result<String> {
check_network!(network, self);
let name_ptr = self.as_ref().getName();
if name_ptr.is_null() {
return Err(Error::Runtime("Failed to get tensor name".to_string()));
}
unsafe { Ok(std::ffi::CStr::from_ptr(name_ptr).to_str()?.to_string()) }
}
pub fn set_name(&self, network: &'_ mut NetworkDefinition, name: &str) -> Result<()> {
check_network!(network, self);
let name_cstr = std::ffi::CString::new(name)?;
unsafe {
self.pin_mut().setName(name_cstr.as_ptr());
}
Ok(())
}
pub fn set_dimension_name(
&self,
network: &'_ mut NetworkDefinition,
index: i32,
name: &str,
) -> Result<()> {
check_network!(network, self);
let name_cstr = std::ffi::CString::new(name)?;
unsafe {
self.pin_mut().setDimensionName(index, name_cstr.as_ptr());
}
Ok(())
}
pub fn set_dimensions(&mut self, network: &mut NetworkDefinition, dims: &[i64]) {
check_network!(network, self);
let dims = Dims64::from_slice(dims);
self.pin_mut().setDimensions(&dims);
}
pub fn dimensions(&self, network: &NetworkDefinition) -> Result<Vec<i64>> {
check_network!(network, self);
let result = self.as_ref().getDimensions();
if result.nbDims < 0 || result.nbDims >= 8 {
let tensor_name = self
.name(network)
.unwrap_or_else(|e| format!("<failed to get tensor name: {e}>"));
return Err(Error::FailedToGetTensorDimensions { tensor_name });
}
Ok(result.d[..result.nbDims as usize].to_vec())
}
pub fn is_execution_tensor(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.as_ref().isExecutionTensor()
}
pub fn is_shape_tensor(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.as_ref().isShapeTensor()
}
pub fn is_network_input(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.as_ref().isNetworkInput()
}
pub fn is_network_output(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.as_ref().isNetworkOutput()
}
pub fn data_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.as_ref().getType().into()
}
pub fn r#type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.as_ref().getType().into()
}
pub fn get_type(&self, network: &NetworkDefinition) -> DataType {
self.r#type(network)
}
pub fn set_allowed_formats(
&mut self,
network: &mut NetworkDefinition,
formats: u32,
) -> Result<()> {
check_network!(network, self);
self.pin_mut().setAllowedFormats(formats);
Ok(())
}
}