async_tensorrt/ffi/
network.rs1use cpp::cpp;
2
3use crate::ffi::parser::Parser;
4
5const MAX_DIMS: usize = 8;
7
8pub struct NetworkDefinition {
12 internal: *mut std::ffi::c_void,
13 pub(crate) _parser: Option<Parser>,
14}
15
16unsafe impl Send for NetworkDefinition {}
22
23unsafe impl Sync for NetworkDefinition {}
29
30impl NetworkDefinition {
31 pub(crate) fn wrap(internal: *mut std::ffi::c_void) -> Self {
37 Self {
38 internal,
39 _parser: None,
40 }
41 }
42
43 pub fn inputs(&self) -> Vec<Tensor> {
45 let mut inputs = Vec::with_capacity(self.num_inputs());
46 for index in 0..self.num_inputs() {
47 inputs.push(self.input(index));
48 }
49 inputs
50 }
51
52 pub fn num_inputs(&self) -> usize {
56 let internal = self.as_ptr();
57 let num_inputs = cpp!(unsafe [
58 internal as "const void*"
59 ] -> std::os::raw::c_int as "int" {
60 return ((const INetworkDefinition*) internal)->getNbInputs();
61 });
62 num_inputs as usize
63 }
64
65 pub fn input(&self, index: usize) -> Tensor<'_> {
73 let internal = self.as_ptr();
74 let index = index as std::os::raw::c_int;
75 let tensor_internal = cpp!(unsafe [
76 internal as "const void*",
77 index as "int"
78 ] -> *mut std::ffi::c_void as "void*" {
79 return ((const INetworkDefinition*) internal)->getInput(index);
80 });
81 Tensor::wrap(tensor_internal)
82 }
83
84 pub fn outputs(&self) -> Vec<Tensor<'_>> {
86 let mut outputs = Vec::with_capacity(self.num_outputs());
87 for index in 0..self.num_outputs() {
88 outputs.push(self.output(index));
89 }
90 outputs
91 }
92
93 pub fn num_outputs(&self) -> usize {
97 let internal = self.as_ptr();
98 let num_outputs = cpp!(unsafe [
99 internal as "const void*"
100 ] -> std::os::raw::c_int as "int" {
101 return ((const INetworkDefinition*) internal)->getNbOutputs();
102 });
103 num_outputs as usize
104 }
105
106 pub fn output(&self, index: usize) -> Tensor<'_> {
114 let internal = self.as_ptr();
115 let index = index as std::os::raw::c_int;
116 let tensor_internal = cpp!(unsafe [
117 internal as "const void*",
118 index as "int"
119 ] -> *mut std::ffi::c_void as "void*" {
120 return ((const INetworkDefinition*) internal)->getOutput(index);
121 });
122 Tensor::wrap(tensor_internal)
123 }
124
125 #[inline(always)]
127 pub fn as_ptr(&self) -> *const std::ffi::c_void {
128 let NetworkDefinition { internal, .. } = *self;
129 internal
130 }
131
132 #[inline(always)]
134 pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
135 let NetworkDefinition { internal, .. } = *self;
136 internal
137 }
138}
139
140impl Drop for NetworkDefinition {
141 fn drop(&mut self) {
142 let internal = self.as_mut_ptr();
143 cpp!(unsafe [
144 internal as "void*"
145 ] {
146 destroy((INetworkDefinition*) internal);
147 });
148 }
149}
150
151#[derive(Copy, Clone)]
156pub enum NetworkDefinitionCreationFlags {
157 None,
158 ExplicitBatchSize,
159}
160
161pub struct Tensor<'parent> {
165 internal: *mut std::ffi::c_void,
166 _phantom: std::marker::PhantomData<&'parent ()>,
167}
168
169unsafe impl<'parent> Send for Tensor<'parent> {}
175
176unsafe impl<'parent> Sync for Tensor<'parent> {}
182
183impl<'parent> Tensor<'parent> {
184 #[inline]
190 pub(crate) fn wrap(internal: *mut std::ffi::c_void) -> Self {
191 Self {
192 internal,
193 _phantom: Default::default(),
194 }
195 }
196
197 pub fn name(&self) -> String {
201 let internal = self.as_ptr();
202 let name = cpp!(unsafe [
203 internal as "const void*"
204 ] -> *const std::os::raw::c_char as "const char*" {
205 return ((const ITensor*) internal)->getName();
206 });
207 unsafe { std::ffi::CStr::from_ptr(name).to_string_lossy().to_string() }
211 }
212
213 pub fn set_name(&mut self, name: &str) {
221 let internal = self.as_mut_ptr();
222 let name_ffi = std::ffi::CString::new(name).unwrap();
223 let name_ptr = name_ffi.as_ptr();
224 cpp!(unsafe [
225 internal as "void*",
226 name_ptr as "const char*"
227 ] {
228 return ((ITensor*) internal)->setName(name_ptr);
229 });
230 }
231
232 pub fn get_dimensions(&self) -> Vec<i32> {
236 let internal = self.as_ptr();
237 let mut dims = Vec::with_capacity(MAX_DIMS);
238 let dims_ptr = dims.as_mut_ptr();
239
240 let num_dimensions = cpp!(unsafe [
241 internal as "void*",
242 dims_ptr as "int32_t*"
243 ] -> i32 as "int32_t" {
244 auto dims = ((const ITensor*) internal)->getDimensions();
245 if (dims.nbDims > 0) {
246 for (int i = 0; i < dims.nbDims; ++i) {
247 dims_ptr[i] = dims.d[i];
248 }
249 }
250 return dims.nbDims;
251 });
252 if num_dimensions > 0 {
253 unsafe {
255 dims.set_len(num_dimensions as usize);
256 }
257 }
258 dims
259 }
260
261 #[inline(always)]
263 pub fn as_ptr(&self) -> *const std::ffi::c_void {
264 let Tensor { internal, .. } = *self;
265 internal
266 }
267
268 #[inline(always)]
270 pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
271 let Tensor { internal, .. } = *self;
272 internal
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use crate::tests::utils::*;
279
280 #[tokio::test]
281 async fn test_network_inputs_and_outputs() {
282 let (_, network) = simple_network!();
283 assert_eq!(network.num_inputs(), 1);
284 assert_eq!(network.num_outputs(), 1);
285 let inputs = network.inputs();
286 let input = inputs.first().unwrap();
287 assert_eq!(input.name(), "X");
288 let outputs = network.outputs();
289 let output = outputs.first().unwrap();
290 assert_eq!(output.name(), "Y");
291 }
292
293 #[tokio::test]
294 async fn test_tensor_set_name() {
295 let (_, network) = simple_network!();
296 network.outputs()[0].set_name("Z");
297 assert_eq!(network.outputs()[0].name(), "Z");
298 }
299}