1#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
10#![warn(missing_debug_implementations)]
11
12use core::ffi::{c_char, c_int, c_void};
13use std::sync::OnceLock;
14
15use baracuda_core::{platform, Library, LoaderError};
16use baracuda_cuda_sys::runtime::cudaStream_t;
17use baracuda_types::CudaStatus;
18
19pub type trtIRuntime_t = *mut c_void;
22pub type trtICudaEngine_t = *mut c_void;
23pub type trtIExecutionContext_t = *mut c_void;
24pub type trtILogger_t = *mut c_void;
25pub type trtIPluginRegistry_t = *mut c_void;
26pub type trtIHostMemory_t = *mut c_void;
27
28#[repr(i32)]
31#[derive(Copy, Clone, Debug, Eq, PartialEq)]
32pub enum trtDataType_t {
33 Float = 0,
34 Half = 1,
35 Int8 = 2,
36 Int32 = 3,
37 Bool = 4,
38 Uint8 = 5,
39 Fp8 = 6,
40 BFloat16 = 7,
41 Int64 = 8,
42 Int4 = 9,
43 Fp4 = 10,
44}
45
46#[repr(i32)]
47#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48pub enum trtTensorIOMode_t {
49 None = 0,
50 Input = 1,
51 Output = 2,
52}
53
54#[repr(i32)]
55#[derive(Copy, Clone, Debug, Eq, PartialEq)]
56pub enum trtSeverity_t {
57 InternalError = 0,
58 Error = 1,
59 Warning = 2,
60 Info = 3,
61 Verbose = 4,
62}
63
64#[repr(i32)]
65#[derive(Copy, Clone, Debug, Eq, PartialEq)]
66pub enum trtExecutionContextAllocationStrategy_t {
67 Static = 0,
68 OnProfileChange = 1,
69 UserManaged = 2,
70}
71
72pub const TRT_MAX_DIMS: usize = 8;
76
77#[repr(C)]
78#[derive(Copy, Clone, Debug)]
79pub struct trtDims_t {
80 pub nb_dims: i32,
81 pub d: [i64; TRT_MAX_DIMS],
82}
83
84#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
90#[repr(transparent)]
91pub struct trtStatus_t(pub i32);
92
93impl trtStatus_t {
94 pub const SUCCESS: Self = Self(0);
95 pub const FAILURE: Self = Self(-1);
96
97 pub const fn is_success(self) -> bool {
98 self.0 == 0
99 }
100}
101
102impl CudaStatus for trtStatus_t {
103 fn code(self) -> i32 {
104 self.0
105 }
106 fn name(self) -> &'static str {
107 match self.0 {
108 0 => "TRT_SUCCESS",
109 -1 => "TRT_FAILURE",
110 _ => "TRT_UNRECOGNIZED",
111 }
112 }
113 fn description(self) -> &'static str {
114 match self.0 {
115 0 => "success",
116 -1 => "TensorRT call failed (check logger output)",
117 _ => "unrecognized TensorRT status code",
118 }
119 }
120 fn is_success(self) -> bool {
121 trtStatus_t::is_success(self)
122 }
123 fn library(self) -> &'static str {
124 "tensorrt"
125 }
126}
127
128pub type PFN_getInferLibVersion = unsafe extern "C" fn() -> i32;
131
132pub type trtLogCallback =
134 unsafe extern "C" fn(severity: trtSeverity_t, msg: *const c_char, user: *mut c_void);
135
136pub type PFN_createInferRuntime =
137 unsafe extern "C" fn(logger: trtILogger_t) -> trtIRuntime_t;
138pub type PFN_destroyInferRuntime = unsafe extern "C" fn(runtime: trtIRuntime_t);
139
140pub type PFN_deserializeCudaEngine = unsafe extern "C" fn(
141 runtime: trtIRuntime_t,
142 blob: *const c_void,
143 size: usize,
144) -> trtICudaEngine_t;
145pub type PFN_destroyCudaEngine = unsafe extern "C" fn(engine: trtICudaEngine_t);
146
147pub type PFN_engineGetNbIOTensors =
148 unsafe extern "C" fn(engine: trtICudaEngine_t) -> i32;
149pub type PFN_engineGetIOTensorName =
150 unsafe extern "C" fn(engine: trtICudaEngine_t, index: i32) -> *const c_char;
151pub type PFN_engineGetTensorIOMode = unsafe extern "C" fn(
152 engine: trtICudaEngine_t,
153 name: *const c_char,
154) -> trtTensorIOMode_t;
155pub type PFN_engineGetTensorDataType = unsafe extern "C" fn(
156 engine: trtICudaEngine_t,
157 name: *const c_char,
158) -> trtDataType_t;
159pub type PFN_engineGetTensorShape =
160 unsafe extern "C" fn(engine: trtICudaEngine_t, name: *const c_char) -> trtDims_t;
161pub type PFN_engineGetTensorBytesPerComponent =
162 unsafe extern "C" fn(engine: trtICudaEngine_t, name: *const c_char) -> i32;
163pub type PFN_engineCreateExecutionContext =
164 unsafe extern "C" fn(engine: trtICudaEngine_t) -> trtIExecutionContext_t;
165pub type PFN_engineCreateExecutionContextWithStrategy = unsafe extern "C" fn(
166 engine: trtICudaEngine_t,
167 strategy: trtExecutionContextAllocationStrategy_t,
168) -> trtIExecutionContext_t;
169pub type PFN_destroyExecutionContext = unsafe extern "C" fn(ctx: trtIExecutionContext_t);
170
171pub type PFN_contextSetInputShape = unsafe extern "C" fn(
172 ctx: trtIExecutionContext_t,
173 name: *const c_char,
174 dims: *const trtDims_t,
175) -> bool;
176pub type PFN_contextGetTensorShape = unsafe extern "C" fn(
177 ctx: trtIExecutionContext_t,
178 name: *const c_char,
179) -> trtDims_t;
180pub type PFN_contextSetTensorAddress = unsafe extern "C" fn(
181 ctx: trtIExecutionContext_t,
182 name: *const c_char,
183 data: *mut c_void,
184) -> bool;
185pub type PFN_contextGetTensorAddress = unsafe extern "C" fn(
186 ctx: trtIExecutionContext_t,
187 name: *const c_char,
188) -> *mut c_void;
189pub type PFN_contextEnqueueV3 =
190 unsafe extern "C" fn(ctx: trtIExecutionContext_t, stream: cudaStream_t) -> bool;
191
192pub type PFN_engineGetName =
193 unsafe extern "C" fn(engine: trtICudaEngine_t) -> *const c_char;
194pub type PFN_engineGetNbOptimizationProfiles =
195 unsafe extern "C" fn(engine: trtICudaEngine_t) -> i32;
196
197pub type PFN_engineSerialize = unsafe extern "C" fn(engine: trtICudaEngine_t) -> trtIHostMemory_t;
198pub type PFN_hostMemoryData = unsafe extern "C" fn(mem: trtIHostMemory_t) -> *mut c_void;
199pub type PFN_hostMemorySize = unsafe extern "C" fn(mem: trtIHostMemory_t) -> usize;
200pub type PFN_hostMemoryDestroy = unsafe extern "C" fn(mem: trtIHostMemory_t);
201
202fn tensorrt_candidates() -> Vec<String> {
205 platform::versioned_library_candidates("nvinfer", &["10", "9", "8"])
207}
208
209macro_rules! trt_fns {
210 ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
211 pub struct TensorRt {
212 lib: Library,
213 $($name: OnceLock<$pfn>,)*
214 }
215 impl core::fmt::Debug for TensorRt {
216 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
217 f.debug_struct("TensorRt").field("lib", &self.lib).finish_non_exhaustive()
218 }
219 }
220 impl TensorRt {
221 $(
222 pub fn $name(&self) -> Result<$pfn, LoaderError> {
223 if let Some(&p) = self.$name.get() { return Ok(p); }
224 let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
225 let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
226 let _ = self.$name.set(p);
227 Ok(p)
228 }
229 )*
230 fn empty(lib: Library) -> Self {
231 Self { lib, $($name: OnceLock::new(),)* }
232 }
233 }
234 };
235}
236
237trt_fns! {
238 get_infer_lib_version as "getInferLibVersion": PFN_getInferLibVersion;
243 create_infer_runtime as "createInferRuntime_INTERNAL": PFN_createInferRuntime;
244 destroy_infer_runtime as "destroyInferRuntime": PFN_destroyInferRuntime;
245 deserialize_cuda_engine as "trtRuntimeDeserializeCudaEngine": PFN_deserializeCudaEngine;
246 destroy_cuda_engine as "trtCudaEngineDestroy": PFN_destroyCudaEngine;
247 engine_get_nb_io_tensors as "trtCudaEngineGetNbIOTensors": PFN_engineGetNbIOTensors;
248 engine_get_io_tensor_name as "trtCudaEngineGetIOTensorName": PFN_engineGetIOTensorName;
249 engine_get_tensor_io_mode as "trtCudaEngineGetTensorIOMode": PFN_engineGetTensorIOMode;
250 engine_get_tensor_data_type as "trtCudaEngineGetTensorDataType": PFN_engineGetTensorDataType;
251 engine_get_tensor_shape as "trtCudaEngineGetTensorShape": PFN_engineGetTensorShape;
252 engine_get_tensor_bytes_per_component as "trtCudaEngineGetTensorBytesPerComponent": PFN_engineGetTensorBytesPerComponent;
253 engine_create_execution_context as "trtCudaEngineCreateExecutionContext": PFN_engineCreateExecutionContext;
254 engine_create_execution_context_with_strategy as "trtCudaEngineCreateExecutionContextWithStrategy": PFN_engineCreateExecutionContextWithStrategy;
255 destroy_execution_context as "trtExecutionContextDestroy": PFN_destroyExecutionContext;
256 context_set_input_shape as "trtExecutionContextSetInputShape": PFN_contextSetInputShape;
257 context_get_tensor_shape as "trtExecutionContextGetTensorShape": PFN_contextGetTensorShape;
258 context_set_tensor_address as "trtExecutionContextSetTensorAddress": PFN_contextSetTensorAddress;
259 context_get_tensor_address as "trtExecutionContextGetTensorAddress": PFN_contextGetTensorAddress;
260 context_enqueue_v3 as "trtExecutionContextEnqueueV3": PFN_contextEnqueueV3;
261 engine_get_name as "trtCudaEngineGetName": PFN_engineGetName;
262 engine_get_nb_optimization_profiles as "trtCudaEngineGetNbOptimizationProfiles": PFN_engineGetNbOptimizationProfiles;
263 engine_serialize as "trtCudaEngineSerialize": PFN_engineSerialize;
264 host_memory_data as "trtHostMemoryData": PFN_hostMemoryData;
265 host_memory_size as "trtHostMemorySize": PFN_hostMemorySize;
266 host_memory_destroy as "trtHostMemoryDestroy": PFN_hostMemoryDestroy;
267}
268
269pub fn tensorrt() -> Result<&'static TensorRt, LoaderError> {
270 static TRT: OnceLock<TensorRt> = OnceLock::new();
271 if let Some(c) = TRT.get() {
272 return Ok(c);
273 }
274 let candidates: Vec<&'static str> = tensorrt_candidates()
275 .into_iter()
276 .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
277 .collect();
278 let candidates_leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
279 let lib = Library::open("nvinfer", candidates_leaked)?;
280 let c = TensorRt::empty(lib);
281 let _ = TRT.set(c);
282 Ok(TRT.get().expect("OnceLock set or lost race"))
283}
284
285#[doc(hidden)]
288pub const _UNUSED_C_INT_MARKER: c_int = 0;