trtx 0.7.0+rtx1.5

Safe Rust bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
use std::pin::Pin;

use crate::{network::check_network, Error, NetworkDefinition, Result};
use trtx_sys::{nvinfer1, DataType, Dims64};

/// [`trtx_sys::nvinfer1::ITensor`] — C++ [`nvinfer1::ITensor`](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/classnvinfer1_1_1_i_tensor.html).
#[derive(Clone, Copy)]
pub struct Tensor<'network> {
    pub(crate) inner: *mut nvinfer1::ITensor,
    pub(crate) network: &'network nvinfer1::INetworkDefinition,
}

impl std::fmt::Debug for Tensor<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Tensor")
            .field("inner", &format!("{:x}", self.inner as usize))
            .finish_non_exhaustive()
    }
}

impl Tensor<'_> {
    pub(crate) unsafe fn new(
        network: *const nvinfer1::INetworkDefinition,
        ptr: *mut nvinfer1::ITensor,
    ) -> Result<Self> {
        unsafe {
            if ptr.is_null() {
                return Err(Error::GetTensorFailed);
            }
            Ok(Self {
                inner: ptr,
                network: network.as_ref().unwrap(),
            })
        }
    }

    #[allow(clippy::mut_from_ref)]
    pub(crate) fn pin_mut(&self) -> Pin<&mut nvinfer1::ITensor> {
        unsafe { Pin::new_unchecked(self.inner.as_mut().unwrap()) }
    }
    pub(crate) fn as_ref(&self) -> &nvinfer1::ITensor {
        unsafe { self.inner.as_ref().unwrap() }
    }
    #[allow(clippy::mut_from_ref)]
    pub(crate) fn as_mut(&self) -> &mut nvinfer1::ITensor {
        unsafe { self.inner.as_mut().unwrap() }
    }

    /// See [nvinfer1::ITensor::getName]
    pub fn name(&self, network: &NetworkDefinition) -> Result<String> {
        check_network!(network, self);
        let name_ptr = self.as_ref().getName();
        if name_ptr.is_null() {
            return Err(Error::Runtime("Failed to get tensor name".to_string()));
        }
        unsafe { Ok(std::ffi::CStr::from_ptr(name_ptr).to_str()?.to_string()) }
    }

    /// See [nvinfer1::ITensor::setName]
    pub fn set_name(&self, network: &'_ mut NetworkDefinition, name: &str) -> Result<()> {
        check_network!(network, self);
        let name_cstr = std::ffi::CString::new(name)?;
        unsafe {
            self.pin_mut().setName(name_cstr.as_ptr());
        }
        Ok(())
    }

    /// See [nvinfer1::ITensor::setDimensionName]
    pub fn set_dimension_name(
        &self,
        network: &'_ mut NetworkDefinition,
        index: i32,
        name: &str,
    ) -> Result<()> {
        check_network!(network, self);
        let name_cstr = std::ffi::CString::new(name)?;
        unsafe {
            self.pin_mut().setDimensionName(index, name_cstr.as_ptr());
        }
        Ok(())
    }

    /// See [nvinfer1::ITensor::setDimensions]
    pub fn set_dimensions(&mut self, network: &mut NetworkDefinition, dims: &[i64]) {
        check_network!(network, self);
        let dims = Dims64::from_slice(dims);
        self.pin_mut().setDimensions(&dims);
    }

    /// See [nvinfer1::ITensor::getDimensions]
    pub fn dimensions(&self, network: &NetworkDefinition) -> Result<Vec<i64>> {
        check_network!(network, self);
        let result = self.as_ref().getDimensions();
        if result.nbDims < 0 || result.nbDims >= 8 {
            let tensor_name = self
                .name(network)
                .unwrap_or_else(|e| format!("<failed to get tensor name: {e}>"));
            return Err(Error::FailedToGetTensorDimensions { tensor_name });
        }
        Ok(result.d[..result.nbDims as usize].to_vec())
    }

    /// See [nvinfer1::ITensor::isExecutionTensor]
    pub fn is_execution_tensor(&self, network: &NetworkDefinition) -> bool {
        check_network!(network, self);
        self.as_ref().isExecutionTensor()
    }

    /// See [nvinfer1::ITensor::isShapeTensor]
    pub fn is_shape_tensor(&self, network: &NetworkDefinition) -> bool {
        check_network!(network, self);
        self.as_ref().isShapeTensor()
    }

    /// See [nvinfer1::ITensor::isNetworkInput]
    pub fn is_network_input(&self, network: &NetworkDefinition) -> bool {
        check_network!(network, self);
        self.as_ref().isNetworkInput()
    }

    /// See [nvinfer1::ITensor::isNetworkOutput]
    pub fn is_network_output(&self, network: &NetworkDefinition) -> bool {
        check_network!(network, self);
        self.as_ref().isNetworkOutput()
    }

    /// See [nvinfer1::ITensor::getType]
    pub fn data_type(&self, network: &NetworkDefinition) -> DataType {
        check_network!(network, self);
        self.as_ref().getType().into()
    }

    /// See [nvinfer1::ITensor::getType]
    pub fn r#type(&self, network: &NetworkDefinition) -> DataType {
        check_network!(network, self);
        self.as_ref().getType().into()
    }

    /// See [nvinfer1::ITensor::getType]
    pub fn get_type(&self, network: &NetworkDefinition) -> DataType {
        self.r#type(network)
    }

    /// Set allowed tensor formats (bitmask of TensorFormat). E.g. 1u32 << TensorFormat::kHWC for channels-last.
    /// TensorRT may insert reformat layers when connecting tensors with different formats.
    ///
    /// See [nvinfer1::ITensor::setAllowedFormats]
    pub fn set_allowed_formats(
        &mut self,
        network: &mut NetworkDefinition,
        formats: u32,
    ) -> Result<()> {
        check_network!(network, self);
        self.pin_mut().setAllowedFormats(formats);
        Ok(())
    }
}