use std::env;
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
use bindgen::EnumVariation;
use regex::Regex;
fn write_if_changed(path: &Path, contents: &[u8]) {
let unchanged = std::fs::read(path)
.map(|existing| existing == contents)
.unwrap_or(false);
if !unchanged {
let mut file = File::create(path).unwrap();
file.write_all(contents).unwrap();
}
}
fn prepare_transformed_headers(header_dir: &Path, out_dir: &Path) -> PathBuf {
let doxy_regex = Regex::new(r"\\(\w+)").unwrap();
let warn_regex = Regex::new(r"\\warning (.*)$").unwrap();
let see_regex = Regex::new(r"\\see ([`\w:()]+)").unwrap();
let param_regex = Regex::new(r"\\param ([\w:()]+)").unwrap();
let comment_indent_regex = Regex::new(r"(///\ )(\ +)").unwrap();
let http_link_regex = Regex::new(r"https:\/\/[\w.\/\-#]+").unwrap();
let transformed_dir = out_dir.join("trtx_transformed_headers");
std::fs::create_dir_all(&transformed_dir).expect("Failed to create transformed headers dir");
for entry in std::fs::read_dir(header_dir).unwrap() {
let entry = entry.unwrap();
let path = entry.path();
if path.is_file() {
let replaced = std::fs::read_to_string(&path).unwrap();
let replaced = warn_regex.replace_all(&replaced, "<div class=\"warning\"> $1 </div>");
let replaced = see_regex.replace_all(&replaced, "See [`$1`]");
let replaced = param_regex.replace_all(&replaced, "- `$1`");
let replaced = http_link_regex.replace_all(&replaced, "<$0>");
let replaced = replaced
.replace("std::size_t", "size_t")
.replace("namespace v_1_0", "inline namespace v_1_0")
.replace("namespace impl", "inline namespace impl")
.replace("ErrorCode getErrorCode", "int32_t getErrorCode")
.replace(
"bool reportError(ErrorCode val",
"bool reportError(int32_t val",
)
.replace(
"void log(Severity severity, AsciiChar const* msg)",
"void log(int32_t severity, char const* msg)",
)
.replace("//!", "///")
.replace(r"\returns", " - Returns ");
let replaced = doxy_regex.replace_all(&replaced, "");
let replaced = "#include <cstdint>\n".to_string()
+ &replaced
+ r#"
namespace nvinfer1 {
class IMoELayer;
class IDistCollectiveLayer;
enum class ComputeCapability : int32_t;
}"#;
let replaced = comment_indent_regex
.replace_all(&replaced, "$1")
.replace("\r\n", "\n");
let out_file = transformed_dir.join(path.file_name().unwrap());
write_if_changed(&out_file, replaced.as_bytes());
}
}
transformed_dir
}
fn generate_enum_bindings(crate_root: &str, out_path: &Path, include_dir: &Path) {
let header = include_dir.join("NvInfer.h").to_string_lossy().to_string();
let cuda_shim = format!("{crate_root}/TensorRT-Headers");
println!("cargo:rerun-if-changed={header}");
let mut builder = bindgen::Builder::default()
.header(header)
.default_enum_style(EnumVariation::Rust {
non_exhaustive: false,
})
.derive_default(true)
.derive_eq(true)
.derive_hash(true)
.derive_ord(true)
.blocklist_type("cu.*")
.clang_arg("-x")
.clang_arg("c++")
.clang_arg(format!("-I{}", include_dir.to_string_lossy()))
.clang_arg(format!("-I{cuda_shim}"));
for pattern in [
".*Type",
".*Mode",
".*Operation",
".*Strategy",
".*Severity",
".*Format",
".*Verbosity",
".*Feature",
".*Platform",
".*Level",
".*Capability",
".*ErrorCode",
".*Flag",
".*Selector",
".*Transformation",
".*Location",
".*Role",
".*Limit",
".*AttentionNormalizationOp",
".*SeekPosition",
".*LoopOutput",
] {
builder = builder.allowlist_type(pattern);
}
builder = builder.blocklist_type(".*IPluginCapability");
builder = builder.blocklist_type(".*IVersionedInterface");
builder = builder.blocklist_type(".*InterfaceInfo");
builder = builder.blocklist_type(".*InterfaceKind");
let bindings = builder
.generate()
.expect("Failed to generate enum bindings from NvInfer.h");
let mut output = bindings.to_string();
output = output.replace("extern \"C\"", "extern \"system\"");
output = output.replace("nvinfer1_", "");
output = output.replace("ILogger_", "");
output = output.replace("impl__EnumMaxImpl", "impl_EnumMaxImpl");
let out_file = out_path.join("enums.rs");
let enums_src = format!("/* automatically generated by bindgen */\n\n{output}");
write_if_changed(&out_file, enums_src.as_bytes());
}
fn main() {
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
let crate_root = env::var("CARGO_MANIFEST_DIR").unwrap();
let link_trt = env::var("CARGO_FEATURE_LINK_TENSORRT_RTX").is_ok();
let link_trt_onnxparser = env::var("CARGO_FEATURE_LINK_TENSORRT_ONNXPARSER").is_ok();
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_LINK_TENSORRT_RTX");
println!("cargo:rerun-if-env-changed=CARGO_FEATURE_LINK_TENSORRT_ONNXPARSER");
let trt_version = if cfg!(feature = "v_1_4") {
"1.4"
} else if cfg!(feature = "v_1_3") {
"1.3"
} else {
panic!("No version feature enabled! Need to at least enable v_1_3 or v_1_4");
};
let trt_version_suffix = if cfg!(unix) || cfg!(feature = "enterprise") {
""
} else {
if cfg!(feature = "v_1_4") {
"_1_4"
} else {
"_1_3"
}
};
let header_overwrite = env::var("TENSORRT_INCLUDE_DIR").ok();
println!("cargo:rerun-if-env-changed=TENSORRT_INCLUDE_DIR");
let lib_overwrite = env::var("TENSORRT_LIB_DIR").ok();
println!("cargo:rerun-if-env-changed=TENSORRT_LIB_DIR");
let sdk_overwrite = env::var("TENSORRT_SDK_DIR").ok();
println!("cargo:rerun-if-env-changed=TENSORRT_SDK_DIR");
if let Some(sdk_overwrite) = sdk_overwrite.as_ref() {
println!("cargo:warning=Using TENSORRT_SDK_DIR={sdk_overwrite}");
}
if let Some(lib_overwrite) = lib_overwrite.as_ref() {
println!("cargo:warning=Using TENSORRT_LIB_DIR={lib_overwrite}");
}
if let Some(header_overwrite) = header_overwrite.as_ref() {
println!("cargo:warning=Using TENSORRT_INCLUDE_DIR={header_overwrite}");
}
let include_dir = header_overwrite
.or(sdk_overwrite.clone().map(|p| format!("{p}/include")))
.map(PathBuf::from)
.unwrap_or_else(|| {
PathBuf::from(if cfg!(feature = "enterprise") {
format!("{crate_root}/TensorRT-Headers/TRT-Enterprise-10.16")
} else {
format!("{crate_root}/TensorRT-Headers/TRT-RTX-{trt_version}")
})
});
println!("cargo:rerun-if-changed={}", include_dir.display());
let cuda_shim_include_dir = format!("{crate_root}/TensorRT-Headers");
let lib_dir = lib_overwrite.or(sdk_overwrite.map(|p| format!("{p}/lib")));
generate_enum_bindings(&crate_root, &out_path, &include_dir);
let is_mock = env::var("CARGO_FEATURE_MOCK").is_ok();
println!("cargo:rerun-if-changed=src/lib.rs");
println!("cargo:rerun-if-changed=logger_bridge.hpp");
println!("cargo:rerun-if-changed=logger_bridge.cpp");
println!("cargo:rerun-if-env-changed=CUDA_ROOT");
println!("cargo:rerun-if-env-changed=LIBCLANG_PATH");
let transformed_include_dir = prepare_transformed_headers(&include_dir, &out_path);
let transformed_include_dir_str = transformed_include_dir.to_string_lossy();
if let Some(lib_dir) = lib_dir {
println!("cargo:rustc-link-search=native={}", lib_dir);
}
if link_trt {
if cfg!(feature = "enterprise") {
println!("cargo:rustc-link-lib=dylib=nvinfer");
} else {
println!("cargo:rustc-link-lib=dylib=tensorrt_rtx{trt_version_suffix}");
}
}
if link_trt_onnxparser {
if cfg!(feature = "enterprise") {
println!("cargo:rustc-link-lib=dylib=nvonnxparser");
} else {
println!("cargo:rustc-link-lib=dylib=tensorrt_onnxparser_rtx{trt_version_suffix}");
}
}
let mut cc_build = cc::Build::new();
cc_build
.cpp(true)
.file("logger_bridge.cpp")
.include(&transformed_include_dir)
.include(&cuda_shim_include_dir);
if is_mock {
cc_build.define("TRTX_MOCK_MODE", "1");
}
if link_trt {
cc_build.define("TRTX_LINK_TENSORRT_RTX", "1");
}
if link_trt_onnxparser {
cc_build.define("TRTX_LINK_TENSORRT_ONNXPARSER", "1");
}
if cfg!(target_os = "windows") && cfg!(target_env = "msvc") {
cc_build.flag("/std:c++17");
cc_build.flag("/wd4100"); cc_build.flag("/wd4996"); } else {
cc_build.flag("-std=c++17");
cc_build.flag("-Wno-unused-parameter"); cc_build.flag("-Wno-deprecated-declarations"); }
cc_build.compile("trtx_logger_bridge");
let clang_args = vec![
"-std=c++17",
"-Wno-unused-parameter", "-Wno-deprecated-declarations", ];
let mut autocxx_build = autocxx_build::Builder::new(
"src/lib.rs",
[
transformed_include_dir_str.as_ref(),
cuda_shim_include_dir.as_str(),
],
)
.extra_clang_args(&clang_args)
.build()
.expect("Failed to build autocxx bindings");
if cfg!(target_os = "windows") && cfg!(target_env = "msvc") {
autocxx_build.flag("/std:c++17");
autocxx_build.flag("/wd4100"); autocxx_build.flag("/wd4996"); } else {
autocxx_build.flag("-std=c++17");
autocxx_build.flag("-Wno-unused-parameter"); autocxx_build.flag("-Wno-deprecated-declarations"); }
autocxx_build.compile("trtx_autocxx");
}