1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
use cpp::cpp;
use crate::error::last_error;
use crate::ffi::network::NetworkDefinition;
type Result<T> = std::result::Result<T, crate::error::Error>;
/// For parsing an ONNX model into a TensorRT network definition ([`crate::NetworkDefinition`]).
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html)
pub struct Parser(*mut std::ffi::c_void);
impl Parser {
/// Create new parser, parse ONNX file and return a [`crate::NetworkDefinition`].
///
/// Note that this function is CPU-intensive. Callers should not use it in async context or
/// spawn a blocking task for it.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html#a973ac2ed682f18c4c6258ed93fc8efa3)
///
/// # Arguments
///
/// * `network_definition` - Network definition to use.
/// * `path` - Path to file to parse.
///
/// # Return value
///
/// Parsed network definition.
pub fn parse_network_definition_from_file(
mut network_definition: NetworkDefinition,
path: &impl AsRef<std::path::Path>,
) -> Result<NetworkDefinition> {
// SAFETY: The call to `Parser::new` is unsafe because we must ensure that the new parser
// outlives `network_definition`. We manually make sure of that here by putting the parser
// inside `NetworkDefinition` and such it will only be destroyed when `network_definition`
// is.
unsafe {
let mut parser = Self::new(&mut network_definition);
parser.parse_from_file(path)?;
// Put parser object in `network_definition` because destroying the parser before the
// network definition is not allowed.
network_definition._parser = Some(parser);
}
Ok(network_definition)
}
/// Parse ONNX file.
///
/// Note that this function is CPU-intensive. Callers should not use it in async context or
/// spawn a blocking task for it.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html#a973ac2ed682f18c4c6258ed93fc8efa3)
///
/// # Arguments
///
/// * `path` - Path to file to parse.
fn parse_from_file(&mut self, path: &impl AsRef<std::path::Path>) -> Result<()> {
let internal = self.as_mut_ptr();
let path_ffi = std::ffi::CString::new(path.as_ref().as_os_str().to_str().unwrap()).unwrap();
let path_ptr = path_ffi.as_ptr();
let ret = cpp!(unsafe [
internal as "void*",
path_ptr as "const char*"
] -> bool as "bool" {
return ((IParser*) internal)->parseFromFile(
path_ptr,
// Set to `VERBOSE` and let Rust code handle what message are passed on based on
// logger configuration.
static_cast<int>(ILogger::Severity::kVERBOSE)
);
});
if ret {
Ok(())
} else {
Err(last_error())
}
}
/// Create new parser.
///
/// # Arguments
///
/// * Reference to network definition to attach to parser.
///
/// # Safety
///
/// Caller must ensure that the [`Parser`] outlives the given [`NetworkDefinition`].
unsafe fn new(network_definition: &mut NetworkDefinition) -> Self {
let network_definition_internal = network_definition.as_ptr();
let internal = cpp!(unsafe [
network_definition_internal as "void*"
] -> *mut std::ffi::c_void as "void*" {
return createParser(
*((INetworkDefinition*) network_definition_internal),
GLOBAL_LOGGER
);
});
Parser(internal)
}
/// Get internal readonly pointer.
#[inline(always)]
pub fn as_ptr(&self) -> *const std::ffi::c_void {
let Parser(internal) = *self;
internal
}
/// Get internal mutable pointer.
#[inline(always)]
pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
let Parser(internal) = *self;
internal
}
}
impl Drop for Parser {
fn drop(&mut self) {
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*"
] {
destroy((IParser*) internal);
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::onnx::*;
use crate::{Builder, NetworkDefinitionCreationFlags};
#[tokio::test]
async fn test_parser_parses_onnx_file() {
let simple_onnx_file = simple_onnx_file!();
let mut builder = Builder::new().await;
let network = builder.network_definition(NetworkDefinitionCreationFlags::ExplicitBatchSize);
assert!(
Parser::parse_network_definition_from_file(network, &simple_onnx_file.path()).is_ok()
);
}
}