use std::path::PathBuf;
use cxx::UniquePtr;
use std::ffi::{c_void, CString};
use trtx_sys::nvonnxparser;
use crate::error::{Error, Result};
use crate::logger::Logger;
use crate::network::NetworkDefinition;
pub struct OnnxParser<'network> {
inner: UniquePtr<nvonnxparser::IParser>,
network: NetworkDefinition<'network>,
}
impl<'parser> OnnxParser<'parser> {
#[cfg(not(any(
feature = "link_tensorrt_onnxparser",
feature = "dlopen_tensorrt_onnxparser"
)))]
pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result<Self> {
Err(Error::TrtOnnxParserLibraryNotLoaded)
}
#[cfg(any(
feature = "link_tensorrt_onnxparser",
feature = "dlopen_tensorrt_onnxparser"
))]
pub fn new(network: NetworkDefinition<'parser>, logger: &Logger) -> Result<Self> {
#[cfg(not(feature = "mock"))]
{
let network_ptr = network.inner.as_mut_ptr();
let logger_ptr = logger.as_logger_ptr();
let parser_ptr = {
#[cfg(feature = "link_tensorrt_onnxparser")]
unsafe {
trtx_sys::create_onnx_parser(network_ptr, logger_ptr)
}
#[cfg(not(feature = "link_tensorrt_onnxparser"))]
#[cfg(feature = "dlopen_tensorrt_onnxparser")]
unsafe {
use libloading::Symbol;
use trtx_sys::nvinfer1::INetworkDefinition;
use crate::TRT_ONNXPARSER_LIB;
if !TRT_ONNXPARSER_LIB.read()?.is_some() {
crate::dynamically_load_tensorrt_onnxparser(None::<String>)?;
}
let lock = TRT_ONNXPARSER_LIB
.read()
.map_err(|_| Error::LockPoisining)?;
let create_onnx_parser: Symbol<
fn(*mut INetworkDefinition, *mut c_void, u32) -> *mut c_void,
> = lock
.as_ref()
.ok_or(Error::TrtOnnxParserLibraryNotLoaded)?
.get(b"createNvOnnxParser_INTERNAL")?;
create_onnx_parser(
network_ptr,
logger_ptr,
trtx_sys::get_nvonnxparser_version(),
)
}
} as *mut nvonnxparser::IParser;
if parser_ptr.is_null() {
return Err(Error::Runtime("Failed to create ONNX parser".to_string()));
}
Ok(OnnxParser {
inner: unsafe { UniquePtr::from_raw(parser_ptr) },
network,
})
}
#[cfg(feature = "mock")]
Ok(OnnxParser {
inner: UniquePtr::null(),
network,
})
}
pub fn parse_from_file(&mut self, path: &str, verbosity: i32) -> Result<()> {
let cpath = CString::new(path)?;
unsafe {
if self
.inner
.pin_mut()
.parseFromFile(cpath.as_ptr(), verbosity.into())
{
Ok(())
} else {
Err(Error::FailedToParseOnnx(PathBuf::from(path)))
}
}
}
pub fn parse(&mut self, model_bytes: &[u8]) -> Result<()> {
#[cfg(not(feature = "mock"))]
{
if self.inner.is_null() {
return Err(Error::Runtime("Invalid parser".to_string()));
}
let parser_ptr = self.inner.as_mut_ptr() as *mut c_void;
let success = unsafe {
trtx_sys::parser_parse(
parser_ptr,
model_bytes.as_ptr() as *const std::ffi::c_void,
model_bytes.len(),
)
};
if !success {
let error_msg = unsafe {
let num_errors = trtx_sys::parser_get_nb_errors(parser_ptr);
if num_errors > 0 {
let err_ptr = trtx_sys::parser_get_error(parser_ptr, 0);
if !err_ptr.is_null() {
let desc_ptr = trtx_sys::parser_error_desc(err_ptr);
if !desc_ptr.is_null() {
std::ffi::CStr::from_ptr(desc_ptr)
.to_str()
.unwrap_or("Failed to parse ONNX model")
.to_string()
} else {
"Failed to parse ONNX model".to_string()
}
} else {
"Failed to parse ONNX model".to_string()
}
} else {
"Failed to parse ONNX model".to_string()
}
};
return Err(Error::Runtime(error_msg));
}
}
Ok(())
}
pub fn network(&'parser self) -> &'parser NetworkDefinition<'parser> {
&self.network
}
pub fn network_mut(&'parser mut self) -> &'parser mut NetworkDefinition<'parser> {
&mut self.network
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Builder, Logger};
#[test]
fn test_onnx_parser_creation() {
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let network = builder.create_network(0).unwrap();
let parser = OnnxParser::new(network, &logger);
assert!(parser.is_ok());
}
#[test]
#[ignore] fn test_onnx_parser_with_real_model() {
let model_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/super-resolution-10.onnx"
);
let model_bytes = std::fs::read(model_path).expect("Failed to read test ONNX model");
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let network = builder.create_network(0).unwrap();
let mut parser = OnnxParser::new(network, &logger).unwrap();
let result = parser.parse(&model_bytes);
assert!(
result.is_ok(),
"Failed to parse ONNX model: {:?}",
result.err()
);
}
}