1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
use cpp::cpp;
use async_cuda::device::DeviceId;
use async_cuda::ffi::device::Device;
use crate::ffi::memory::HostBuffer;
use crate::ffi::result;
use crate::ffi::sync::engine::Engine;
type Result<T> = std::result::Result<T, crate::error::Error>;
/// Synchronous implementation of [`crate::Runtime`].
///
/// Refer to [`crate::Runtime`] for documentation.
pub struct Runtime {
addr: *mut std::ffi::c_void,
device: DeviceId,
}
/// Implements [`Send`] for [`Runtime`].
///
/// # Safety
///
/// The TensorRT API is thread-safe with regards to all operations on [`Runtime`].
unsafe impl Send for Runtime {}
/// Implements [`Sync`] for [`Runtime`].
///
/// # Safety
///
/// The TensorRT API is thread-safe with regards to all operations on [`Runtime`].
unsafe impl Sync for Runtime {}
impl Runtime {
pub fn new() -> Self {
let device = Device::get_or_panic();
let addr = cpp!(unsafe [] -> *mut std::ffi::c_void as "void*" {
return createInferRuntime(GLOBAL_LOGGER);
});
Runtime { addr, device }
}
pub fn deserialize_engine_from_plan(self, plan: &HostBuffer) -> Result<Engine> {
unsafe {
// SAFETY: Since we have a reference to the buffer for the duration of this call, we
// know the internal pointers will be and remain valid until the end of the block.
self.deserialize_engine_raw(plan.data(), plan.size())
}
}
pub fn deserialize_engine(self, buffer: &[u8]) -> Result<Engine> {
unsafe {
// SAFETY: Since we have a reference to the slice for the duration of this call, we
// know the internal pointers will be and remain valid until the end of the block.
self.deserialize_engine_raw(buffer.as_ptr() as *const std::ffi::c_void, buffer.len())
}
}
/// Deserialize an engine from a buffer.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_runtime.html#ad0dc765e77cab99bfad901e47216a767)
///
/// # Safety
///
/// Both provided pointers must be valid pointers.
///
/// # Arguments
///
/// * `buffer_ptr` - Pointer to buffer to read from.
/// * `buffer_size` - Size of buffer to read from.
unsafe fn deserialize_engine_raw(
mut self,
buffer_ptr: *const std::ffi::c_void,
buffer_size: usize,
) -> Result<Engine> {
Device::set(self.device)?;
let internal = self.as_mut_ptr();
let internal_engine = cpp!(unsafe [
internal as "void*",
buffer_ptr as "const void*",
buffer_size as "std::size_t"
] -> *mut std::ffi::c_void as "void*" {
return ((IRuntime*) internal)->deserializeCudaEngine(buffer_ptr, buffer_size);
});
result!(internal_engine, Engine::wrap(internal_engine, self))
}
#[inline(always)]
pub fn as_ptr(&self) -> *const std::ffi::c_void {
self.addr
}
#[inline(always)]
pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
self.addr
}
#[inline(always)]
pub fn device(&self) -> DeviceId {
self.device
}
}
impl Drop for Runtime {
fn drop(&mut self) {
Device::set_or_panic(self.device);
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*"
] {
destroy((IRuntime*) internal);
});
}
}
impl Default for Runtime {
fn default() -> Self {
Runtime::new()
}
}