Skip to main content

atomr_accel_tensorrt/
sys.rs

1//! Hand-written FFI surface for libnvinfer (and libnvonnxparser when
2//! the `tensorrt-onnx` feature is on).
3//!
4//! TensorRT is a C++ API. We expose just the C-ABI shim functions we
5//! need from a thin C++ glue layer. The functions declared `extern "C"`
6//! here are intentionally empty when the `tensorrt-link` feature is
7//! off — the linker is never asked to resolve them, and the safe
8//! wrappers in `builder.rs`/`engine.rs`/`runtime.rs`/`onnx.rs` only
9//! call them through a `#[cfg(feature = "tensorrt-link")]` gate.
10//!
11//! The opaque pointer types (`IBuilder`, `IBuilderConfig`,
12//! `INetworkDefinition`, `ICudaEngine`, `IExecutionContext`,
13//! `IRuntime`, `IPluginCreator`) are zero-sized stand-ins for the
14//! corresponding TensorRT C++ classes. They are only ever held as
15//! `*mut` raw pointers; `Send`/`Sync` for the safe wrappers is granted
16//! via newtypes (see `engine.rs`).
17//!
18//! The C-ABI shim itself (a hand-written `nvinfer_shim.cpp`) is not
19//! shipped in this Phase 8 skeleton — it lives behind the
20//! `tensorrt-link` feature in a follow-up commit. Until then the FFI
21//! signatures here document the surface area and let downstream code
22//! type-check against a stable shape.
23
24#![allow(non_camel_case_types, dead_code, non_snake_case, unused_imports)]
25
26use std::os::raw::{c_char, c_int, c_void};
27
28// -------- Opaque object pointers --------
29
30#[repr(C)]
31pub struct IBuilder {
32    _private: [u8; 0],
33}
34
35#[repr(C)]
36pub struct IBuilderConfig {
37    _private: [u8; 0],
38}
39
40#[repr(C)]
41pub struct INetworkDefinition {
42    _private: [u8; 0],
43}
44
45#[repr(C)]
46pub struct ICudaEngine {
47    _private: [u8; 0],
48}
49
50#[repr(C)]
51pub struct IExecutionContext {
52    _private: [u8; 0],
53}
54
55#[repr(C)]
56pub struct IRuntime {
57    _private: [u8; 0],
58}
59
60#[repr(C)]
61pub struct IHostMemory {
62    _private: [u8; 0],
63}
64
65#[repr(C)]
66pub struct IRefitter {
67    _private: [u8; 0],
68}
69
70#[repr(C)]
71pub struct IInt8Calibrator {
72    _private: [u8; 0],
73}
74
75#[repr(C)]
76pub struct IPluginCreator {
77    _private: [u8; 0],
78}
79
80#[repr(C)]
81pub struct IPluginV3 {
82    _private: [u8; 0],
83}
84
85#[repr(C)]
86pub struct IOnnxParser {
87    _private: [u8; 0],
88}
89
90// -------- Enums (mirrored from NvInferRuntimeCommon.h / NvInferRuntime.h) --------
91
92#[repr(i32)]
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum DataType {
95    kFLOAT = 0,
96    kHALF = 1,
97    kINT8 = 2,
98    kINT32 = 3,
99    kBOOL = 4,
100    kUINT8 = 5,
101    kFP8 = 6,
102    kBF16 = 7,
103    kINT64 = 8,
104    kINT4 = 9,
105}
106
107#[repr(u32)]
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum BuilderFlag {
110    kFP16 = 0,
111    kINT8 = 1,
112    kDEBUG = 2,
113    kGPU_FALLBACK = 3,
114    kREFIT = 4,
115    kDISABLE_TIMING_CACHE = 5,
116    kTF32 = 6,
117    kSPARSE_WEIGHTS = 7,
118    kSAFETY_SCOPE = 8,
119    kOBEY_PRECISION_CONSTRAINTS = 9,
120    kPREFER_PRECISION_CONSTRAINTS = 10,
121    kDIRECT_IO = 11,
122    kREJECT_EMPTY_ALGORITHMS = 12,
123    kBF16 = 13,
124    kFP8 = 14,
125    kSTRIP_PLAN = 15,
126    kVERSION_COMPATIBLE = 16,
127    kEXCLUDE_LEAN_RUNTIME = 17,
128}
129
130#[repr(i32)]
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum DeviceType {
133    kGPU = 0,
134    kDLA = 1,
135}
136
137#[repr(u32)]
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum TacticSource {
140    kCUBLAS = 0,
141    kCUBLAS_LT = 1,
142    kCUDNN = 2,
143    kEDGE_MASK_CONVOLUTIONS = 3,
144    kJIT_CONVOLUTIONS = 4,
145}
146
147#[repr(i32)]
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum CalibrationAlgoType {
150    kLEGACY_CALIBRATION = 0,
151    kENTROPY_CALIBRATION = 1,
152    kENTROPY_CALIBRATION_2 = 2,
153    kMINMAX_CALIBRATION = 3,
154}
155
156#[repr(C)]
157#[derive(Debug, Clone, Copy)]
158pub struct Dims {
159    pub nb_dims: c_int,
160    pub d: [c_int; 8],
161}
162
163// -------- Function declarations (link probe is gated by `tensorrt-link`) --------
164//
165// The signatures below mirror the C-ABI shim that wraps the TensorRT
166// C++ classes. With the `tensorrt-link` feature off these are present
167// in the source as documentation; they're never referenced and so
168// never produce link errors.
169
170/// Logger callback signature: invoked from the C++ shim's
171/// `RustBridgeLogger::log()` once per TRT log line. Severity follows
172/// `nvinfer1::ILogger::Severity` — 0 = INTERNAL_ERROR, 1 = ERROR,
173/// 2 = WARNING, 3 = INFO, 4 = VERBOSE.
174#[cfg(feature = "tensorrt-link")]
175pub type AtomrTrtLogCb =
176    unsafe extern "C" fn(severity: c_int, msg: *const c_char, len: usize, user: *mut c_void);
177
178#[cfg(feature = "tensorrt-link")]
179extern "C" {
180    /// Install a Rust callback that the C++ shim's static `ILogger`
181    /// forwards every TRT log line to. Idempotent — last call wins.
182    pub fn atomr_trt_install_logger(cb: AtomrTrtLogCb, user: *mut c_void);
183
184    // Builder lifecycle
185    pub fn atomr_trt_builder_create(logger_severity: c_int) -> *mut IBuilder;
186    pub fn atomr_trt_builder_destroy(builder: *mut IBuilder);
187    pub fn atomr_trt_builder_create_network(
188        builder: *mut IBuilder,
189        flags: u32,
190    ) -> *mut INetworkDefinition;
191    pub fn atomr_trt_builder_create_config(builder: *mut IBuilder) -> *mut IBuilderConfig;
192    pub fn atomr_trt_builder_build_serialized(
193        builder: *mut IBuilder,
194        network: *mut INetworkDefinition,
195        config: *mut IBuilderConfig,
196    ) -> *mut IHostMemory;
197
198    // BuilderConfig knobs
199    pub fn atomr_trt_config_destroy(config: *mut IBuilderConfig);
200    pub fn atomr_trt_config_set_flag(config: *mut IBuilderConfig, flag: u32, on: c_int);
201    pub fn atomr_trt_config_set_memory_pool_limit(
202        config: *mut IBuilderConfig,
203        pool: c_int,
204        bytes: usize,
205    );
206    pub fn atomr_trt_config_set_default_device_type(config: *mut IBuilderConfig, dt: c_int);
207    pub fn atomr_trt_config_set_dla_core(config: *mut IBuilderConfig, core: c_int);
208    pub fn atomr_trt_config_set_tactic_sources(config: *mut IBuilderConfig, mask: u32);
209    pub fn atomr_trt_config_set_int8_calibrator(
210        config: *mut IBuilderConfig,
211        calibrator: *mut IInt8Calibrator,
212    );
213    pub fn atomr_trt_config_set_timing_cache(
214        config: *mut IBuilderConfig,
215        blob: *const u8,
216        len: usize,
217    );
218
219    // Engine
220    pub fn atomr_trt_engine_destroy(engine: *mut ICudaEngine);
221    pub fn atomr_trt_engine_create_execution_context(
222        engine: *mut ICudaEngine,
223    ) -> *mut IExecutionContext;
224    pub fn atomr_trt_engine_serialize(engine: *mut ICudaEngine) -> *mut IHostMemory;
225    pub fn atomr_trt_engine_num_io_tensors(engine: *mut ICudaEngine) -> c_int;
226    pub fn atomr_trt_engine_io_tensor_name(engine: *mut ICudaEngine, idx: c_int) -> *const c_char;
227    pub fn atomr_trt_engine_create_refitter(engine: *mut ICudaEngine) -> *mut IRefitter;
228
229    // Refitter
230    pub fn atomr_trt_refitter_destroy(refitter: *mut IRefitter);
231    pub fn atomr_trt_refitter_set_named_weights(
232        refitter: *mut IRefitter,
233        name: *const c_char,
234        weights: *const c_void,
235        bytes: usize,
236        dtype: c_int,
237    ) -> c_int;
238    pub fn atomr_trt_refitter_refit_engine(refitter: *mut IRefitter) -> c_int;
239
240    // ExecutionContext
241    pub fn atomr_trt_context_destroy(ctx: *mut IExecutionContext);
242    pub fn atomr_trt_context_set_input_shape(
243        ctx: *mut IExecutionContext,
244        name: *const c_char,
245        dims: *const Dims,
246    ) -> c_int;
247    pub fn atomr_trt_context_set_tensor_address(
248        ctx: *mut IExecutionContext,
249        name: *const c_char,
250        addr: *mut c_void,
251    ) -> c_int;
252    pub fn atomr_trt_context_enqueue_v3(
253        ctx: *mut IExecutionContext,
254        cuda_stream: *mut c_void,
255    ) -> c_int;
256
257    // Runtime + deserialise
258    pub fn atomr_trt_runtime_create(logger_severity: c_int) -> *mut IRuntime;
259    pub fn atomr_trt_runtime_destroy(runtime: *mut IRuntime);
260    pub fn atomr_trt_runtime_deserialize(
261        runtime: *mut IRuntime,
262        blob: *const u8,
263        len: usize,
264    ) -> *mut ICudaEngine;
265
266    // HostMemory
267    pub fn atomr_trt_host_memory_data(mem: *mut IHostMemory) -> *const u8;
268    pub fn atomr_trt_host_memory_size(mem: *mut IHostMemory) -> usize;
269    pub fn atomr_trt_host_memory_destroy(mem: *mut IHostMemory);
270
271    // Plugin registry (IPluginV3)
272    pub fn atomr_trt_register_plugin_creator(creator: *mut IPluginCreator) -> c_int;
273}
274
275/// Vtable mirrored in `csrc/rust_bridge.h`. Each function pointer
276/// dispatches a vtable-method call from the C++ proxy back to the
277/// corresponding Rust trait method on `dyn PluginV3`.
278#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-plugin"))]
279#[repr(C)]
280pub struct AtomrPluginVTable {
281    pub get_name: unsafe extern "C" fn(user: *const c_void) -> *const c_char,
282    pub get_version: unsafe extern "C" fn(user: *const c_void) -> *const c_char,
283    pub get_namespace: unsafe extern "C" fn(user: *const c_void) -> *const c_char,
284    pub create_plugin:
285        unsafe extern "C" fn(user: *const c_void, name: *const c_char) -> *mut c_void,
286    pub destroy: unsafe extern "C" fn(user: *mut c_void),
287    pub destroy_instance: unsafe extern "C" fn(instance: *mut c_void),
288}
289
290#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-plugin"))]
291extern "C" {
292    pub fn atomr_trt_make_plugin_creator(
293        vt: *const AtomrPluginVTable,
294        user: *mut c_void,
295    ) -> *mut IPluginCreator;
296}
297
298#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-onnx"))]
299extern "C" {
300    pub fn atomr_trt_onnx_parser_create(
301        network: *mut INetworkDefinition,
302        logger_severity: c_int,
303    ) -> *mut IOnnxParser;
304    pub fn atomr_trt_onnx_parser_destroy(parser: *mut IOnnxParser);
305    pub fn atomr_trt_onnx_parser_parse(
306        parser: *mut IOnnxParser,
307        data: *const u8,
308        len: usize,
309        path: *const c_char,
310    ) -> c_int;
311    pub fn atomr_trt_onnx_parser_num_errors(parser: *mut IOnnxParser) -> c_int;
312}