async_tensorrt/ffi/sync/
runtime.rs

1use cpp::cpp;
2
3use async_cuda::device::DeviceId;
4use async_cuda::ffi::device::Device;
5
6use crate::ffi::memory::HostBuffer;
7use crate::ffi::result;
8use crate::ffi::sync::engine::Engine;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12/// Synchronous implementation of [`crate::Runtime`].
13///
14/// Refer to [`crate::Runtime`] for documentation.
15pub struct Runtime {
16    addr: *mut std::ffi::c_void,
17    device: DeviceId,
18}
19
20/// Implements [`Send`] for [`Runtime`].
21///
22/// # Safety
23///
24/// The TensorRT API is thread-safe with regards to all operations on [`Runtime`].
25unsafe impl Send for Runtime {}
26
27/// Implements [`Sync`] for [`Runtime`].
28///
29/// # Safety
30///
31/// The TensorRT API is thread-safe with regards to all operations on [`Runtime`].
32unsafe impl Sync for Runtime {}
33
34impl Runtime {
35    pub fn new() -> Self {
36        let device = Device::get_or_panic();
37        let addr = cpp!(unsafe [] -> *mut std::ffi::c_void as "void*" {
38            return createInferRuntime(GLOBAL_LOGGER);
39        });
40        Runtime { addr, device }
41    }
42
43    pub fn deserialize_engine_from_plan(self, plan: &HostBuffer) -> Result<Engine> {
44        unsafe {
45            // SAFETY: Since we have a reference to the buffer for the duration of this call, we
46            // know the internal pointers will be and remain valid until the end of the block.
47            self.deserialize_engine_raw(plan.data(), plan.size())
48        }
49    }
50
51    pub fn deserialize_engine(self, buffer: &[u8]) -> Result<Engine> {
52        unsafe {
53            // SAFETY: Since we have a reference to the slice for the duration of this call, we
54            // know the internal pointers will be and remain valid until the end of the block.
55            self.deserialize_engine_raw(buffer.as_ptr() as *const std::ffi::c_void, buffer.len())
56        }
57    }
58
59    /// Deserialize an engine from a buffer.
60    ///
61    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_runtime.html#ad0dc765e77cab99bfad901e47216a767)
62    ///
63    /// # Safety
64    ///
65    /// Both provided pointers must be valid pointers.
66    ///
67    /// # Arguments
68    ///
69    /// * `buffer_ptr` - Pointer to buffer to read from.
70    /// * `buffer_size` - Size of buffer to read from.
71    unsafe fn deserialize_engine_raw(
72        mut self,
73        buffer_ptr: *const std::ffi::c_void,
74        buffer_size: usize,
75    ) -> Result<Engine> {
76        Device::set(self.device)?;
77        let internal = self.as_mut_ptr();
78        let internal_engine = cpp!(unsafe [
79            internal as "void*",
80            buffer_ptr as "const void*",
81            buffer_size as "std::size_t"
82        ] -> *mut std::ffi::c_void as "void*" {
83            return ((IRuntime*) internal)->deserializeCudaEngine(buffer_ptr, buffer_size);
84        });
85        result!(internal_engine, Engine::wrap(internal_engine, self))
86    }
87
88    #[inline(always)]
89    pub fn as_ptr(&self) -> *const std::ffi::c_void {
90        self.addr
91    }
92
93    #[inline(always)]
94    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
95        self.addr
96    }
97
98    #[inline(always)]
99    pub fn device(&self) -> DeviceId {
100        self.device
101    }
102}
103
104impl Drop for Runtime {
105    fn drop(&mut self) {
106        Device::set_or_panic(self.device);
107        let internal = self.as_mut_ptr();
108        cpp!(unsafe [
109            internal as "void*"
110        ] {
111            destroy((IRuntime*) internal);
112        });
113    }
114}
115
116impl Default for Runtime {
117    fn default() -> Self {
118        Runtime::new()
119    }
120}