async_tensorrt/ffi/sync/
engine.rs

1use cpp::cpp;
2
3use async_cuda::device::DeviceId;
4use async_cuda::ffi::device::Device;
5
6use crate::error::last_error;
7use crate::ffi::memory::HostBuffer;
8use crate::ffi::result;
9use crate::ffi::sync::runtime::Runtime;
10
11type Result<T> = std::result::Result<T, crate::error::Error>;
12
13/// Synchronous implementation of [`crate::Engine`].
14///
15/// Refer to [`crate::Engine`] for documentation.
16pub struct Engine {
17    internal: *mut std::ffi::c_void,
18    runtime: Runtime,
19}
20
21/// Implements [`Send`] for [`Engine`].
22///
23/// # Safety
24///
25/// The TensorRT API is thread-safe with regards to all operations on [`Engine`].
26unsafe impl Send for Engine {}
27
28/// Implements [`Sync`] for [`Engine`].
29///
30/// # Safety
31///
32/// The TensorRT API is thread-safe with regards to all operations on [`Engine`].
33unsafe impl Sync for Engine {}
34
35impl Engine {
36    #[inline]
37    pub(crate) fn wrap(internal: *mut std::ffi::c_void, runtime: Runtime) -> Self {
38        Engine { internal, runtime }
39    }
40
41    pub fn serialize(&self) -> Result<HostBuffer> {
42        let internal = self.as_ptr();
43        let internal_buffer = cpp!(unsafe [
44            internal as "const void*"
45        ] -> *mut std::ffi::c_void as "void*" {
46            return (void*) ((const ICudaEngine*) internal)->serialize();
47        });
48        result!(internal_buffer, HostBuffer::wrap(internal_buffer))
49    }
50
51    pub fn num_io_tensors(&self) -> usize {
52        let internal = self.as_ptr();
53        let num_io_tensors = cpp!(unsafe [
54            internal as "const void*"
55        ] -> std::os::raw::c_int as "int" {
56            return ((const ICudaEngine*) internal)->getNbIOTensors();
57        });
58        num_io_tensors as usize
59    }
60
61    pub fn io_tensor_name(&self, io_tensor_index: usize) -> String {
62        let internal = self.as_ptr();
63        let io_tensor_index = io_tensor_index as std::os::raw::c_int;
64        let io_tensor_name_ptr = cpp!(unsafe [
65            internal as "const void*",
66            io_tensor_index as "int"
67        ] -> *const std::os::raw::c_char as "const char*" {
68            return ((const ICudaEngine*) internal)->getIOTensorName(io_tensor_index);
69        });
70
71        // SAFETY: This is safe because:
72        // * The pointer is valid because we just got it from TensorRT.
73        // * The pointer isn't kept after this block (we copy the string instead).
74        unsafe {
75            std::ffi::CStr::from_ptr(io_tensor_name_ptr)
76                .to_string_lossy()
77                .to_string()
78        }
79    }
80
81    pub fn tensor_shape(&self, tensor_name: &str) -> Vec<usize> {
82        let internal = self.as_ptr();
83        let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap();
84        let tensor_name_ptr = tensor_name_cstr.as_ptr();
85        let tensor_dimensions = cpp!(unsafe [
86            internal as "const void*",
87            tensor_name_ptr as "const char*"
88        ] -> Dims as "Dims64" {
89            #if NV_TENSORRT_MAJOR >= 10
90            return ((const ICudaEngine*) internal)->getTensorShape(tensor_name_ptr);
91            #else
92            Dims32 dims32 = ((const ICudaEngine*) internal)->getTensorShape(tensor_name_ptr);
93            Dims64 dims64;
94            dims64.nbDims = dims32.nbDims;
95            for (int i = 0; i < dims32.nbDims; i++) {
96                dims64.d[i] = dims32.d[i];
97            }
98            return dims64;
99            #endif
100        });
101
102        let mut dimensions = Vec::with_capacity(tensor_dimensions.nbDims as usize);
103        for i in 0..tensor_dimensions.nbDims {
104            dimensions.push(tensor_dimensions.d[i as usize] as usize);
105        }
106
107        dimensions
108    }
109
110    pub fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode {
111        let internal = self.as_ptr();
112        let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap();
113        let tensor_name_ptr = tensor_name_cstr.as_ptr();
114        let tensor_io_mode = cpp!(unsafe [
115            internal as "const void*",
116            tensor_name_ptr as "const char*"
117        ] -> i32 as "std::int32_t" {
118            return (std::int32_t) ((const ICudaEngine*) internal)->getTensorIOMode(tensor_name_ptr);
119        });
120        TensorIoMode::from_i32(tensor_io_mode)
121    }
122
123    #[inline(always)]
124    pub fn as_ptr(&self) -> *const std::ffi::c_void {
125        let Engine { internal, .. } = *self;
126        internal
127    }
128
129    #[inline(always)]
130    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
131        let Engine { internal, .. } = *self;
132        internal
133    }
134
135    #[inline(always)]
136    pub fn device(&self) -> DeviceId {
137        self.runtime.device()
138    }
139}
140
141impl Drop for Engine {
142    fn drop(&mut self) {
143        Device::set_or_panic(self.runtime.device());
144        let Engine { internal, .. } = *self;
145        cpp!(unsafe [
146            internal as "void*"
147        ] {
148            destroy((ICudaEngine*) internal);
149        });
150    }
151}
152
153/// Synchronous implementation of [`crate::ExecutionContext`].
154///
155/// Refer to [`crate::ExecutionContext`] for documentation.
156pub struct ExecutionContext<'engine> {
157    internal: *mut std::ffi::c_void,
158    device: DeviceId,
159    _parent: Option<std::sync::Arc<Engine>>,
160    _phantom: std::marker::PhantomData<&'engine ()>,
161}
162
163/// Implements [`Send`] for `ExecutionContext`.
164///
165/// # Safety
166///
167/// The TensorRT API is thread-safe with regards to all operations on [`ExecutionContext`].
168unsafe impl<'engine> Send for ExecutionContext<'engine> {}
169
170/// Implements [`Sync`] for `ExecutionContext`.
171///
172/// # Safety
173///
174/// The TensorRT API is thread-safe with regards to all operations on [`ExecutionContext`].
175unsafe impl<'engine> Sync for ExecutionContext<'engine> {}
176
177impl ExecutionContext<'static> {
178    pub fn from_engine(mut engine: Engine) -> Result<Self> {
179        let internal = unsafe { Self::new_internal(&mut engine) };
180        result!(
181            internal,
182            Self {
183                internal,
184                device: engine.device(),
185                _parent: Some(std::sync::Arc::new(engine)),
186                _phantom: Default::default(),
187            }
188        )
189    }
190
191    pub fn from_engine_many(mut engine: Engine, num: usize) -> Result<Vec<Self>> {
192        let mut internals = Vec::with_capacity(num);
193        for _ in 0..num {
194            internals.push(unsafe { Self::new_internal(&mut engine) });
195        }
196        let device = engine.device();
197        let parent = std::sync::Arc::new(engine);
198        internals
199            .into_iter()
200            .map(|internal| {
201                result!(
202                    internal,
203                    Self {
204                        internal,
205                        device,
206                        _parent: Some(parent.clone()),
207                        _phantom: Default::default(),
208                    }
209                )
210            })
211            .collect()
212    }
213}
214
215impl<'engine> ExecutionContext<'engine> {
216    pub fn new(engine: &'engine mut Engine) -> Result<Self> {
217        let internal = unsafe { Self::new_internal(engine) };
218        result!(
219            internal,
220            Self {
221                internal,
222                device: engine.device(),
223                _parent: None,
224                _phantom: Default::default(),
225            }
226        )
227    }
228
229    pub fn enqueue<T: Copy>(
230        &mut self,
231        io_tensors: &mut std::collections::HashMap<
232            &str,
233            &mut async_cuda::ffi::memory::DeviceBuffer<T>,
234        >,
235        stream: &async_cuda::ffi::stream::Stream,
236    ) -> Result<()> {
237        let internal = self.as_mut_ptr();
238        for (tensor_name, buffer) in io_tensors {
239            unsafe {
240                self.set_tensor_address(tensor_name, buffer)?;
241            }
242        }
243        let stream_ptr = stream.as_internal().as_ptr();
244        let success = cpp!(unsafe [
245            internal as "void*",
246            stream_ptr as "const void*"
247        ] -> bool as "bool" {
248            return ((IExecutionContext*) internal)->enqueueV3((cudaStream_t) stream_ptr);
249        });
250        if success {
251            Ok(())
252        } else {
253            Err(last_error())
254        }
255    }
256
257    #[inline(always)]
258    pub fn as_ptr(&self) -> *const std::ffi::c_void {
259        let ExecutionContext { internal, .. } = *self;
260        internal
261    }
262
263    #[inline(always)]
264    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
265        let ExecutionContext { internal, .. } = *self;
266        internal
267    }
268
269    #[inline(always)]
270    pub fn device(&self) -> DeviceId {
271        self.device
272    }
273
274    unsafe fn new_internal(engine: &mut Engine) -> *mut std::ffi::c_void {
275        Device::set_or_panic(engine.device());
276        let internal_engine = engine.as_mut_ptr();
277        let internal = cpp!(unsafe [
278            internal_engine as "void*"
279        ] -> *mut std::ffi::c_void as "void*" {
280            return (void*) ((ICudaEngine*) internal_engine)->createExecutionContext();
281        });
282        internal
283    }
284
285    unsafe fn set_tensor_address<T: Copy>(
286        &mut self,
287        tensor_name: &str,
288        buffer: &mut async_cuda::ffi::memory::DeviceBuffer<T>,
289    ) -> Result<()> {
290        let internal = self.as_mut_ptr();
291        let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap();
292        let tensor_name_ptr = tensor_name_cstr.as_ptr();
293        let buffer_ptr = buffer.as_mut_internal().as_mut_ptr();
294        let success = cpp!(unsafe [
295            internal as "const void*",
296            tensor_name_ptr as "const char*",
297            buffer_ptr as "void*"
298        ] -> bool as "bool" {
299            return ((IExecutionContext*) internal)->setTensorAddress(
300                tensor_name_ptr,
301                buffer_ptr
302            );
303        });
304        if success {
305            Ok(())
306        } else {
307            Err(last_error())
308        }
309    }
310}
311
312impl<'engine> Drop for ExecutionContext<'engine> {
313    fn drop(&mut self) {
314        Device::set_or_panic(self.device);
315        let ExecutionContext { internal, .. } = *self;
316        cpp!(unsafe [
317            internal as "void*"
318        ] {
319            destroy((IExecutionContext*) internal);
320        });
321    }
322}
323
324/// Tensor IO mode.
325#[derive(Debug, Copy, Clone, PartialEq, Eq)]
326pub enum TensorIoMode {
327    None,
328    Input,
329    Output,
330}
331
332impl TensorIoMode {
333    /// Create [`IoTensorMode`] from `value`.
334    ///
335    /// # Arguments
336    ///
337    /// * `value` - Integer representation of IO mode.
338    fn from_i32(value: i32) -> Self {
339        match value {
340            1 => TensorIoMode::Input,
341            2 => TensorIoMode::Output,
342            _ => TensorIoMode::None,
343        }
344    }
345}
346
347/// Internal representation of the `Dims64` struct in TensorRT.
348#[repr(C)]
349#[derive(Debug, Copy, Clone)]
350#[allow(non_snake_case)]
351struct Dims {
352    pub nbDims: i32,
353    pub d: [i64; 8usize],
354}