use std::error::Error;
use trtx::executor::{run_onnx_with_tensorrt, run_onnx_zeroed, TensorInput};
fn main() -> Result<(), Box<dyn Error>> {
#[cfg(feature = "dlopen_tensorrt_rtx")]
trtx::dynamically_load_tensorrt(None::<String>).unwrap();
println!("TensorRT-RTX Executor for rustnn");
println!("==================================\n");
println!("1. Testing with zero-filled inputs...");
let dummy_onnx = create_dummy_onnx_model();
let input_descriptors = vec![("input".to_string(), vec![1, 3, 224, 224])];
match run_onnx_zeroed(&dummy_onnx, &input_descriptors) {
Ok(outputs) => {
println!(" ✓ Execution succeeded");
println!(" Outputs:");
for output in outputs {
println!(
" - {}: shape {:?}, {} values",
output.name,
output.shape,
output.data.len()
);
}
}
Err(e) => {
println!(" ✗ Execution failed: {}", e);
println!(" (This is expected with a dummy model in mock mode)");
}
}
println!("\n2. Testing with actual input data...");
let inputs = vec![TensorInput {
name: "input".to_string(),
shape: vec![1, 3, 224, 224],
data: create_sample_input(3 * 224 * 224),
}];
match run_onnx_with_tensorrt(&dummy_onnx, &inputs) {
Ok(outputs) => {
println!(" ✓ Execution succeeded");
for output in outputs {
println!(" - {}: shape {:?}", output.name, output.shape);
println!(
" First 5 values: {:?}",
&output.data[..output.data.len().min(5)]
);
}
}
Err(e) => {
println!(" ✗ Execution failed: {}", e);
println!(" (Expected with dummy model - use real ONNX for actual inference)");
}
}
println!("\n3. rustnn Integration Pattern");
println!(" To use in rustnn, implement:");
println!(" ```rust");
println!(" #[cfg(feature = \"trtx-runtime\")]");
println!(" pub fn run_trtx_with_inputs(");
println!(" model_bytes: &[u8],");
println!(" inputs: &[TrtxInput],");
println!(" ) -> Result<Vec<TrtxOutputWithData>> {{");
println!(" trtx::run_onnx_with_tensorrt(model_bytes, inputs)");
println!(" }}");
println!(" ```");
println!("\n✓ Example completed");
Ok(())
}
fn create_dummy_onnx_model() -> Vec<u8> {
vec![0u8; 100]
}
fn create_sample_input(size: usize) -> Vec<f32> {
(0..size).map(|i| (i as f32 * 0.001).sin()).collect()
}