Skip to main content

baracuda_tensorrt_sys/
lib.rs

1//! Raw FFI + dynamic loader for NVIDIA TensorRT (C API surface).
2//!
3//! TensorRT's native public API is C++; NVIDIA ships a partial C-ABI surface
4//! suitable for language bindings in `NvInferRuntimeCAPI.h` (TRT 10+). This
5//! crate wraps that surface for runtime deserialization and inference. The
6//! builder side of TensorRT remains C++-only; use the TRT `trtexec` tool or
7//! the Python bindings to produce serialized engines, then load them here.
8
9#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
10#![warn(missing_debug_implementations)]
11
12use core::ffi::{c_char, c_int, c_void};
13use std::sync::OnceLock;
14
15use baracuda_core::{platform, Library, LoaderError};
16use baracuda_cuda_sys::runtime::cudaStream_t;
17use baracuda_types::CudaStatus;
18
19// ---- opaque handles ------------------------------------------------------
20
21pub type trtIRuntime_t = *mut c_void;
22pub type trtICudaEngine_t = *mut c_void;
23pub type trtIExecutionContext_t = *mut c_void;
24pub type trtILogger_t = *mut c_void;
25pub type trtIPluginRegistry_t = *mut c_void;
26pub type trtIHostMemory_t = *mut c_void;
27
28// ---- enums ---------------------------------------------------------------
29
30#[repr(i32)]
31#[derive(Copy, Clone, Debug, Eq, PartialEq)]
32pub enum trtDataType_t {
33    Float = 0,
34    Half = 1,
35    Int8 = 2,
36    Int32 = 3,
37    Bool = 4,
38    Uint8 = 5,
39    Fp8 = 6,
40    BFloat16 = 7,
41    Int64 = 8,
42    Int4 = 9,
43    Fp4 = 10,
44}
45
46#[repr(i32)]
47#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48pub enum trtTensorIOMode_t {
49    None = 0,
50    Input = 1,
51    Output = 2,
52}
53
54#[repr(i32)]
55#[derive(Copy, Clone, Debug, Eq, PartialEq)]
56pub enum trtSeverity_t {
57    InternalError = 0,
58    Error = 1,
59    Warning = 2,
60    Info = 3,
61    Verbose = 4,
62}
63
64#[repr(i32)]
65#[derive(Copy, Clone, Debug, Eq, PartialEq)]
66pub enum trtExecutionContextAllocationStrategy_t {
67    Static = 0,
68    OnProfileChange = 1,
69    UserManaged = 2,
70}
71
72// ---- Dims container ------------------------------------------------------
73
74/// Analog of `nvinfer1::Dims` — up to 8 dimensions.
75pub const TRT_MAX_DIMS: usize = 8;
76
77#[repr(C)]
78#[derive(Copy, Clone, Debug)]
79pub struct trtDims_t {
80    pub nb_dims: i32,
81    pub d: [i64; TRT_MAX_DIMS],
82}
83
84// ---- status --------------------------------------------------------------
85
86/// TensorRT C API returns `bool` (0/1) or `int32_t` status codes depending on the
87/// function. We provide a thin `trtStatus_t` newtype for the error-reporting
88/// subset so it implements [`CudaStatus`].
89#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
90#[repr(transparent)]
91pub struct trtStatus_t(pub i32);
92
93impl trtStatus_t {
94    pub const SUCCESS: Self = Self(0);
95    pub const FAILURE: Self = Self(-1);
96
97    pub const fn is_success(self) -> bool {
98        self.0 == 0
99    }
100}
101
102impl CudaStatus for trtStatus_t {
103    fn code(self) -> i32 {
104        self.0
105    }
106    fn name(self) -> &'static str {
107        match self.0 {
108            0 => "TRT_SUCCESS",
109            -1 => "TRT_FAILURE",
110            _ => "TRT_UNRECOGNIZED",
111        }
112    }
113    fn description(self) -> &'static str {
114        match self.0 {
115            0 => "success",
116            -1 => "TensorRT call failed (check logger output)",
117            _ => "unrecognized TensorRT status code",
118        }
119    }
120    fn is_success(self) -> bool {
121        trtStatus_t::is_success(self)
122    }
123    fn library(self) -> &'static str {
124        "tensorrt"
125    }
126}
127
128// ---- function-pointer types ----------------------------------------------
129
130pub type PFN_getInferLibVersion = unsafe extern "C" fn() -> i32;
131
132/// Logger callback signature (matches `nvinfer1::ILogger::log`).
133pub type trtLogCallback =
134    unsafe extern "C" fn(severity: trtSeverity_t, msg: *const c_char, user: *mut c_void);
135
136pub type PFN_createInferRuntime =
137    unsafe extern "C" fn(logger: trtILogger_t) -> trtIRuntime_t;
138pub type PFN_destroyInferRuntime = unsafe extern "C" fn(runtime: trtIRuntime_t);
139
140pub type PFN_deserializeCudaEngine = unsafe extern "C" fn(
141    runtime: trtIRuntime_t,
142    blob: *const c_void,
143    size: usize,
144) -> trtICudaEngine_t;
145pub type PFN_destroyCudaEngine = unsafe extern "C" fn(engine: trtICudaEngine_t);
146
147pub type PFN_engineGetNbIOTensors =
148    unsafe extern "C" fn(engine: trtICudaEngine_t) -> i32;
149pub type PFN_engineGetIOTensorName =
150    unsafe extern "C" fn(engine: trtICudaEngine_t, index: i32) -> *const c_char;
151pub type PFN_engineGetTensorIOMode = unsafe extern "C" fn(
152    engine: trtICudaEngine_t,
153    name: *const c_char,
154) -> trtTensorIOMode_t;
155pub type PFN_engineGetTensorDataType = unsafe extern "C" fn(
156    engine: trtICudaEngine_t,
157    name: *const c_char,
158) -> trtDataType_t;
159pub type PFN_engineGetTensorShape =
160    unsafe extern "C" fn(engine: trtICudaEngine_t, name: *const c_char) -> trtDims_t;
161pub type PFN_engineGetTensorBytesPerComponent =
162    unsafe extern "C" fn(engine: trtICudaEngine_t, name: *const c_char) -> i32;
163pub type PFN_engineCreateExecutionContext =
164    unsafe extern "C" fn(engine: trtICudaEngine_t) -> trtIExecutionContext_t;
165pub type PFN_engineCreateExecutionContextWithStrategy = unsafe extern "C" fn(
166    engine: trtICudaEngine_t,
167    strategy: trtExecutionContextAllocationStrategy_t,
168) -> trtIExecutionContext_t;
169pub type PFN_destroyExecutionContext = unsafe extern "C" fn(ctx: trtIExecutionContext_t);
170
171pub type PFN_contextSetInputShape = unsafe extern "C" fn(
172    ctx: trtIExecutionContext_t,
173    name: *const c_char,
174    dims: *const trtDims_t,
175) -> bool;
176pub type PFN_contextGetTensorShape = unsafe extern "C" fn(
177    ctx: trtIExecutionContext_t,
178    name: *const c_char,
179) -> trtDims_t;
180pub type PFN_contextSetTensorAddress = unsafe extern "C" fn(
181    ctx: trtIExecutionContext_t,
182    name: *const c_char,
183    data: *mut c_void,
184) -> bool;
185pub type PFN_contextGetTensorAddress = unsafe extern "C" fn(
186    ctx: trtIExecutionContext_t,
187    name: *const c_char,
188) -> *mut c_void;
189pub type PFN_contextEnqueueV3 =
190    unsafe extern "C" fn(ctx: trtIExecutionContext_t, stream: cudaStream_t) -> bool;
191
192pub type PFN_engineGetName =
193    unsafe extern "C" fn(engine: trtICudaEngine_t) -> *const c_char;
194pub type PFN_engineGetNbOptimizationProfiles =
195    unsafe extern "C" fn(engine: trtICudaEngine_t) -> i32;
196
197pub type PFN_engineSerialize = unsafe extern "C" fn(engine: trtICudaEngine_t) -> trtIHostMemory_t;
198pub type PFN_hostMemoryData = unsafe extern "C" fn(mem: trtIHostMemory_t) -> *mut c_void;
199pub type PFN_hostMemorySize = unsafe extern "C" fn(mem: trtIHostMemory_t) -> usize;
200pub type PFN_hostMemoryDestroy = unsafe extern "C" fn(mem: trtIHostMemory_t);
201
202// ---- loader --------------------------------------------------------------
203
204fn tensorrt_candidates() -> Vec<String> {
205    // TensorRT 10 ships libnvinfer.so.10 / nvinfer_10.dll; 8 uses "8".
206    platform::versioned_library_candidates("nvinfer", &["10", "9", "8"])
207}
208
209macro_rules! trt_fns {
210    ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
211        pub struct TensorRt {
212            lib: Library,
213            $($name: OnceLock<$pfn>,)*
214        }
215        impl core::fmt::Debug for TensorRt {
216            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
217                f.debug_struct("TensorRt").field("lib", &self.lib).finish_non_exhaustive()
218            }
219        }
220        impl TensorRt {
221            $(
222                pub fn $name(&self) -> Result<$pfn, LoaderError> {
223                    if let Some(&p) = self.$name.get() { return Ok(p); }
224                    let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
225                    let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
226                    let _ = self.$name.set(p);
227                    Ok(p)
228                }
229            )*
230            fn empty(lib: Library) -> Self {
231                Self { lib, $($name: OnceLock::new(),)* }
232            }
233        }
234    };
235}
236
237trt_fns! {
238    // The symbol names below mirror the C API exported by TensorRT 10
239    // (`NvInferRuntimeCAPI.h`). Symbol-name mismatches against older TRT
240    // versions fall back to `LoaderError::SymbolUnavailable`, which the safe
241    // crate maps to `Error::FeatureNotSupported`.
242    get_infer_lib_version as "getInferLibVersion": PFN_getInferLibVersion;
243    create_infer_runtime as "createInferRuntime_INTERNAL": PFN_createInferRuntime;
244    destroy_infer_runtime as "destroyInferRuntime": PFN_destroyInferRuntime;
245    deserialize_cuda_engine as "trtRuntimeDeserializeCudaEngine": PFN_deserializeCudaEngine;
246    destroy_cuda_engine as "trtCudaEngineDestroy": PFN_destroyCudaEngine;
247    engine_get_nb_io_tensors as "trtCudaEngineGetNbIOTensors": PFN_engineGetNbIOTensors;
248    engine_get_io_tensor_name as "trtCudaEngineGetIOTensorName": PFN_engineGetIOTensorName;
249    engine_get_tensor_io_mode as "trtCudaEngineGetTensorIOMode": PFN_engineGetTensorIOMode;
250    engine_get_tensor_data_type as "trtCudaEngineGetTensorDataType": PFN_engineGetTensorDataType;
251    engine_get_tensor_shape as "trtCudaEngineGetTensorShape": PFN_engineGetTensorShape;
252    engine_get_tensor_bytes_per_component as "trtCudaEngineGetTensorBytesPerComponent": PFN_engineGetTensorBytesPerComponent;
253    engine_create_execution_context as "trtCudaEngineCreateExecutionContext": PFN_engineCreateExecutionContext;
254    engine_create_execution_context_with_strategy as "trtCudaEngineCreateExecutionContextWithStrategy": PFN_engineCreateExecutionContextWithStrategy;
255    destroy_execution_context as "trtExecutionContextDestroy": PFN_destroyExecutionContext;
256    context_set_input_shape as "trtExecutionContextSetInputShape": PFN_contextSetInputShape;
257    context_get_tensor_shape as "trtExecutionContextGetTensorShape": PFN_contextGetTensorShape;
258    context_set_tensor_address as "trtExecutionContextSetTensorAddress": PFN_contextSetTensorAddress;
259    context_get_tensor_address as "trtExecutionContextGetTensorAddress": PFN_contextGetTensorAddress;
260    context_enqueue_v3 as "trtExecutionContextEnqueueV3": PFN_contextEnqueueV3;
261    engine_get_name as "trtCudaEngineGetName": PFN_engineGetName;
262    engine_get_nb_optimization_profiles as "trtCudaEngineGetNbOptimizationProfiles": PFN_engineGetNbOptimizationProfiles;
263    engine_serialize as "trtCudaEngineSerialize": PFN_engineSerialize;
264    host_memory_data as "trtHostMemoryData": PFN_hostMemoryData;
265    host_memory_size as "trtHostMemorySize": PFN_hostMemorySize;
266    host_memory_destroy as "trtHostMemoryDestroy": PFN_hostMemoryDestroy;
267}
268
269pub fn tensorrt() -> Result<&'static TensorRt, LoaderError> {
270    static TRT: OnceLock<TensorRt> = OnceLock::new();
271    if let Some(c) = TRT.get() {
272        return Ok(c);
273    }
274    let candidates: Vec<&'static str> = tensorrt_candidates()
275        .into_iter()
276        .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
277        .collect();
278    let candidates_leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
279    let lib = Library::open("nvinfer", candidates_leaked)?;
280    let c = TensorRt::empty(lib);
281    let _ = TRT.set(c);
282    Ok(TRT.get().expect("OnceLock set or lost race"))
283}
284
285// A pointer placeholder that satisfies the `c_int` dependency used by
286// cross-checking crates (silences unused-import lint on some configs).
287#[doc(hidden)]
288pub const _UNUSED_C_INT_MARKER: c_int = 0;