async_tensorrt/ffi/parser.rs
1use cpp::cpp;
2
3use crate::error::last_error;
4use crate::ffi::network::NetworkDefinition;
5
6type Result<T> = std::result::Result<T, crate::error::Error>;
7
8/// For parsing an ONNX model into a TensorRT network definition ([`crate::NetworkDefinition`]).
9///
10/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html)
11pub struct Parser(*mut std::ffi::c_void);
12
13impl Parser {
14 /// Create new parser, parse ONNX file and return a [`crate::NetworkDefinition`].
15 ///
16 /// Note that this function is CPU-intensive. Callers should not use it in async context or
17 /// spawn a blocking task for it.
18 ///
19 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html#a973ac2ed682f18c4c6258ed93fc8efa3)
20 ///
21 /// # Arguments
22 ///
23 /// * `network_definition` - Network definition to use.
24 /// * `path` - Path to file to parse.
25 ///
26 /// # Return value
27 ///
28 /// Parsed network definition.
29 pub fn parse_network_definition_from_file(
30 mut network_definition: NetworkDefinition,
31 path: &impl AsRef<std::path::Path>,
32 ) -> Result<NetworkDefinition> {
33 // SAFETY: The call to `Parser::new` is unsafe because we must ensure that the new parser
34 // outlives `network_definition`. We manually make sure of that here by putting the parser
35 // inside `NetworkDefinition` and such it will only be destroyed when `network_definition`
36 // is.
37 unsafe {
38 let mut parser = Self::new(&mut network_definition);
39 parser.parse_from_file(path)?;
40 // Put parser object in `network_definition` because destroying the parser before the
41 // network definition is not allowed.
42 network_definition._parser = Some(parser);
43 }
44 Ok(network_definition)
45 }
46
47 /// Parse ONNX file.
48 ///
49 /// Note that this function is CPU-intensive. Callers should not use it in async context or
50 /// spawn a blocking task for it.
51 ///
52 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html#a973ac2ed682f18c4c6258ed93fc8efa3)
53 ///
54 /// # Arguments
55 ///
56 /// * `path` - Path to file to parse.
57 fn parse_from_file(&mut self, path: &impl AsRef<std::path::Path>) -> Result<()> {
58 let internal = self.as_mut_ptr();
59 let path_ffi = std::ffi::CString::new(path.as_ref().as_os_str().to_str().unwrap()).unwrap();
60 let path_ptr = path_ffi.as_ptr();
61 let ret = cpp!(unsafe [
62 internal as "void*",
63 path_ptr as "const char*"
64 ] -> bool as "bool" {
65 return ((IParser*) internal)->parseFromFile(
66 path_ptr,
67 // Set to `VERBOSE` and let Rust code handle what message are passed on based on
68 // logger configuration.
69 static_cast<int>(ILogger::Severity::kVERBOSE)
70 );
71 });
72 if ret {
73 Ok(())
74 } else {
75 Err(last_error())
76 }
77 }
78
79 /// Create new parser.
80 ///
81 /// # Arguments
82 ///
83 /// * Reference to network definition to attach to parser.
84 ///
85 /// # Safety
86 ///
87 /// Caller must ensure that the [`Parser`] outlives the given [`NetworkDefinition`].
88 unsafe fn new(network_definition: &mut NetworkDefinition) -> Self {
89 let network_definition_internal = network_definition.as_ptr();
90 let internal = cpp!(unsafe [
91 network_definition_internal as "void*"
92 ] -> *mut std::ffi::c_void as "void*" {
93 return createParser(
94 *((INetworkDefinition*) network_definition_internal),
95 GLOBAL_LOGGER
96 );
97 });
98 Parser(internal)
99 }
100
101 /// Get internal readonly pointer.
102 #[inline(always)]
103 pub fn as_ptr(&self) -> *const std::ffi::c_void {
104 let Parser(internal) = *self;
105 internal
106 }
107
108 /// Get internal mutable pointer.
109 #[inline(always)]
110 pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
111 let Parser(internal) = *self;
112 internal
113 }
114}
115
116impl Drop for Parser {
117 fn drop(&mut self) {
118 let internal = self.as_mut_ptr();
119 cpp!(unsafe [
120 internal as "void*"
121 ] {
122 destroy((IParser*) internal);
123 });
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 use crate::tests::onnx::*;
132 use crate::{Builder, NetworkDefinitionCreationFlags};
133
134 #[tokio::test]
135 async fn test_parser_parses_onnx_file() {
136 let simple_onnx_file = simple_onnx_file!();
137 let mut builder = Builder::new().await.unwrap();
138 let network = builder.network_definition(NetworkDefinitionCreationFlags::ExplicitBatchSize);
139 assert!(
140 Parser::parse_network_definition_from_file(network, &simple_onnx_file.path()).is_ok()
141 );
142 }
143}