async_tensorrt/ffi/
network.rs

1use cpp::cpp;
2
3use crate::ffi::parser::Parser;
4
5/// Defined in `NvInferRuntimeBase.h`
6const MAX_DIMS: usize = 8;
7
8/// A network definition for input to the builder.
9///
10/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html)
11pub struct NetworkDefinition {
12    internal: *mut std::ffi::c_void,
13    pub(crate) _parser: Option<Parser>,
14}
15
16/// Implements [`Send`] for [`NetworkDefinition`].
17///
18/// # Safety
19///
20/// The TensorRT API is thread-safe with regards to all operations on [`NetworkDefinition`].
21unsafe impl Send for NetworkDefinition {}
22
23/// Implements [`Sync`] for [`NetworkDefinition`].
24///
25/// # Safety
26///
27/// The TensorRT API is thread-safe with regards to all operations on [`NetworkDefinition`].
28unsafe impl Sync for NetworkDefinition {}
29
30impl NetworkDefinition {
31    /// Wrap internal pointer as [`NetworkDefinition`].
32    ///
33    /// # Safety
34    ///
35    /// The pointer must point to a valid `INetworkDefinition` object.
36    pub(crate) fn wrap(internal: *mut std::ffi::c_void) -> Self {
37        Self {
38            internal,
39            _parser: None,
40        }
41    }
42
43    /// Get network inputs.
44    pub fn inputs(&self) -> Vec<Tensor> {
45        let mut inputs = Vec::with_capacity(self.num_inputs());
46        for index in 0..self.num_inputs() {
47            inputs.push(self.input(index));
48        }
49        inputs
50    }
51
52    /// Get number of inputs.
53    ///
54    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a715d0ea103f1978c5b5e9173af2994a4)
55    pub fn num_inputs(&self) -> usize {
56        let internal = self.as_ptr();
57        let num_inputs = cpp!(unsafe [
58            internal as "const void*"
59        ] -> std::os::raw::c_int as "int" {
60            return ((const INetworkDefinition*) internal)->getNbInputs();
61        });
62        num_inputs as usize
63    }
64
65    /// Get network input at given index.
66    ///
67    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a3142a780be319b7f6a9e9e7f6ed12ca4)
68    ///
69    /// # Arguments
70    ///
71    /// * `index` - Input index.
72    pub fn input(&self, index: usize) -> Tensor<'_> {
73        let internal = self.as_ptr();
74        let index = index as std::os::raw::c_int;
75        let tensor_internal = cpp!(unsafe [
76            internal as "const void*",
77            index as "int"
78        ] -> *mut std::ffi::c_void as "void*" {
79            return ((const INetworkDefinition*) internal)->getInput(index);
80        });
81        Tensor::wrap(tensor_internal)
82    }
83
84    /// Get network outputs.
85    pub fn outputs(&self) -> Vec<Tensor<'_>> {
86        let mut outputs = Vec::with_capacity(self.num_outputs());
87        for index in 0..self.num_outputs() {
88            outputs.push(self.output(index));
89        }
90        outputs
91    }
92
93    /// Get number of outputs.
94    ///
95    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#aef477421510ad25a342ecd950736a59a)
96    pub fn num_outputs(&self) -> usize {
97        let internal = self.as_ptr();
98        let num_outputs = cpp!(unsafe [
99            internal as "const void*"
100        ] -> std::os::raw::c_int as "int" {
101            return ((const INetworkDefinition*) internal)->getNbOutputs();
102        });
103        num_outputs as usize
104    }
105
106    /// Get network output at given index.
107    ///
108    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a2cb7b6ee73a876fc73076a559fa9e955)
109    ///
110    /// # Arguments
111    ///
112    /// * `index` - Output index.
113    pub fn output(&self, index: usize) -> Tensor<'_> {
114        let internal = self.as_ptr();
115        let index = index as std::os::raw::c_int;
116        let tensor_internal = cpp!(unsafe [
117            internal as "const void*",
118            index as "int"
119        ] -> *mut std::ffi::c_void as "void*" {
120            return ((const INetworkDefinition*) internal)->getOutput(index);
121        });
122        Tensor::wrap(tensor_internal)
123    }
124
125    /// Get internal readonly pointer.
126    #[inline(always)]
127    pub fn as_ptr(&self) -> *const std::ffi::c_void {
128        let NetworkDefinition { internal, .. } = *self;
129        internal
130    }
131
132    /// Get internal mutable pointer.
133    #[inline(always)]
134    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
135        let NetworkDefinition { internal, .. } = *self;
136        internal
137    }
138}
139
140impl Drop for NetworkDefinition {
141    fn drop(&mut self) {
142        let internal = self.as_mut_ptr();
143        cpp!(unsafe [
144            internal as "void*"
145        ] {
146            destroy((INetworkDefinition*) internal);
147        });
148    }
149}
150
151/// Specifies immutable properties of [`NetworkDefinition`] expressed at creation time.
152///
153/// [TensorRT documentation of `NetworkDefinitionCreationFlags`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#a77b643e855bcc302b30348276fa36504)
154/// [TensorRT documentation of `NetworkDefinitionCreationFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#aa8f406be96c14b7dbea548cf19f09a08a85b8fdd336af67a4aa147b3430064945)
155#[derive(Copy, Clone)]
156pub enum NetworkDefinitionCreationFlags {
157    None,
158    ExplicitBatchSize,
159}
160
161/// A tensor in a [`NetworkDefinition`].
162///
163/// [TensorRT documenation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_tensor.html)
164pub struct Tensor<'parent> {
165    internal: *mut std::ffi::c_void,
166    _phantom: std::marker::PhantomData<&'parent ()>,
167}
168
169/// Implements [`Send`] for [`Tensor`].
170///
171/// # Safety
172///
173/// The TensorRT API is thread-safe with regards to all operations on [`Tensor`].
174unsafe impl<'parent> Send for Tensor<'parent> {}
175
176/// Implements [`Sync`] for [`Tensor`].
177///
178/// # Safety
179///
180/// The TensorRT API is thread-safe with regards to all operations on [`Tensor`].
181unsafe impl<'parent> Sync for Tensor<'parent> {}
182
183impl<'parent> Tensor<'parent> {
184    /// Wrap internal pointer as [`Tensor`].
185    ///
186    /// # Safety
187    ///
188    /// The pointer must point to a valid `ITensor` object.
189    #[inline]
190    pub(crate) fn wrap(internal: *mut std::ffi::c_void) -> Self {
191        Self {
192            internal,
193            _phantom: Default::default(),
194        }
195    }
196
197    /// Get the tensor name.
198    ///
199    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_tensor.html#a684fd842a172ad300dbb31270fc675a2)
200    pub fn name(&self) -> String {
201        let internal = self.as_ptr();
202        let name = cpp!(unsafe [
203            internal as "const void*"
204        ] -> *const std::os::raw::c_char as "const char*" {
205            return ((const ITensor*) internal)->getName();
206        });
207        // SAFETY: This is safe because:
208        // * The pointer is valid because we just got it from TensorRT.
209        // * The pointer isn't kept after this block (we copy the string instead).
210        unsafe { std::ffi::CStr::from_ptr(name).to_string_lossy().to_string() }
211    }
212
213    /// Set the tensor name.
214    ///
215    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_tensor.html#a44ffc55db1d6e68908859596c4e4ef49)
216    ///
217    /// # Arguments
218    ///
219    /// * `name` - Name to set.
220    pub fn set_name(&mut self, name: &str) {
221        let internal = self.as_mut_ptr();
222        let name_ffi = std::ffi::CString::new(name).unwrap();
223        let name_ptr = name_ffi.as_ptr();
224        cpp!(unsafe [
225            internal as "void*",
226            name_ptr as "const char*"
227        ] {
228            return ((ITensor*) internal)->setName(name_ptr);
229        });
230    }
231
232    /// Get the dimensions of a tensor.
233    ///
234    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_tensor.html#aefa740255768fbe234730577cb24fac9)
235    pub fn get_dimensions(&self) -> Vec<i32> {
236        let internal = self.as_ptr();
237        let mut dims = Vec::with_capacity(MAX_DIMS);
238        let dims_ptr = dims.as_mut_ptr();
239
240        let num_dimensions = cpp!(unsafe [
241            internal as "void*",
242            dims_ptr as "int32_t*"
243        ] -> i32 as "int32_t" {
244            auto dims = ((const ITensor*) internal)->getDimensions();
245            if (dims.nbDims > 0) {
246                for (int i = 0; i < dims.nbDims; ++i) {
247                    dims_ptr[i] = dims.d[i];
248                }
249            }
250            return dims.nbDims;
251        });
252        if num_dimensions > 0 {
253            // Safety: The vec has been initialized up until num_dimensions elements
254            unsafe {
255                dims.set_len(num_dimensions as usize);
256            }
257        }
258        dims
259    }
260
261    /// Get internal readonly pointer.
262    #[inline(always)]
263    pub fn as_ptr(&self) -> *const std::ffi::c_void {
264        let Tensor { internal, .. } = *self;
265        internal
266    }
267
268    /// Get internal mutable pointer.
269    #[inline(always)]
270    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
271        let Tensor { internal, .. } = *self;
272        internal
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use crate::tests::utils::*;
279
280    #[tokio::test]
281    async fn test_network_inputs_and_outputs() {
282        let (_, network) = simple_network!();
283        assert_eq!(network.num_inputs(), 1);
284        assert_eq!(network.num_outputs(), 1);
285        let inputs = network.inputs();
286        let input = inputs.first().unwrap();
287        assert_eq!(input.name(), "X");
288        let outputs = network.outputs();
289        let output = outputs.first().unwrap();
290        assert_eq!(output.name(), "Y");
291    }
292
293    #[tokio::test]
294    async fn test_tensor_set_name() {
295        let (_, network) = simple_network!();
296        network.outputs()[0].set_name("Z");
297        assert_eq!(network.outputs()[0].name(), "Z");
298    }
299}