trtx-sys 0.3.0

Raw FFI bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
use std::env;
use std::path::PathBuf;

fn main() {
    let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
    let link_trt = env::var("CARGO_FEATURE_LINK_TENSORRT_RTX").is_ok();
    let link_trt_onnxparser = env::var("CARGO_FEATURE_LINK_TENSORRT_ONNXPARSER").is_ok();
    println!("cargo:rerun-if-env-changed=CARGO_FEATURE_LINK_TENSORRT_RTX");
    println!("cargo:rerun-if-env-changed=CARGO_FEATURE_LINK_TENSORRT_ONNXPARSER");

    // Check if we're in mock mode
    if env::var("CARGO_FEATURE_MOCK").is_ok() {
        println!("cargo:warning=Building in MOCK mode - no TensorRT-RTX required");

        // Build mock C implementation
        cc::Build::new().file("mock.c").compile("trtx_mock");

        generate_mock_bindings(&out_path);
        return;
    }

    println!("cargo:rerun-if-changed=src/lib.rs");
    println!("cargo:rerun-if-changed=logger_bridge.hpp");
    println!("cargo:rerun-if-changed=logger_bridge.cpp");
    println!("cargo:rerun-if-env-changed=TENSORRT_RTX_DIR");
    println!("cargo:rerun-if-env-changed=CUDA_ROOT");
    println!("cargo:rerun-if-env-changed=LIBCLANG_PATH");

    // Look for TensorRT-RTX installation
    // Users can override with TENSORRT_RTX_DIR environment variable
    let trtx_dir = match env::var("TENSORRT_RTX_DIR") {
        Ok(dir) => {
            println!("cargo:warning=Using TENSORRT_RTX_DIR={}", dir);
            dir
        }
        Err(_) => {
            println!(
                "cargo:warning=TENSORRT_RTX_DIR not set, using default: /usr/local/tensorrt-rtx"
            );
            "/usr/local/tensorrt-rtx".to_string()
        }
    };

    let crate_root = env::var("CARGO_MANIFEST_DIR").unwrap();
    #[cfg(feature = "v_1_3")]
    let trt_version = "1.3";

    let include_dir = format!("{crate_root}/TensorRT-Headers/TRT-RTX-{trt_version}");
    let cuda_shim_include_dir = format!("{crate_root}/TensorRT-Headers");
    let lib_dir = format!("{trtx_dir}/lib");

    #[cfg(unix)]
    let trt_version_suffix = "";

    #[cfg(all(windows, feature = "v_1_3"))]
    let trt_version_suffix = "_1_3";

    println!("cargo:rustc-link-search=native={}", lib_dir);
    if link_trt {
        println!("cargo:rustc-link-lib=dylib=tensorrt_rtx{trt_version_suffix}");
    }
    if link_trt_onnxparser {
        println!("cargo:rustc-link-lib=dylib=tensorrt_onnxparser{trt_version_suffix}");
    }

    // Build logger bridge C++ wrapper
    let mut cc_build = cc::Build::new();
    cc_build
        .cpp(true)
        .file("logger_bridge.cpp")
        .include(&include_dir)
        .include(&cuda_shim_include_dir);

    if link_trt {
        cc_build.define("TRTX_LINK_TENSORRT_RTX", "1");
    }
    if link_trt_onnxparser {
        cc_build.define("TRTX_LINK_TENSORRT_ONNXPARSER", "1");
    }

    // Use correct C++17 flag based on compiler
    if cfg!(target_os = "windows") && cfg!(target_env = "msvc") {
        cc_build.flag("/std:c++17");
        cc_build.flag("/wd4100"); // Disable unused parameter warning on MSVC
        cc_build.flag("/wd4996"); // Disable deprecated declaration warning on MSVC
    } else {
        cc_build.flag("-std=c++17");
        cc_build.flag("-Wno-unused-parameter"); // Suppress unused parameter warnings
        cc_build.flag("-Wno-deprecated-declarations"); // Suppress deprecated warnings
    }

    cc_build.compile("trtx_logger_bridge");

    // Build autocxx bindings for main TensorRT API
    // Prepare CUDA include paths for autocxx clang parser
    let clang_args = vec![
        "-std=c++17",
        "-Wno-unused-parameter", // Suppress unused parameter warnings from TensorRT headers
        "-Wno-deprecated-declarations", // Suppress deprecated warnings from TensorRT headers
    ];

    let mut autocxx_build =
        autocxx_build::Builder::new("src/lib.rs", [&include_dir, &cuda_shim_include_dir])
            .extra_clang_args(&clang_args)
            .build()
            .expect("Failed to build autocxx bindings");

    // Set C++17 standard and suppress warnings
    if cfg!(target_os = "windows") && cfg!(target_env = "msvc") {
        autocxx_build.flag("/std:c++17");
        autocxx_build.flag("/wd4100"); // Disable unused parameter warning
        autocxx_build.flag("/wd4996"); // Disable deprecated declaration warning
    } else {
        autocxx_build.flag("-std=c++17");
        autocxx_build.flag("-Wno-unused-parameter"); // Suppress unused parameter warnings
        autocxx_build.flag("-Wno-deprecated-declarations"); // Suppress deprecated warnings
    }

    autocxx_build.compile("trtx_autocxx");

    println!("cargo:rerun-if-changed=src/lib.rs");
}

fn generate_mock_bindings(out_path: &std::path::Path) {
    let mock_bindings = r#"
// Mock bindings for development without TensorRT-RTX

// Error codes
pub const TRTX_SUCCESS: i32 = 0;
pub const TRTX_ERROR_INVALID_ARGUMENT: i32 = 1;
pub const TRTX_ERROR_OUT_OF_MEMORY: i32 = 2;
pub const TRTX_ERROR_RUNTIME_ERROR: i32 = 3;
pub const TRTX_ERROR_CUDA_ERROR: i32 = 4;
pub const TRTX_ERROR_UNKNOWN: i32 = 99;

// Logger severity levels
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum TrtxLoggerSeverity {
    TRTX_SEVERITY_INTERNAL_ERROR = 0,
    TRTX_SEVERITY_ERROR = 1,
    TRTX_SEVERITY_WARNING = 2,
    TRTX_SEVERITY_INFO = 3,
    TRTX_SEVERITY_VERBOSE = 4,
}

// Opaque types (just markers in mock mode)
#[repr(C)]
pub struct TrtxLogger {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxBuilder {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxBuilderConfig {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxNetworkDefinition {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxRuntime {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxCudaEngine {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxExecutionContext {
    _unused: [u8; 0],
}

#[repr(C)]
pub struct TrtxOnnxParser {
    _unused: [u8; 0],
}

// Logger callback type
pub type TrtxLoggerCallback = ::std::option::Option<
    unsafe extern "C" fn(
        user_data: *mut ::std::os::raw::c_void,
        severity: TrtxLoggerSeverity,
        msg: *const ::std::os::raw::c_char,
    ),
>;

// Stub implementations that return success
extern "C" {
    pub fn trtx_logger_create(
        callback: TrtxLoggerCallback,
        user_data: *mut ::std::os::raw::c_void,
        out_logger: *mut *mut TrtxLogger,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_logger_destroy(logger: *mut TrtxLogger);

    pub fn trtx_builder_create(
        logger: *mut TrtxLogger,
        out_builder: *mut *mut TrtxBuilder,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_builder_destroy(builder: *mut TrtxBuilder);

    pub fn trtx_builder_create_network(
        builder: *mut TrtxBuilder,
        flags: u32,
        out_network: *mut *mut TrtxNetworkDefinition,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_builder_create_builder_config(
        builder: *mut TrtxBuilder,
        out_config: *mut *mut TrtxBuilderConfig,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_builder_build_serialized_network(
        builder: *mut TrtxBuilder,
        network: *mut TrtxNetworkDefinition,
        config: *mut TrtxBuilderConfig,
        out_data: *mut *mut ::std::os::raw::c_void,
        out_size: *mut usize,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_builder_config_destroy(config: *mut TrtxBuilderConfig);

    pub fn trtx_builder_config_set_memory_pool_limit(
        config: *mut TrtxBuilderConfig,
        pool_type: i32,
        pool_size: usize,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_network_destroy(network: *mut TrtxNetworkDefinition);

    pub fn trtx_runtime_create(
        logger: *mut TrtxLogger,
        out_runtime: *mut *mut TrtxRuntime,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_runtime_destroy(runtime: *mut TrtxRuntime);

    pub fn trtx_runtime_deserialize_cuda_engine(
        runtime: *mut TrtxRuntime,
        data: *const ::std::os::raw::c_void,
        size: usize,
        out_engine: *mut *mut TrtxCudaEngine,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_engine_destroy(engine: *mut TrtxCudaEngine);

    pub fn trtx_cuda_engine_create_execution_context(
        engine: *mut TrtxCudaEngine,
        out_context: *mut *mut TrtxExecutionContext,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_engine_get_tensor_name(
        engine: *mut TrtxCudaEngine,
        index: i32,
        out_name: *mut *const ::std::os::raw::c_char,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_engine_get_nb_io_tensors(
        engine: *mut TrtxCudaEngine,
        out_count: *mut i32,
    ) -> i32;

    pub fn trtx_execution_context_destroy(context: *mut TrtxExecutionContext);

    pub fn trtx_execution_context_set_tensor_address(
        context: *mut TrtxExecutionContext,
        tensor_name: *const ::std::os::raw::c_char,
        data: *mut ::std::os::raw::c_void,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_execution_context_enqueue_v3(
        context: *mut TrtxExecutionContext,
        cuda_stream: *mut ::std::os::raw::c_void,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_free_buffer(buffer: *mut ::std::os::raw::c_void);

    // ONNX Parser functions
    pub fn trtx_onnx_parser_create(
        network: *mut TrtxNetworkDefinition,
        logger: *mut TrtxLogger,
        out_parser: *mut *mut TrtxOnnxParser,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_onnx_parser_destroy(parser: *mut TrtxOnnxParser);

    pub fn trtx_onnx_parser_parse(
        parser: *mut TrtxOnnxParser,
        model_data: *const ::std::os::raw::c_void,
        model_size: usize,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    // CUDA Memory Management functions
    pub fn trtx_cuda_malloc(
        ptr: *mut *mut ::std::os::raw::c_void,
        size: usize,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_free(
        ptr: *mut ::std::os::raw::c_void,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_memcpy_host_to_device(
        dst: *mut ::std::os::raw::c_void,
        src: *const ::std::os::raw::c_void,
        size: usize,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_memcpy_device_to_host(
        dst: *mut ::std::os::raw::c_void,
        src: *const ::std::os::raw::c_void,
        size: usize,
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_synchronize(
        error_msg: *mut ::std::os::raw::c_char,
        error_msg_len: usize,
    ) -> i32;

    pub fn trtx_cuda_get_default_stream() -> *mut ::std::os::raw::c_void;
}

// Mock nvinfer1 module - stub types for trtx crate compatibility in mock mode
pub mod nvinfer1 {
    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum DataType {
        kFLOAT = 0,
        kHALF = 1,
        kINT8 = 2,
        kINT32 = 3,
        kBOOL = 4,
        kUINT8 = 5,
        kFP8 = 6,
        kBF16 = 7,
        kINT64 = 8,
        kINT4 = 9,
        kFP4 = 10,
        kE8M0 = 11,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ActivationType {
        kRELU = 0,
        kSIGMOID = 1,
        kTANH = 2,
        kLEAKY_RELU = 3,
        kELU = 4,
        kSELU = 5,
        kSOFTSIGN = 6,
        kSOFTPLUS = 7,
        kCLIP = 8,
        kHARD_SIGMOID = 9,
        kSCALED_TANH = 10,
        kTHRESHOLDED_RELU = 11,
        kGELU_ERF = 12,
        kGELU_TANH = 13,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum PoolingType {
        kMAX = 0,
        kAVERAGE = 1,
        kMAX_AVERAGE_BLEND = 2,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ElementWiseOperation {
        kSUM = 0,
        kPROD = 1,
        kMAX = 2,
        kMIN = 3,
        kSUB = 4,
        kDIV = 5,
        kPOW = 6,
        kFLOOR_DIV = 7,
        kAND = 8,
        kOR = 9,
        kXOR = 10,
        kEQUAL = 11,
        kGREATER = 12,
        kLESS = 13,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum MatrixOperation {
        kNONE = 0,
        kTRANSPOSE = 1,
        kVECTOR = 2,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum UnaryOperation {
        kEXP = 0,
        kLOG = 1,
        kSQRT = 2,
        kRECIP = 3,
        kABS = 4,
        kNEG = 5,
        kSIN = 6,
        kCOS = 7,
        kTAN = 8,
        kSINH = 9,
        kCOSH = 10,
        kASIN = 11,
        kACOS = 12,
        kATAN = 13,
        kASINH = 14,
        kACOSH = 15,
        kATANH = 16,
        kCEIL = 17,
        kFLOOR = 18,
        kERF = 19,
        kNOT = 20,
        kROUND = 21,
        kSIGN = 22,
        kISINF = 23,
        kISNAN = 24,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ReduceOperation {
        kSUM = 0,
        kPROD = 1,
        kMAX = 2,
        kMIN = 3,
        kAVG = 4,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum CumulativeOperation {
        kSUM = 0,
        kPROD = 1,
        kMIN = 2,
        kMAX = 3,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum GatherMode {
        kDEFAULT = 0,
        kELEMENT = 1,
        kND = 2,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ScatterMode {
        kELEMENT = 0,
        kND = 1,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum InterpolationMode {
        kNEAREST = 0,
        kLINEAR = 1,
        kCUBIC = 2,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ResizeCoordinateTransformation {
        kASYMMETRIC = 0,
        kALIGN_CORNERS = 1,
        kHALF_PIXEL = 2,
        kHALF_PIXEL_SYMMETRIC = 3,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ResizeRoundMode {
        kFLOOR = 0,
        kCEIL = 1,
        kROUND = 2,
        kHALF_UP = 3,
        kHALF_DOWN = 4,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ResizeSelector {
        kFORMULA = 0,
        kSIZES = 1,
        kUPPER = 2,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum TopKOperation {
        kMAX = 0,
        kMIN = 1,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ScaleMode {
        kUNIFORM = 0,
        kCHANNEL = 1,
        kELEMENTWISE = 2,
    }

    #[repr(i32)]
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
    pub enum ExecutionContextAllocationStrategy {
        kSTATIC = 0,
        kUSER_MANAGED = 1,
    }

    // Layer interface types (opaque stubs for mock - only used in type positions)
    #[repr(C)]
    pub struct IShuffleLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IActivationLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IResizeLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ITopKLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IGatherLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IScatterLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ISelectLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IMatrixMultiplyLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ISoftMaxLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IReduceLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ICumulativeLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IPoolingLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IConvolutionLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IDeconvolutionLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IQuantizeLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IDequantizeLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IConstantLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IConcatenationLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IScaleLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ISliceLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IUnaryLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IIdentityLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IPaddingLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ICastLayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ITensor { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ILayer { _unused: [u8; 0] }
    #[repr(C)]
    pub struct INetworkDefinition { _unused: [u8; 0] }
    #[repr(C)]
    pub struct ICudaEngine { _unused: [u8; 0] }
    #[repr(C)]
    pub struct IExecutionContext { _unused: [u8; 0] }

    #[repr(C)]
    #[derive(Clone, Copy)]
    pub struct Weights {
        pub type_: DataType,
        pub values: *const ::std::ffi::c_void,
        pub count: i64,
    }

    impl Weights {
        pub fn new_float(values_ptr: *const ::std::ffi::c_void, count_val: i64) -> Self {
            Self { type_: DataType::kFLOAT, values: values_ptr, count: count_val }
        }
        pub fn new_with_type(
            data_type: DataType,
            values_ptr: *const ::std::ffi::c_void,
            count_val: i64,
        ) -> Self {
            Self { type_: data_type, values: values_ptr, count: count_val }
        }
    }
}

// Dims64/Dims - mock version
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct Dims64 {
    pub nbDims: i32,
    pub d: [i64; 8],
}

pub type Dims = Dims64;

impl Dims64 {
    pub fn from_slice(dims: &[i64]) -> Self {
        let mut d = [0i64; 8];
        let nb_dims = dims.len().min(8) as i32;
        d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]);
        Self { nbDims: nb_dims, d }
    }
    pub fn new_2d(d0: i64, d1: i64) -> Self {
        Self { nbDims: 2, d: [d0, d1, 0, 0, 0, 0, 0, 0] }
    }
    pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self {
        Self { nbDims: 3, d: [d0, d1, d2, 0, 0, 0, 0, 0] }
    }
    pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self {
        Self { nbDims: 4, d: [d0, d1, d2, d3, 0, 0, 0, 0] }
    }
}

// ResizeMode is InterpolationMode in TensorRT
pub use nvinfer1::InterpolationMode as ResizeMode;
"#;

    std::fs::write(out_path.join("bindings.rs"), mock_bindings)
        .expect("Couldn't write mock bindings!");
}