use crate::builder::network_flags;
use crate::error::Result;
use crate::{Builder, DeviceBuffer, OnnxParser};
use crate::{Logger, Runtime};
#[derive(Debug, Clone)]
pub struct TensorInput {
pub name: String,
pub shape: Vec<usize>,
pub data: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct TensorOutput {
pub name: String,
pub shape: Vec<usize>,
pub data: Vec<f32>,
}
pub fn run_onnx_with_tensorrt(
onnx_model_bytes: &[u8],
inputs: &[TensorInput],
) -> Result<Vec<TensorOutput>> {
let logger = Logger::stderr()?;
let engine_data = build_engine_from_onnx(&logger, onnx_model_bytes)?;
execute_engine(&logger, &engine_data, inputs)
}
fn build_engine_from_onnx(logger: &Logger, onnx_bytes: &[u8]) -> Result<Vec<u8>> {
use crate::builder::MemoryPoolType;
let mut builder = Builder::new(logger)?;
let network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
let mut parser = OnnxParser::new(network, logger)?;
parser.parse(onnx_bytes)?;
let network = parser.network_mut();
let mut config = builder.create_config()?;
config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 30);
let memory = builder.build_serialized_network(network, &mut config)?;
Ok(memory.to_vec())
}
fn execute_engine(
logger: &Logger,
engine_data: &[u8],
inputs: &[TensorInput],
) -> Result<Vec<TensorOutput>> {
let mut runtime = Runtime::new(logger)?;
let mut engine = runtime.deserialize_cuda_engine(engine_data)?;
let mut context = engine.create_execution_context()?;
let num_tensors = engine.nb_io_tensors()?;
let mut device_buffers: Vec<(String, DeviceBuffer)> = Vec::new();
let mut output_info: Vec<(String, Vec<usize>)> = Vec::new();
for i in 0..num_tensors {
let name = engine.io_tensor_name(i)?;
if let Some(input) = inputs.iter().find(|inp| inp.name == name) {
let expected_shape_i64 = engine.tensor_shape(&name)?;
let expected_shape: Vec<usize> =
expected_shape_i64.iter().map(|&d| d as usize).collect();
let expected_elements: usize = expected_shape.iter().product();
let provided_elements: usize = input.shape.iter().product();
if provided_elements != expected_elements {
return Err(crate::Error::InvalidArgument(format!(
"Input tensor '{}' shape mismatch: expected {:?} ({} elements), got {:?} ({} elements)",
name, expected_shape, expected_elements, input.shape, provided_elements
)));
}
if input.data.len() != provided_elements {
return Err(crate::Error::InvalidArgument(format!(
"Input tensor '{}' data length ({}) doesn't match shape {:?} ({} elements)",
name,
input.data.len(),
input.shape,
provided_elements
)));
}
let size_bytes = input.data.len() * std::mem::size_of::<f32>();
let mut buffer = DeviceBuffer::new(size_bytes)?;
let input_bytes =
unsafe { std::slice::from_raw_parts(input.data.as_ptr() as *const u8, size_bytes) };
buffer.copy_from_host(input_bytes)?;
unsafe {
context.set_tensor_address(&name, buffer.as_ptr())?;
}
device_buffers.push((name.clone(), buffer));
} else {
let shape_i64 = engine.tensor_shape(&name)?;
let shape: Vec<usize> = shape_i64.iter().map(|&d| d as usize).collect();
let num_elements: usize = shape.iter().product();
let size_bytes = num_elements * std::mem::size_of::<f32>();
let buffer = DeviceBuffer::new(size_bytes)?;
unsafe {
context.set_tensor_address(&name, buffer.as_ptr())?;
}
output_info.push((name.clone(), shape));
device_buffers.push((name.clone(), buffer));
}
}
unsafe {
context.enqueue_v3(crate::cuda::default_stream())?;
}
crate::cuda::synchronize()?;
let mut outputs = Vec::new();
for (name, shape) in output_info {
if let Some((_, buffer)) = device_buffers.iter().find(|(n, _)| n == &name) {
let size_bytes = shape.iter().product::<usize>() * std::mem::size_of::<f32>();
let mut host_data: Vec<u8> = vec![0u8; size_bytes];
buffer.copy_to_host(host_data.as_mut_slice())?;
let data: Vec<f32> = unsafe {
std::slice::from_raw_parts(
host_data.as_ptr() as *const f32,
size_bytes / std::mem::size_of::<f32>(),
)
}
.to_vec();
outputs.push(TensorOutput { name, shape, data });
}
}
Ok(outputs)
}
pub fn run_onnx_zeroed(
onnx_model_bytes: &[u8],
input_descriptors: &[(String, Vec<usize>)],
) -> Result<Vec<TensorOutput>> {
let inputs: Vec<TensorInput> = input_descriptors
.iter()
.map(|(name, shape)| {
let size: usize = shape.iter().product();
TensorInput {
name: name.clone(),
shape: shape.clone(),
data: vec![0.0; size],
}
})
.collect();
run_onnx_with_tensorrt(onnx_model_bytes, &inputs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_input_creation() {
let input = TensorInput {
name: "input".to_string(),
shape: vec![1, 3, 224, 224],
data: vec![0.0; 3 * 224 * 224],
};
assert_eq!(input.name, "input");
assert_eq!(input.shape, vec![1, 3, 224, 224]);
assert_eq!(input.data.len(), 3 * 224 * 224);
}
#[test]
#[ignore] fn test_executor_basic() {
let dummy_onnx = vec![0u8; 100];
let inputs = vec![("input".to_string(), vec![1, 3, 224, 224])];
let _result = run_onnx_zeroed(&dummy_onnx, &inputs);
#[cfg(feature = "mock_runtime")]
assert!(_result.is_ok());
}
}