trtx 0.7.0+rtx1.5

Safe Rust bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
//! ONNX model parser for TensorRT.
//!
//! [`OnnxParser`] wraps [`trtx_sys::nvonnxparser::IParser`] (C++ [`nvonnxparser::IParser`](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/classnvonnxparser_1_1_i_parser.html)).

use std::path::PathBuf;

use cxx::UniquePtr;
use std::ffi::{c_void, CString};
use trtx_sys::nvonnxparser;

use crate::error::{Error, Result};
use crate::logger::Logger;
use crate::network::NetworkDefinition;

/// [`trtx_sys::nvonnxparser::IParser`] — C++ [`nvonnxparser::IParser`](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/classnvonnxparser_1_1_i_parser.html).
pub struct OnnxParser<'network> {
    inner: UniquePtr<nvonnxparser::IParser>,
    network: NetworkDefinition<'network>,
}

impl<'parser> OnnxParser<'parser> {
    #[cfg(not(any(
        feature = "link_tensorrt_onnxparser",
        feature = "dlopen_tensorrt_onnxparser"
    )))]
    pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result<Self> {
        Err(Error::TrtOnnxParserLibraryNotLoaded)
    }

    #[cfg(any(
        feature = "link_tensorrt_onnxparser",
        feature = "dlopen_tensorrt_onnxparser"
    ))]
    /// Creates a new Parser. This consumes NetworkDefinition to ensure the parser is not dropped
    /// before the network build is finished (as the parser holds the weights).
    ///
    /// Use [OnnxParser::network], [OnnxParser::network_mut] to get a borrow to the network
    pub fn new(network: NetworkDefinition<'parser>, logger: &Logger) -> Result<Self> {
        #[cfg(not(feature = "mock"))]
        {
            let network_ptr = network.inner.as_mut_ptr();
            let logger_ptr = logger.as_logger_ptr();
            let parser_ptr = {
                #[cfg(feature = "link_tensorrt_onnxparser")]
                unsafe {
                    trtx_sys::create_onnx_parser(network_ptr, logger_ptr)
                }
                #[cfg(not(feature = "link_tensorrt_onnxparser"))]
                #[cfg(feature = "dlopen_tensorrt_onnxparser")]
                unsafe {
                    use libloading::Symbol;
                    use trtx_sys::nvinfer1::INetworkDefinition;

                    use crate::TRT_ONNXPARSER_LIB;

                    if !TRT_ONNXPARSER_LIB.read()?.is_some() {
                        crate::dynamically_load_tensorrt_onnxparser(None::<String>)?;
                    }

                    let lock = TRT_ONNXPARSER_LIB
                        .read()
                        .map_err(|_| Error::LockPoisining)?;
                    let create_onnx_parser: Symbol<
                        fn(*mut INetworkDefinition, *mut c_void, u32) -> *mut c_void,
                    > = lock
                        .as_ref()
                        .ok_or(Error::TrtOnnxParserLibraryNotLoaded)?
                        .get(b"createNvOnnxParser_INTERNAL")?;
                    create_onnx_parser(
                        network_ptr,
                        logger_ptr,
                        trtx_sys::get_nvonnxparser_version(),
                    )
                }
            } as *mut nvonnxparser::IParser;
            if parser_ptr.is_null() {
                return Err(Error::Runtime("Failed to create ONNX parser".to_string()));
            }
            Ok(OnnxParser {
                inner: unsafe { UniquePtr::from_raw(parser_ptr) },
                network,
            })
        }
        #[cfg(feature = "mock")]
        Ok(OnnxParser {
            inner: UniquePtr::null(),
            network,
        })
    }

    pub fn parse_from_file(&mut self, path: &str, verbosity: i32) -> Result<()> {
        let cpath = CString::new(path)?;
        unsafe {
            if self
                .inner
                .pin_mut()
                .parseFromFile(cpath.as_ptr(), verbosity.into())
            {
                Ok(())
            } else {
                Err(Error::FailedToParseOnnx(PathBuf::from(path)))
            }
        }
    }

    pub fn parse(&mut self, model_bytes: &[u8]) -> Result<()> {
        #[cfg(not(feature = "mock"))]
        {
            if self.inner.is_null() {
                return Err(Error::Runtime("Invalid parser".to_string()));
            }
            let parser_ptr = self.inner.as_mut_ptr() as *mut c_void;
            let success = unsafe {
                trtx_sys::parser_parse(
                    parser_ptr,
                    model_bytes.as_ptr() as *const std::ffi::c_void,
                    model_bytes.len(),
                )
            };
            if !success {
                let error_msg = unsafe {
                    let num_errors = trtx_sys::parser_get_nb_errors(parser_ptr);
                    if num_errors > 0 {
                        let err_ptr = trtx_sys::parser_get_error(parser_ptr, 0);
                        if !err_ptr.is_null() {
                            let desc_ptr = trtx_sys::parser_error_desc(err_ptr);
                            if !desc_ptr.is_null() {
                                std::ffi::CStr::from_ptr(desc_ptr)
                                    .to_str()
                                    .unwrap_or("Failed to parse ONNX model")
                                    .to_string()
                            } else {
                                "Failed to parse ONNX model".to_string()
                            }
                        } else {
                            "Failed to parse ONNX model".to_string()
                        }
                    } else {
                        "Failed to parse ONNX model".to_string()
                    }
                };
                return Err(Error::Runtime(error_msg));
            }
        }
        Ok(())
    }

    pub fn network(&'parser self) -> &'parser NetworkDefinition<'parser> {
        &self.network
    }

    pub fn network_mut(&'parser mut self) -> &'parser mut NetworkDefinition<'parser> {
        &mut self.network
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{Builder, Logger};

    #[test]
    fn test_onnx_parser_creation() {
        let logger = Logger::stderr().unwrap();
        let mut builder = Builder::new(&logger).unwrap();
        let network = builder.create_network(0).unwrap();
        let parser = OnnxParser::new(network, &logger);
        assert!(parser.is_ok());
    }

    #[test]
    #[ignore] // Requires GPU and TensorRT runtime
    fn test_onnx_parser_with_real_model() {
        let model_path = concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/tests/data/super-resolution-10.onnx"
        );
        let model_bytes = std::fs::read(model_path).expect("Failed to read test ONNX model");
        let logger = Logger::stderr().unwrap();
        let mut builder = Builder::new(&logger).unwrap();
        let network = builder.create_network(0).unwrap();
        let mut parser = OnnxParser::new(network, &logger).unwrap();
        let result = parser.parse(&model_bytes);
        assert!(
            result.is_ok(),
            "Failed to parse ONNX model: {:?}",
            result.err()
        );
    }
}