use std::rc::Rc;
use std::{ffi::CStr, marker::PhantomData};
use crate::engine_inspector::EngineInspector;
use crate::error::PropertySetAttempt;
use crate::host_memory::HostMemory;
use crate::runtime_config::RuntimeConfig;
use crate::{DataType, Error, ExecutionContext, Result};
use autocxx::cxx::UniquePtr;
use trtx_sys::{
nvinfer1::{self, ICudaEngine},
SerializationFlag, TensorIOMode,
};
use trtx_sys::{TensorFormat, TensorLocation};
pub struct SerializationConfig<'cuda_engine> {
inner: UniquePtr<nvinfer1::ISerializationConfig>,
_runtime: PhantomData<&'cuda_engine nvinfer1::ICudaEngine>,
}
impl std::fmt::Debug for SerializationConfig<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SerializationConfig")
.field("inner", &format!("{:x}", self.inner.as_ptr() as usize))
.finish_non_exhaustive()
}
}
impl SerializationConfig<'_> {
pub fn flag(&self, flag: SerializationFlag) -> bool {
self.inner.getFlag(flag.into())
}
pub fn flags(&self) -> u32 {
self.inner.getFlags()
}
pub fn set_flag(&mut self, flag: SerializationFlag) -> Result<()> {
if self.inner.pin_mut().setFlag(flag.into()) {
Ok(())
} else {
Err(Error::FailedToSetProperty(
PropertySetAttempt::SerializationFlag,
))
}
}
pub fn set_flags(&mut self, flags: u32) -> Result<()> {
if self.inner.pin_mut().setFlags(flags) {
Ok(())
} else {
Err(Error::FailedToSetProperty(
PropertySetAttempt::SerializationFlag,
))
}
}
pub fn clear_flag(&mut self, flag: SerializationFlag) -> Result<()> {
if self.inner.pin_mut().clearFlag(flag.into()) {
Ok(())
} else {
Err(Error::FailedToSetProperty(
PropertySetAttempt::SerializationFlag,
))
}
}
}
pub struct CudaEngine<'runtime> {
pub(crate) inner: UniquePtr<ICudaEngine>,
_runtime: PhantomData<&'runtime nvinfer1::IRuntime>,
}
impl std::fmt::Debug for CudaEngine<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaEngine")
.field("inner", &format!("{:x}", self.inner.as_ptr() as usize))
.finish_non_exhaustive()
}
}
impl<'engine> CudaEngine<'engine> {
pub(crate) unsafe fn from_ptr(ptr: *mut ICudaEngine) -> Self {
Self {
inner: unsafe { UniquePtr::from_raw(ptr) },
_runtime: Default::default(),
}
}
#[deprecated = "use nb_io_tensors instead"]
pub fn get_nb_io_tensors(&self) -> Result<i32> {
self.nb_io_tensors()
}
#[deprecated = "use tensor_shape instead"]
pub fn get_tensor_shape(&self, name: &str) -> Result<Vec<i64>> {
self.tensor_shape(name)
}
#[deprecated = "use io_tensor_name instead"]
pub fn get_tensor_name(&self, index: i32) -> Result<String> {
self.io_tensor_name(index)
}
#[deprecated = "use tensor_data_type instead"]
pub fn get_tensor_dtype(&self, name: &str) -> Result<DataType> {
self.tensor_data_type(name)
}
pub fn name(&self) -> Result<String> {
let ptr = self.inner.getName();
if ptr.is_null() {
return Ok(String::new());
}
Ok(unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string())
}
pub fn nb_io_tensors(&self) -> Result<i32> {
if cfg!(feature = "mock_runtime") {
Ok(0)
} else {
Ok(self.inner.getNbIOTensors())
}
}
pub fn io_tensor_name(&self, index: i32) -> Result<String> {
if cfg!(feature = "mock_runtime") {
Ok("mock_runtime".to_string())
} else {
let name_ptr = self.inner.getIOTensorName(index);
if name_ptr.is_null() {
return Err(Error::InvalidArgument("Invalid tensor index".to_string()));
}
Ok(unsafe { CStr::from_ptr(name_ptr) }.to_str()?.to_string())
}
}
pub fn tensor_shape(&self, name: &str) -> Result<Vec<i64>> {
let name_cstr = std::ffi::CString::new(name)?;
let dims = unsafe { self.inner.getTensorShape(name_cstr.as_ptr()) };
let nb_dims = dims.nbDims as usize;
if nb_dims > 8 {
return Err(Error::Runtime("Tensor has too many dimensions".to_string()));
}
Ok((0..nb_dims).map(|i| dims.d[i]).collect())
}
pub fn tensor_data_type(&self, name: &str) -> Result<DataType> {
if cfg!(not(feature = "mock_runtime")) {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorDataType(name_cstr.as_ptr()) }.into())
} else {
Ok(DataType::kFLOAT)
}
}
pub fn nb_layers(&self) -> Result<i32> {
Ok(self.inner.getNbLayers())
}
pub fn nb_optimization_profiles(&self) -> Result<i32> {
Ok(self.inner.getNbOptimizationProfiles())
}
pub fn nb_aux_streams(&self) -> Result<i32> {
Ok(self.inner.getNbAuxStreams())
}
pub fn tensor_io_mode(&self, name: &str) -> Result<TensorIOMode> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorIOMode(name_cstr.as_ptr()).into() })
}
pub fn tensor_location(&self, name: &str) -> Result<TensorLocation> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorLocation(name_cstr.as_ptr()).into() })
}
pub fn tensor_format(&self, name: &str) -> Result<TensorFormat> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorFormat(name_cstr.as_ptr()).into() })
}
pub fn tensor_format_for_profile(
&self,
name: &str,
profile_index: i32,
) -> Result<nvinfer1::TensorFormat> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe {
self.inner
.getTensorFormat1(name_cstr.as_ptr(), profile_index)
})
}
pub fn tensor_format_desc(&self, name: &str) -> Result<String> {
let name_cstr = std::ffi::CString::new(name)?;
let ptr = unsafe { self.inner.getTensorFormatDesc(name_cstr.as_ptr()) };
if ptr.is_null() {
return Ok(String::new());
}
Ok(unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string())
}
pub fn tensor_format_desc_for_profile(&self, name: &str, profile_index: i32) -> Result<String> {
let name_cstr = std::ffi::CString::new(name)?;
let ptr = unsafe {
self.inner
.getTensorFormatDesc1(name_cstr.as_ptr(), profile_index)
};
if ptr.is_null() {
return Ok(String::new());
}
Ok(unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string())
}
pub fn tensor_vectorized_dim(&self, name: &str) -> Result<i32> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorVectorizedDim(name_cstr.as_ptr()) })
}
pub fn tensor_vectorized_dim_for_profile(&self, name: &str, profile_index: i32) -> Result<i32> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe {
self.inner
.getTensorVectorizedDim1(name_cstr.as_ptr(), profile_index)
})
}
pub fn tensor_bytes_per_component(&self, name: &str) -> Result<i32> {
#[cfg(not(feature = "mock_runtime"))]
{
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorBytesPerComponent(name_cstr.as_ptr()) })
}
#[cfg(feature = "mock_runtime")]
Ok(42)
}
pub fn tensor_bytes_per_component_for_profile(
&self,
name: &str,
profile_index: i32,
) -> Result<i32> {
if !self.inner.is_null() {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe {
self.inner
.getTensorBytesPerComponent1(name_cstr.as_ptr(), profile_index)
})
} else {
Ok(0)
}
}
pub fn tensor_components_per_element(&self, name: &str) -> Result<i32> {
#[cfg(not(feature = "mock_runtime"))]
{
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.getTensorComponentsPerElement(name_cstr.as_ptr()) })
}
#[cfg(feature = "mock_runtime")]
{
Ok(42)
}
}
pub fn tensor_components_per_element_for_profile(
&self,
name: &str,
profile_index: i32,
) -> Result<i32> {
#[cfg(not(feature = "mock_runtime"))]
{
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe {
self.inner
.getTensorComponentsPerElement1(name_cstr.as_ptr(), profile_index)
})
}
#[cfg(feature = "mock_runtime")]
{
Ok(42)
}
}
pub fn create_engine_inspector(&self) -> Result<EngineInspector<'_>> {
#[cfg(not(feature = "mock_runtime"))]
{
use crate::engine_inspector::EngineInspector;
let inspector = self.inner.createEngineInspector();
let inspector = unsafe {
inspector.as_mut().ok_or_else(|| {
Error::Runtime("Failed to create engine inspector".to_string())
})?
};
Ok(EngineInspector {
inner: unsafe { UniquePtr::from_raw(inspector) },
_engine: Default::default(),
})
}
#[cfg(feature = "mock_runtime")]
{
Ok(EngineInspector {
inner: UniquePtr::null(),
_engine: Default::default(),
})
}
}
#[deprecated = "use tensor_data_type instead"]
pub fn tensor_dtype(&self, name: &str) -> Result<DataType> {
self.tensor_data_type(name)
}
pub fn create_execution_context(&'_ mut self) -> Result<ExecutionContext<'engine>> {
#[cfg(not(feature = "mock_runtime"))]
{
use crate::ExecutionContext;
let context_ptr = self
.inner
.pin_mut()
.createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kSTATIC);
Ok(unsafe { ExecutionContext::from_ptr(context_ptr, None)? })
}
#[cfg(feature = "mock_runtime")]
Ok(unsafe { ExecutionContext::from_ptr(std::ptr::null_mut(), None)? })
}
pub fn create_execution_context_with_config(
&'_ mut self,
runtime_config: Rc<RuntimeConfig<'engine>>,
) -> Result<ExecutionContext<'engine>> {
#[cfg(not(feature = "mock_runtime"))]
{
use crate::ExecutionContext;
let context_ptr = unsafe {
self.inner
.pin_mut()
.createExecutionContext1(runtime_config.inner.as_mut_ptr())
};
Ok(unsafe { ExecutionContext::from_ptr(context_ptr, Some(runtime_config))? })
}
#[cfg(feature = "mock_runtime")]
Ok(unsafe { ExecutionContext::from_ptr(std::ptr::null_mut(), None)? })
}
pub fn create_serialization_config(&mut self) -> Result<SerializationConfig<'engine>> {
let config = unsafe {
self.inner
.pin_mut()
.createSerializationConfig()
.as_mut()
.ok_or_else(|| Error::Runtime("SerializationConfig creation failed".to_string()))?
};
Ok(SerializationConfig {
inner: unsafe { UniquePtr::from_raw(config) },
_runtime: Default::default(),
})
}
pub fn serialize_with_config(
&'_ self,
config: &mut SerializationConfig,
) -> Result<HostMemory<'engine>> {
if !cfg!(feature = "mock_runtime") {
let host_mem = unsafe {
self.inner
.serializeWithConfig(config.inner.pin_mut())
.as_mut()
.ok_or_else(|| {
Error::Runtime("Failed to serialize ICudaEngine with config".to_string())
})?
};
Ok(unsafe { HostMemory::from_raw(host_mem) })
} else {
Ok(unsafe { HostMemory::from_raw(std::ptr::null_mut()) })
}
}
pub fn is_shape_inference_io(&self, name: &str) -> Result<bool> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.isShapeInferenceIO(name_cstr.as_ptr()) })
}
pub fn create_runtime_config(&'_ mut self) -> Result<RuntimeConfig<'engine>> {
#[cfg(not(feature = "mock"))]
let config_ptr = self.inner.pin_mut().createRuntimeConfig();
#[cfg(feature = "mock")]
let config_ptr = std::ptr::null_mut();
RuntimeConfig::new(config_ptr)
}
pub fn io_tensor_names(&self) -> Result<CudaEngineIoTensorNamesIter<'_>> {
Ok(CudaEngineIoTensorNamesIter {
engine: self,
index: 0,
count: self.nb_io_tensors()?,
})
}
pub fn input_tensor_names(&self) -> Result<impl Iterator<Item = String> + '_> {
Ok(self
.io_tensor_names()?
.filter(|name| self.tensor_io_mode(name).ok() == Some(TensorIOMode::kINPUT)))
}
pub fn output_tensor_names(&self) -> Result<impl Iterator<Item = String> + '_> {
Ok(self
.io_tensor_names()?
.filter(|name| self.tensor_io_mode(name).ok() == Some(TensorIOMode::kOUTPUT)))
}
}
#[derive(Debug)]
pub struct CudaEngineIoTensorNamesIter<'a> {
engine: &'a CudaEngine<'a>,
index: i32,
count: i32,
}
impl Iterator for CudaEngineIoTensorNamesIter<'_> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let name = self
.engine
.io_tensor_name(self.index)
.expect("valid tensor index");
self.index += 1;
Some(name)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for CudaEngineIoTensorNamesIter<'_> {}
#[cfg(test)]
#[cfg(not(feature = "mock_runtime"))]
mod tests {
use crate::builder::network_flags;
use crate::builder::{Builder, MemoryPoolType};
use crate::logger::Logger;
use crate::runtime::Runtime;
use crate::{CudaEngine, DataType};
use trtx_sys::{ActivationType, LayerInformationFormat};
fn build_minimal_engine_with_verbose_profiling(logger: &Logger) -> crate::Result<Vec<u8>> {
let mut builder = Builder::new(logger)?;
let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
let mut tensor = network.add_input("input", DataType::kFLOAT, &[1, 4])?;
tensor = network
.add_activation(&tensor, ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
tensor = network
.add_activation(&tensor, ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
tensor.set_name(&mut network, "output").unwrap();
network.mark_output(&tensor);
let mut config = builder.create_config()?;
config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 20);
config.set_profiling_verbosity(crate::ProfilingVerbosity::kDETAILED);
let engine_data = builder.build_serialized_network(&mut network, &mut config)?;
Ok(engine_data.to_vec())
}
#[test]
fn input_output_tensor_names_iter() {
let logger = Logger::stderr().expect("logger");
let engine_data =
build_minimal_engine_with_verbose_profiling(&logger).expect("build engine");
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize");
let inputs: Vec<_> = engine.input_tensor_names().unwrap().collect();
let outputs: Vec<_> = engine.output_tensor_names().unwrap().collect();
assert_eq!(inputs, ["input"]);
assert_eq!(outputs, ["output"]);
assert_eq!(
inputs.len() + outputs.len(),
engine.nb_io_tensors().unwrap() as usize
);
}
#[test]
fn io_tensor_names_iter() {
let logger = Logger::stderr().expect("logger");
let engine_data =
build_minimal_engine_with_verbose_profiling(&logger).expect("build engine");
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize");
let names: Vec<_> = engine.io_tensor_names().unwrap().collect();
assert_eq!(engine.io_tensor_names().unwrap().len(), names.len());
let mut old_style = Vec::new();
for i in 0..engine.nb_io_tensors().unwrap() {
old_style.push(engine.io_tensor_name(i).unwrap());
}
assert_eq!(names, old_style);
}
#[test]
fn engine_inspector_json_verbose_profiling() {
let logger = Logger::stderr().expect("logger");
let engine_data =
build_minimal_engine_with_verbose_profiling(&logger).expect("build engine");
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine: CudaEngine<'_> = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize");
let inspector = engine.create_engine_inspector().expect("engine inspector");
let json = inspector
.engine_information(LayerInformationFormat::kJSON)
.expect("get_engine_information JSON");
assert!(
!json.is_empty(),
"engine information JSON should not be empty"
);
assert!(
json.trim_start().starts_with('{'),
"engine information should be JSON (starts with '{{'); got: {}...",
json.chars().take(80).collect::<String>()
);
}
}