1#![allow(non_camel_case_types, dead_code, non_snake_case, unused_imports)]
25
26use std::os::raw::{c_char, c_int, c_void};
27
28#[repr(C)]
31pub struct IBuilder {
32 _private: [u8; 0],
33}
34
35#[repr(C)]
36pub struct IBuilderConfig {
37 _private: [u8; 0],
38}
39
40#[repr(C)]
41pub struct INetworkDefinition {
42 _private: [u8; 0],
43}
44
45#[repr(C)]
46pub struct ICudaEngine {
47 _private: [u8; 0],
48}
49
50#[repr(C)]
51pub struct IExecutionContext {
52 _private: [u8; 0],
53}
54
55#[repr(C)]
56pub struct IRuntime {
57 _private: [u8; 0],
58}
59
60#[repr(C)]
61pub struct IHostMemory {
62 _private: [u8; 0],
63}
64
65#[repr(C)]
66pub struct IRefitter {
67 _private: [u8; 0],
68}
69
70#[repr(C)]
71pub struct IInt8Calibrator {
72 _private: [u8; 0],
73}
74
75#[repr(C)]
76pub struct IPluginCreator {
77 _private: [u8; 0],
78}
79
80#[repr(C)]
81pub struct IPluginV3 {
82 _private: [u8; 0],
83}
84
85#[repr(C)]
86pub struct IOnnxParser {
87 _private: [u8; 0],
88}
89
90#[repr(i32)]
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum DataType {
95 kFLOAT = 0,
96 kHALF = 1,
97 kINT8 = 2,
98 kINT32 = 3,
99 kBOOL = 4,
100 kUINT8 = 5,
101 kFP8 = 6,
102 kBF16 = 7,
103 kINT64 = 8,
104 kINT4 = 9,
105}
106
107#[repr(u32)]
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum BuilderFlag {
110 kFP16 = 0,
111 kINT8 = 1,
112 kDEBUG = 2,
113 kGPU_FALLBACK = 3,
114 kREFIT = 4,
115 kDISABLE_TIMING_CACHE = 5,
116 kTF32 = 6,
117 kSPARSE_WEIGHTS = 7,
118 kSAFETY_SCOPE = 8,
119 kOBEY_PRECISION_CONSTRAINTS = 9,
120 kPREFER_PRECISION_CONSTRAINTS = 10,
121 kDIRECT_IO = 11,
122 kREJECT_EMPTY_ALGORITHMS = 12,
123 kBF16 = 13,
124 kFP8 = 14,
125 kSTRIP_PLAN = 15,
126 kVERSION_COMPATIBLE = 16,
127 kEXCLUDE_LEAN_RUNTIME = 17,
128}
129
130#[repr(i32)]
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum DeviceType {
133 kGPU = 0,
134 kDLA = 1,
135}
136
137#[repr(u32)]
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum TacticSource {
140 kCUBLAS = 0,
141 kCUBLAS_LT = 1,
142 kCUDNN = 2,
143 kEDGE_MASK_CONVOLUTIONS = 3,
144 kJIT_CONVOLUTIONS = 4,
145}
146
147#[repr(i32)]
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum CalibrationAlgoType {
150 kLEGACY_CALIBRATION = 0,
151 kENTROPY_CALIBRATION = 1,
152 kENTROPY_CALIBRATION_2 = 2,
153 kMINMAX_CALIBRATION = 3,
154}
155
156#[repr(C)]
157#[derive(Debug, Clone, Copy)]
158pub struct Dims {
159 pub nb_dims: c_int,
160 pub d: [c_int; 8],
161}
162
163#[cfg(feature = "tensorrt-link")]
175pub type AtomrTrtLogCb =
176 unsafe extern "C" fn(severity: c_int, msg: *const c_char, len: usize, user: *mut c_void);
177
178#[cfg(feature = "tensorrt-link")]
179extern "C" {
180 pub fn atomr_trt_install_logger(cb: AtomrTrtLogCb, user: *mut c_void);
183
184 pub fn atomr_trt_builder_create(logger_severity: c_int) -> *mut IBuilder;
186 pub fn atomr_trt_builder_destroy(builder: *mut IBuilder);
187 pub fn atomr_trt_builder_create_network(
188 builder: *mut IBuilder,
189 flags: u32,
190 ) -> *mut INetworkDefinition;
191 pub fn atomr_trt_builder_create_config(builder: *mut IBuilder) -> *mut IBuilderConfig;
192 pub fn atomr_trt_builder_build_serialized(
193 builder: *mut IBuilder,
194 network: *mut INetworkDefinition,
195 config: *mut IBuilderConfig,
196 ) -> *mut IHostMemory;
197
198 pub fn atomr_trt_config_destroy(config: *mut IBuilderConfig);
200 pub fn atomr_trt_config_set_flag(config: *mut IBuilderConfig, flag: u32, on: c_int);
201 pub fn atomr_trt_config_set_memory_pool_limit(
202 config: *mut IBuilderConfig,
203 pool: c_int,
204 bytes: usize,
205 );
206 pub fn atomr_trt_config_set_default_device_type(config: *mut IBuilderConfig, dt: c_int);
207 pub fn atomr_trt_config_set_dla_core(config: *mut IBuilderConfig, core: c_int);
208 pub fn atomr_trt_config_set_tactic_sources(config: *mut IBuilderConfig, mask: u32);
209 pub fn atomr_trt_config_set_int8_calibrator(
210 config: *mut IBuilderConfig,
211 calibrator: *mut IInt8Calibrator,
212 );
213 pub fn atomr_trt_config_set_timing_cache(
214 config: *mut IBuilderConfig,
215 blob: *const u8,
216 len: usize,
217 );
218
219 pub fn atomr_trt_engine_destroy(engine: *mut ICudaEngine);
221 pub fn atomr_trt_engine_create_execution_context(
222 engine: *mut ICudaEngine,
223 ) -> *mut IExecutionContext;
224 pub fn atomr_trt_engine_serialize(engine: *mut ICudaEngine) -> *mut IHostMemory;
225 pub fn atomr_trt_engine_num_io_tensors(engine: *mut ICudaEngine) -> c_int;
226 pub fn atomr_trt_engine_io_tensor_name(engine: *mut ICudaEngine, idx: c_int) -> *const c_char;
227 pub fn atomr_trt_engine_create_refitter(engine: *mut ICudaEngine) -> *mut IRefitter;
228
229 pub fn atomr_trt_refitter_destroy(refitter: *mut IRefitter);
231 pub fn atomr_trt_refitter_set_named_weights(
232 refitter: *mut IRefitter,
233 name: *const c_char,
234 weights: *const c_void,
235 bytes: usize,
236 dtype: c_int,
237 ) -> c_int;
238 pub fn atomr_trt_refitter_refit_engine(refitter: *mut IRefitter) -> c_int;
239
240 pub fn atomr_trt_context_destroy(ctx: *mut IExecutionContext);
242 pub fn atomr_trt_context_set_input_shape(
243 ctx: *mut IExecutionContext,
244 name: *const c_char,
245 dims: *const Dims,
246 ) -> c_int;
247 pub fn atomr_trt_context_set_tensor_address(
248 ctx: *mut IExecutionContext,
249 name: *const c_char,
250 addr: *mut c_void,
251 ) -> c_int;
252 pub fn atomr_trt_context_enqueue_v3(
253 ctx: *mut IExecutionContext,
254 cuda_stream: *mut c_void,
255 ) -> c_int;
256
257 pub fn atomr_trt_runtime_create(logger_severity: c_int) -> *mut IRuntime;
259 pub fn atomr_trt_runtime_destroy(runtime: *mut IRuntime);
260 pub fn atomr_trt_runtime_deserialize(
261 runtime: *mut IRuntime,
262 blob: *const u8,
263 len: usize,
264 ) -> *mut ICudaEngine;
265
266 pub fn atomr_trt_host_memory_data(mem: *mut IHostMemory) -> *const u8;
268 pub fn atomr_trt_host_memory_size(mem: *mut IHostMemory) -> usize;
269 pub fn atomr_trt_host_memory_destroy(mem: *mut IHostMemory);
270
271 pub fn atomr_trt_register_plugin_creator(creator: *mut IPluginCreator) -> c_int;
273}
274
275#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-plugin"))]
279#[repr(C)]
280pub struct AtomrPluginVTable {
281 pub get_name: unsafe extern "C" fn(user: *const c_void) -> *const c_char,
282 pub get_version: unsafe extern "C" fn(user: *const c_void) -> *const c_char,
283 pub get_namespace: unsafe extern "C" fn(user: *const c_void) -> *const c_char,
284 pub create_plugin:
285 unsafe extern "C" fn(user: *const c_void, name: *const c_char) -> *mut c_void,
286 pub destroy: unsafe extern "C" fn(user: *mut c_void),
287 pub destroy_instance: unsafe extern "C" fn(instance: *mut c_void),
288}
289
290#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-plugin"))]
291extern "C" {
292 pub fn atomr_trt_make_plugin_creator(
293 vt: *const AtomrPluginVTable,
294 user: *mut c_void,
295 ) -> *mut IPluginCreator;
296}
297
298#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-onnx"))]
299extern "C" {
300 pub fn atomr_trt_onnx_parser_create(
301 network: *mut INetworkDefinition,
302 logger_severity: c_int,
303 ) -> *mut IOnnxParser;
304 pub fn atomr_trt_onnx_parser_destroy(parser: *mut IOnnxParser);
305 pub fn atomr_trt_onnx_parser_parse(
306 parser: *mut IOnnxParser,
307 data: *const u8,
308 len: usize,
309 path: *const c_char,
310 ) -> c_int;
311 pub fn atomr_trt_onnx_parser_num_errors(parser: *mut IOnnxParser) -> c_int;
312}