async_tensorrt/ffi/
parser.rs

1use cpp::cpp;
2
3use crate::error::last_error;
4use crate::ffi::network::NetworkDefinition;
5
6type Result<T> = std::result::Result<T, crate::error::Error>;
7
8/// For parsing an ONNX model into a TensorRT network definition ([`crate::NetworkDefinition`]).
9///
10/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html)
11pub struct Parser(*mut std::ffi::c_void);
12
13impl Parser {
14    /// Create new parser, parse ONNX file and return a [`crate::NetworkDefinition`].
15    ///
16    /// Note that this function is CPU-intensive. Callers should not use it in async context or
17    /// spawn a blocking task for it.
18    ///
19    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html#a973ac2ed682f18c4c6258ed93fc8efa3)
20    ///
21    /// # Arguments
22    ///
23    /// * `network_definition` - Network definition to use.
24    /// * `path` - Path to file to parse.
25    ///
26    /// # Return value
27    ///
28    /// Parsed network definition.
29    pub fn parse_network_definition_from_file(
30        mut network_definition: NetworkDefinition,
31        path: &impl AsRef<std::path::Path>,
32    ) -> Result<NetworkDefinition> {
33        // SAFETY: The call to `Parser::new` is unsafe because we must ensure that the new parser
34        // outlives `network_definition`. We manually make sure of that here by putting the parser
35        // inside `NetworkDefinition` and such it will only be destroyed when `network_definition`
36        // is.
37        unsafe {
38            let mut parser = Self::new(&mut network_definition);
39            parser.parse_from_file(path)?;
40            // Put parser object in `network_definition` because destroying the parser before the
41            // network definition is not allowed.
42            network_definition._parser = Some(parser);
43        }
44        Ok(network_definition)
45    }
46
47    /// Parse ONNX file.
48    ///
49    /// Note that this function is CPU-intensive. Callers should not use it in async context or
50    /// spawn a blocking task for it.
51    ///
52    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvonnxparser_1_1_i_parser.html#a973ac2ed682f18c4c6258ed93fc8efa3)
53    ///
54    /// # Arguments
55    ///
56    /// * `path` - Path to file to parse.
57    fn parse_from_file(&mut self, path: &impl AsRef<std::path::Path>) -> Result<()> {
58        let internal = self.as_mut_ptr();
59        let path_ffi = std::ffi::CString::new(path.as_ref().as_os_str().to_str().unwrap()).unwrap();
60        let path_ptr = path_ffi.as_ptr();
61        let ret = cpp!(unsafe [
62            internal as "void*",
63            path_ptr as "const char*"
64        ] -> bool as "bool" {
65            return ((IParser*) internal)->parseFromFile(
66                path_ptr,
67                // Set to `VERBOSE` and let Rust code handle what message are passed on based on
68                // logger configuration.
69                static_cast<int>(ILogger::Severity::kVERBOSE)
70            );
71        });
72        if ret {
73            Ok(())
74        } else {
75            Err(last_error())
76        }
77    }
78
79    /// Create new parser.
80    ///
81    /// # Arguments
82    ///
83    /// * Reference to network definition to attach to parser.
84    ///
85    /// # Safety
86    ///
87    /// Caller must ensure that the [`Parser`] outlives the given [`NetworkDefinition`].
88    unsafe fn new(network_definition: &mut NetworkDefinition) -> Self {
89        let network_definition_internal = network_definition.as_ptr();
90        let internal = cpp!(unsafe [
91            network_definition_internal as "void*"
92        ] -> *mut std::ffi::c_void as "void*" {
93            return createParser(
94                *((INetworkDefinition*) network_definition_internal),
95                GLOBAL_LOGGER
96            );
97        });
98        Parser(internal)
99    }
100
101    /// Get internal readonly pointer.
102    #[inline(always)]
103    pub fn as_ptr(&self) -> *const std::ffi::c_void {
104        let Parser(internal) = *self;
105        internal
106    }
107
108    /// Get internal mutable pointer.
109    #[inline(always)]
110    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
111        let Parser(internal) = *self;
112        internal
113    }
114}
115
116impl Drop for Parser {
117    fn drop(&mut self) {
118        let internal = self.as_mut_ptr();
119        cpp!(unsafe [
120            internal as "void*"
121        ] {
122            destroy((IParser*) internal);
123        });
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    use crate::tests::onnx::*;
132    use crate::{Builder, NetworkDefinitionCreationFlags};
133
134    #[tokio::test]
135    async fn test_parser_parses_onnx_file() {
136        let simple_onnx_file = simple_onnx_file!();
137        let mut builder = Builder::new().await.unwrap();
138        let network = builder.network_definition(NetworkDefinitionCreationFlags::ExplicitBatchSize);
139        assert!(
140            Parser::parse_network_definition_from_file(network, &simple_onnx_file.path()).is_ok()
141        );
142    }
143}