trtx-sys 0.5.0

Raw FFI bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
use std::env;
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};

use bindgen::EnumVariation;
use regex::Regex;

/// Writes `contents` only if the file is missing or differs. Avoids bumping mtimes when the
/// build script re-runs unchanged; autocxx emits `cargo:rerun-if-changed` for headers under
/// `OUT_DIR`, and rewriting them every run made Cargo think those outputs changed after the
/// previous fingerprint (perpetual rebuild loop).
/// Parsed from an RTX SDK `NvInferVersion.h` (`TRT_*_RTX` defines).
#[derive(Clone, Copy, Debug)]
struct RtxHeaderVersion {
    major: u32,
    minor: u32,
    patch: u32,
}

/// Reads `NvInferVersion.h` and matches RTX version lines, e.g.
/// `#define TRT_MAJOR_RTX 1` / `TRT_MINOR_RTX` / `TRT_PATCH_RTX`.
fn parse_trt_rtx_version_from_nvinfer_version_h(version_h: &Path) -> Option<RtxHeaderVersion> {
    let text = std::fs::read_to_string(version_h).ok()?;
    let major_re = Regex::new(r"(?m)^\s*#define\s+TRT_MAJOR_RTX\s+(\d+)\s*$").ok()?;
    let minor_re = Regex::new(r"(?m)^\s*#define\s+TRT_MINOR_RTX\s+(\d+)\s*$").ok()?;
    let patch_re = Regex::new(r"(?m)^\s*#define\s+TRT_PATCH_RTX\s+(\d+)\s*$").ok()?;
    let major: u32 = major_re.captures(&text)?.get(1)?.as_str().parse().ok()?;
    let minor: u32 = minor_re.captures(&text)?.get(1)?.as_str().parse().ok()?;
    let patch: u32 = patch_re.captures(&text)?.get(1)?.as_str().parse().ok()?;
    Some(RtxHeaderVersion {
        major,
        minor,
        patch,
    })
}

fn trt_version_suffix_from_feature_flags() -> &'static str {
    if cfg!(feature = "v_1_4") {
        "_1_4"
    } else {
        "_1_3"
    }
}

fn write_if_changed(path: &Path, contents: &[u8]) {
    let unchanged = std::fs::read(path)
        .map(|existing| existing == contents)
        .unwrap_or(false);
    if !unchanged {
        let mut file = File::create(path).unwrap();
        file.write_all(contents).unwrap();
    }
}

/// Creates a folder in `out_dir`, copies headers from `trt_dir` with transformations applied,
/// and returns the path to the new folder. Original headers are never modified.
fn prepare_transformed_headers(header_dir: &Path, out_dir: &Path) -> PathBuf {
    let doxy_regex = Regex::new(r"\\(\w+)").unwrap();
    let warn_regex = Regex::new(r"\\warning (.*)$").unwrap();
    let see_regex = Regex::new(r"\\see ([`\w:()]+)").unwrap();
    let param_regex = Regex::new(r"\\param ([\w:()]+)").unwrap();
    let comment_indent_regex = Regex::new(r"(///\ )(\ +)").unwrap();
    let http_link_regex = Regex::new(r"https:\/\/[\w.\/\-#]+").unwrap();

    let transformed_dir = out_dir.join("trtx_transformed_headers");
    std::fs::create_dir_all(&transformed_dir).expect("Failed to create transformed headers dir");

    for entry in std::fs::read_dir(header_dir).unwrap() {
        let entry = entry.unwrap();
        let path = entry.path();
        if path.is_file() {
            let replaced = std::fs::read_to_string(&path).unwrap();
            let replaced = warn_regex.replace_all(&replaced, "<div class=\"warning\"> $1 </div>");
            let replaced = see_regex.replace_all(&replaced, "See [`$1`]");
            let replaced = param_regex.replace_all(&replaced, "- `$1`");
            let replaced = http_link_regex.replace_all(&replaced, "<$0>");
            let replaced = replaced
                .replace("std::size_t", "size_t")
                // workaround autocxx limitation where there can't be the same type in different
                // namespaces
                .replace("namespace v_1_0", "inline namespace v_1_0")
                .replace("namespace impl", "inline namespace impl")
                .replace("ErrorCode getErrorCode", "int32_t getErrorCode")
                .replace(
                    "bool reportError(ErrorCode val",
                    "bool reportError(int32_t val",
                )
                .replace(
                    "void log(Severity severity, AsciiChar const* msg)",
                    "void log(int32_t severity, char const* msg)",
                )
                .replace("//!", "///")
                .replace(r"\returns", " - Returns ");
            let replaced = doxy_regex.replace_all(&replaced, "");
            // We need to declare everything added after v_1_3 because we can't do features for
            // autocxx
            let replaced = "#include <cstdint>\n".to_string()
                + &replaced
                + r#"
namespace nvinfer1 {
    class IMoELayer;
    class IDistCollectiveLayer;
    enum class ComputeCapability : int32_t;
}"#;

            // trimming of indentation is necessary, so that rustdoc doesn't interpret
            // indented sections as Rust code.
            let replaced = comment_indent_regex
                .replace_all(&replaced, "$1")
                .replace("\r\n", "\n");

            let out_file = transformed_dir.join(path.file_name().unwrap());
            write_if_changed(&out_file, replaced.as_bytes());
        }
    }

    transformed_dir
}

/// Generate enum bindings from NvInfer.h using bindgen (replaces generate_debug_enum.sh).
fn generate_enum_bindings(crate_root: &str, out_path: &Path, include_dir: &Path) {
    let header = include_dir.join("NvInfer.h").to_string_lossy().to_string();
    let cuda_shim = format!("{crate_root}/TensorRT-Headers");

    println!("cargo:rerun-if-changed={header}");

    let mut builder = bindgen::Builder::default()
        .header(header)
        .default_enum_style(EnumVariation::Rust {
            non_exhaustive: false,
        })
        .derive_default(true)
        .derive_eq(true)
        .derive_hash(true)
        .derive_ord(true)
        .blocklist_type("cu.*")
        .clang_arg("-x")
        .clang_arg("c++")
        .clang_arg(format!("-I{}", include_dir.to_string_lossy()))
        .clang_arg(format!("-I{cuda_shim}"));

    // Allowlist types matching the script's regex: (.*Type|.*Mode|.*Operation|...)
    for pattern in [
        ".*Type",
        ".*Mode",
        ".*Operation",
        ".*Strategy",
        ".*Severity",
        ".*Format",
        ".*Verbosity",
        ".*Feature",
        ".*Platform",
        ".*Level",
        ".*Capability",
        ".*ErrorCode",
        ".*Flag",
        ".*Selector",
        ".*Transformation",
        ".*Location",
        ".*Role",
        ".*Limit",
        ".*AttentionNormalizationOp",
        ".*SeekPosition",
        ".*LoopOutput",
    ] {
        builder = builder.allowlist_type(pattern);
    }
    builder = builder.blocklist_type(".*IPluginCapability");
    builder = builder.blocklist_type(".*IVersionedInterface");
    builder = builder.blocklist_type(".*InterfaceInfo");
    builder = builder.blocklist_type(".*InterfaceKind");

    let bindings = builder
        .generate()
        .expect("Failed to generate enum bindings from NvInfer.h");

    let mut output = bindings.to_string();
    output = output.replace("extern \"C\"", "extern \"system\"");
    output = output.replace("nvinfer1_", "");
    output = output.replace("ILogger_", "");
    output = output.replace("impl__EnumMaxImpl", "impl_EnumMaxImpl");

    let out_file = out_path.join("enums.rs");
    let enums_src = format!("/* automatically generated by bindgen */\n\n{output}");
    write_if_changed(&out_file, enums_src.as_bytes());
}

fn main() {
    let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
    let crate_root = env::var("CARGO_MANIFEST_DIR").unwrap();
    let link_trt = env::var("CARGO_FEATURE_LINK_TENSORRT_RTX").is_ok();
    let link_trt_onnxparser = env::var("CARGO_FEATURE_LINK_TENSORRT_ONNXPARSER").is_ok();

    println!("cargo:rerun-if-env-changed=CARGO_FEATURE_LINK_TENSORRT_RTX");
    println!("cargo:rerun-if-env-changed=CARGO_FEATURE_LINK_TENSORRT_ONNXPARSER");

    let trt_version = if cfg!(feature = "v_1_4") {
        "1.4"
    } else if cfg!(feature = "v_1_3") {
        "1.3"
    } else {
        panic!("No version feature enabled! Need to at least enable v_1_3 or v_1_4");
    };

    let header_overwrite = env::var("TENSORRT_INCLUDE_DIR").ok();
    println!("cargo:rerun-if-env-changed=TENSORRT_INCLUDE_DIR");
    let lib_overwrite = env::var("TENSORRT_LIB_DIR").ok();
    println!("cargo:rerun-if-env-changed=TENSORRT_LIB_DIR");
    let sdk_overwrite = env::var("TENSORRT_SDK_DIR").ok();
    println!("cargo:rerun-if-env-changed=TENSORRT_SDK_DIR");
    if let Some(sdk_overwrite) = sdk_overwrite.as_ref() {
        println!("cargo:warning=Using TENSORRT_SDK_DIR={sdk_overwrite}");
    }
    if let Some(lib_overwrite) = lib_overwrite.as_ref() {
        println!("cargo:warning=Using TENSORRT_LIB_DIR={lib_overwrite}");
    }
    if let Some(header_overwrite) = header_overwrite.as_ref() {
        println!("cargo:warning=Using TENSORRT_INCLUDE_DIR={header_overwrite}");
    }

    let include_dir = header_overwrite
        .clone()
        .or_else(|| sdk_overwrite.clone().map(|p| format!("{p}/include")))
        .map(PathBuf::from)
        .unwrap_or_else(|| {
            PathBuf::from(if cfg!(feature = "enterprise") {
                format!("{crate_root}/TensorRT-Headers/TRT-Enterprise-10.16")
            } else {
                format!("{crate_root}/TensorRT-Headers/TRT-RTX-{trt_version}")
            })
        });
    println!("cargo:rerun-if-changed={}", include_dir.display());
    let nvinfer_version_h = include_dir.join("NvInferVersion.h");
    println!("cargo:rerun-if-changed={}", nvinfer_version_h.display());
    let cuda_shim_include_dir = format!("{crate_root}/TensorRT-Headers");
    let lib_dir = lib_overwrite.or_else(|| sdk_overwrite.clone().map(|p| format!("{p}/lib")));

    // Windows link names use a `_major_minor` suffix for RTX (e.g. `tensorrt_rtx_1_4`). When an
    // SDK root is set, derive the suffix from `include/NvInferVersion.h` instead of Cargo features.
    let trt_version_suffix: String = if cfg!(unix) || cfg!(feature = "enterprise") {
        String::new()
    } else if sdk_overwrite.is_some() {
        match parse_trt_rtx_version_from_nvinfer_version_h(&nvinfer_version_h) {
            Some(v) => {
                println!(
                    "cargo:warning=Linking tensorrt_rtx_{}_{} / onnxparser (from NvInferVersion.h: {}.{}.{})",
                    v.major, v.minor, v.major, v.minor, v.patch
                );
                format!("_{}_{}", v.major, v.minor)
            }
            None => {
                println!(
                    "cargo:warning=SDK dir set but could not parse TRT_*_RTX in {}; using Cargo feature suffix",
                    nvinfer_version_h.display()
                );
                trt_version_suffix_from_feature_flags().to_string()
            }
        }
    } else {
        trt_version_suffix_from_feature_flags().to_string()
    };

    // Generate enum bindings from NvInfer.h (used in both mock and real builds)
    generate_enum_bindings(&crate_root, &out_path, &include_dir);

    // Check if we're in mock mode
    let is_mock = env::var("CARGO_FEATURE_MOCK").is_ok();

    println!("cargo:rerun-if-changed=src/lib.rs");
    println!("cargo:rerun-if-changed=logger_bridge.hpp");
    println!("cargo:rerun-if-changed=logger_bridge.cpp");
    println!("cargo:rerun-if-env-changed=CUDA_ROOT");
    println!("cargo:rerun-if-env-changed=LIBCLANG_PATH");

    let transformed_include_dir = prepare_transformed_headers(&include_dir, &out_path);
    let transformed_include_dir_str = transformed_include_dir.to_string_lossy();

    if let Some(lib_dir) = lib_dir {
        println!("cargo:rustc-link-search=native={}", lib_dir);
    }
    if link_trt {
        if cfg!(feature = "enterprise") {
            println!("cargo:rustc-link-lib=dylib=nvinfer");
        } else {
            println!("cargo:rustc-link-lib=dylib=tensorrt_rtx{trt_version_suffix}");
        }
    }
    if link_trt_onnxparser {
        if cfg!(feature = "enterprise") {
            println!("cargo:rustc-link-lib=dylib=nvonnxparser");
        } else {
            println!("cargo:rustc-link-lib=dylib=tensorrt_onnxparser_rtx{trt_version_suffix}");
        }
    }

    // Build logger bridge C++ wrapper
    let mut cc_build = cc::Build::new();
    cc_build
        .cpp(true)
        .file("logger_bridge.cpp")
        .include(&transformed_include_dir)
        .include(&cuda_shim_include_dir);

    if is_mock {
        cc_build.define("TRTX_MOCK_MODE", "1");
    }
    if link_trt {
        cc_build.define("TRTX_LINK_TENSORRT_RTX", "1");
    }
    if link_trt_onnxparser {
        cc_build.define("TRTX_LINK_TENSORRT_ONNXPARSER", "1");
    }

    // Use correct C++17 flag based on compiler
    if cfg!(target_os = "windows") && cfg!(target_env = "msvc") {
        cc_build.flag("/std:c++17");
        cc_build.flag("/wd4100"); // Disable unused parameter warning on MSVC
        cc_build.flag("/wd4996"); // Disable deprecated declaration warning on MSVC
    } else {
        cc_build.flag("-std=c++17");
        cc_build.flag("-Wno-unused-parameter"); // Suppress unused parameter warnings
        cc_build.flag("-Wno-deprecated-declarations"); // Suppress deprecated warnings
    }

    cc_build.compile("trtx_logger_bridge");

    // Build autocxx bindings for main TensorRT API
    // Prepare CUDA include paths for autocxx clang parser
    let clang_args = vec![
        "-std=c++17",
        "-Wno-unused-parameter", // Suppress unused parameter warnings from TensorRT headers
        "-Wno-deprecated-declarations", // Suppress deprecated warnings from TensorRT headers
    ];

    let mut autocxx_build = autocxx_build::Builder::new(
        "src/lib.rs",
        [
            transformed_include_dir_str.as_ref(),
            cuda_shim_include_dir.as_str(),
        ],
    )
    .extra_clang_args(&clang_args)
    .build()
    .expect("Failed to build autocxx bindings");

    // Set C++17 standard and suppress warnings
    if cfg!(target_os = "windows") && cfg!(target_env = "msvc") {
        autocxx_build.flag("/std:c++17");
        autocxx_build.flag("/wd4100"); // Disable unused parameter warning
        autocxx_build.flag("/wd4996"); // Disable deprecated declaration warning
    } else {
        autocxx_build.flag("-std=c++17");
        autocxx_build.flag("-Wno-unused-parameter"); // Suppress unused parameter warnings
        autocxx_build.flag("-Wno-deprecated-declarations"); // Suppress deprecated warnings
    }

    autocxx_build.compile("trtx_autocxx");
}