trtx 0.7.0+rtx1.5

Safe Rust bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
//! Safe Rust bindings to NVIDIA TensorRT-RTX
//!
//! ⚠️ **EXPERIMENTAL - NOT FOR PRODUCTION USE**
//!
//! This crate is in early experimental development. The API is unstable and will change.
//! This is NOT production-ready software. Use at your own risk.
//!
//! This crate provides safe, ergonomic Rust bindings to the TensorRT-RTX library
//! for high-performance deep learning inference on NVIDIA GPUs.
//!
//! # Overview
//!
//! TensorRT-RTX enables efficient inference by:
//! - Optimizing neural network graphs
//! - Fusing layers and operations
//! - Selecting optimal kernels for your hardware
//! - Supporting dynamic shapes and batching
//!
//! # Workflow
//!
//! Using TensorRT-RTX typically follows two phases:
//!
//! ## Build Phase (Ahead-of-Time)
//!
//! 1. Create a [`Logger`] to capture TensorRT messages
//! 2. Create a [`Builder`] to construct an optimized engine
//! 3. Define your network using [`NetworkDefinition`]
//! 4. Configure optimization with [`BuilderConfig`]
//! 5. Build and serialize the engine to disk
//!
//! ## Inference Phase (Runtime)
//!
//! 1. Create a [`Runtime`] with a logger
//! 2. Deserialize the engine using [`Runtime::deserialize_cuda_engine`]
//! 3. Create an [`ExecutionContext`] from the engine
//! 4. Bind input/output tensors
//! 5. Execute inference with [`ExecutionContext::enqueue_v3`]
//!
//! # Example
//!
//! ```rust,no_run
//! use trtx::{Logger, Builder, Runtime};
//! use trtx::builder::{BuilderConfig, MemoryPoolType, network_flags};
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // Dynamically load TensorRT with optional path
//! // when using the crate's dlopen_tensorrt_rtx feature (the default, optional and a no-op when link_tensorrt_rtx is also enabled)
//! trtx::dynamically_load_tensorrt(None::<String>).unwrap();
//!
//! // Create logger
//! let logger = Logger::stderr()?;
//!
//! // Build phase
//! let mut builder = Builder::new(&logger)?;
//! let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
//! let mut config = builder.create_config()?;
//!
//! // Configure memory
//! config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 30);
//!
//! // Build and serialize
//! let engine_data = builder.build_serialized_network(&mut network, &mut config)?;
//! std::fs::write("model.engine", &engine_data)?;
//!
//! // Inference phase
//! let mut runtime = Runtime::new(&logger)?;
//! let mut engine = runtime.deserialize_cuda_engine(&engine_data)?;
//! let context = engine.create_execution_context()?;
//!
//! // List I/O tensors
//! let num_tensors = engine.nb_io_tensors()?;
//! for i in 0..num_tensors {
//!     let name = engine.io_tensor_name(i)?;
//!     println!("Tensor {}: {}", i, name);
//! }
//! # Ok(())
//! # }
//! ```
//!
//! # Safety
//!
//! This crate provides safe abstractions over the underlying C++ API. However,
//! some operations (like setting tensor addresses and enqueueing inference)
//! require careful management of CUDA memory and are marked as `unsafe`.
//!
//! ### Required (building)
//!
//! 1. **libclang < 22**: Required for autocxx. On Windows: `winget install LLVM.LLVM -v 20.1.2`
//!
//!   *Important*: libclang version 22 or greater will cause a compilation error
//!
//!   You can steer discovery of libclang using LIBCLANG_PATH environment variable if auto-discovery
//!   discovers a wrong version of libclang, e.g.
//!
//!   $env:LIBCLANG_PATH="D:\programs\LLVM\bin"  # powershell windows
//!   export LIBCLANG_PATH=/usr/lib/llvm-19/lib  # linux
//!
//!   See <https://rust-lang.github.io/rust-bindgen/requirements.html> (note that autocxx uses an older fork of bindgen)
//!
//! TensorRT is by default dynamically loaded. So, the TensorRT SDK is only required for building
//! with Cargo features `link_tensorrt_rtx`/ `link_tensorrt_onnxparser` which would link the TensorRT libraries.
//! Use `TENSORRT_RTX_DIR` to point to the TensorRT SDK root directory (the path that contains the `lib` folder with the shared libraries).
//!
//! ### Required (GPU execution)
//!
//! 1. **NVIDIA TensorRT-RTX**: Download and install from [NVIDIA Developer](https://developer.nvidia.com/tensorrt)
//!      - The TensorRT libraries should be in a location where they can be dynamically loaded.
//!        (e.g. by setting PATH on Windows or LD_LIBRARY_PATH on Linux)
//!      - This crate currently requires TensorRT RTX version 1.3, 1.4 or 1.5 (see Cargo feature `v_1_3`, `v_1_4`, `v_1_5`).
//!        Use `default-features = false` plus version feature to select version.
//!        You will also have to either enable `dlopen_tensorrt_rtx` or `link_tensorrt_rtx`.
//!
//! 2. **NVIDIA GPU**: Compatible with TensorRT-RTX requirements
//!
//! # C++ API reference
//!
//! Rust types in this crate wrap TensorRT for RTX C++ interfaces. The authoritative class list and
//! method documentation is the
//! [TensorRT for RTX C++ API (annotated)](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/annotated.html).
//! Each wrapper’s docs also link the Rust FFI type in [`trtx_sys::nvinfer1`] or [`trtx_sys::nvonnxparser`]
//! alongside the matching C++ class on NVIDIA’s site.

// Allow unnecessary casts - they're needed for real mode (u32) but not mock mode (i32)
#![cfg_attr(
    any(feature = "mock", feature = "mock_runtime"),
    allow(clippy::unnecessary_cast)
)]
// We don't use real parameters in mocks
#![cfg_attr(any(feature = "mock", feature = "mock_runtime"), allow(unused))]
#![cfg_attr(
    any(feature = "mock", feature = "mock_runtime"),
    allow(unused_variables)
)]

pub mod axes;
pub mod builder;
pub mod builder_config;
pub mod cuda;
pub mod cuda_engine;
pub mod engine_inspector;
pub mod error;
pub mod execution_context;
#[cfg(feature = "onnxparser")]
pub mod executor;
pub mod host_memory;
pub mod interfaces;
pub mod logger;
pub mod network;
#[cfg(feature = "onnxparser")]
pub mod onnx_parser;
pub mod optimization_profile;
pub mod refitter;
pub mod runtime;
#[cfg(not(feature = "enterprise"))]
pub mod runtime_cache;
pub mod runtime_config;
pub mod tensor;

// Re-export commonly used types
pub use axes::Axes;
pub use builder::{Builder, BuilderConfig};
pub use cuda::{default_stream, synchronize, DeviceBuffer};
pub use error::{Error, Result};
#[cfg(feature = "onnxparser")]
pub use executor::{run_onnx_with_tensorrt, run_onnx_zeroed};
#[cfg(feature = "onnxparser")]
pub use executor::{TensorInput, TensorOutput};
#[cfg(feature = "dlopen_tensorrt_rtx")]
use libloading::AsFilename;
pub use logger::{LogHandler, Logger, Severity, StderrLogger};
pub use network::{ConvWeights, NetworkDefinition, OwnedConvWeights, OwnedWeights, Tensor};
#[cfg(feature = "onnxparser")]
pub use onnx_parser::OnnxParser;
pub use refitter::Refitter;
#[cfg(not(feature = "enterprise"))]
pub use runtime::RuntimeCache;
pub use runtime::{CudaEngine, EngineInspector, ExecutionContext, Runtime, RuntimeConfig};

#[cfg(feature = "dlopen_tensorrt_rtx")]
#[cfg(not(any(feature = "link_tensorrt_rtx", feature = "mock")))]
pub(crate) static TRTLIB: std::sync::RwLock<Option<libloading::Library>> =
    std::sync::RwLock::new(None);

#[cfg(feature = "dlopen_tensorrt_rtx")]
pub fn dynamically_load_tensorrt(_filename: Option<impl AsFilename>) -> Result<()> {
    #[cfg(not(any(feature = "link_tensorrt_rtx", feature = "mock")))]
    {
        use log::{debug, warn};
        if TRTLIB.read()?.is_some() {
            return Ok(());
        }
        let mut write = TRTLIB.write()?;
        if write.is_none() {
            let lib = if let Some(filename) = _filename {
                debug!("Loading library TensorRT library");
                unsafe { libloading::Library::new(filename) }
            } else {
                unsafe {
                    libloading::Library::new({
                        let filename = if cfg!(unix) {
                            if cfg!(feature = "enterprise") {
                                "libnvinfer.so".to_string()
                            } else {
                                use trtx_sys::{
                                    get_tensorrt_major_version, get_tensorrt_minor_version,
                                    get_tensorrt_patch_version,
                                };
                                format!(
                                    "libtensorrt_rtx.so.{}.{}.{}",
                                    get_tensorrt_major_version(),
                                    get_tensorrt_minor_version(),
                                    get_tensorrt_patch_version()
                                )
                            }
                        } else {
                            if cfg!(feature = "enterprise") {
                                "nvinfer.dll".to_string()
                            } else {
                                use trtx_sys::{
                                    get_tensorrt_major_version, get_tensorrt_minor_version,
                                };

                                // yes, this uses the version from tensort and not nvonnxparser version
                                format!(
                                    "tensorrt_rtx_{}_{}.dll",
                                    get_tensorrt_major_version(),
                                    get_tensorrt_minor_version()
                                )
                            }
                        };

                        debug!("Loading library {filename} as TensorRT library");
                        filename
                    })
                }
            }
            .inspect_err(|e| warn!("Failed to load TensorRT library: {e:?}"))?;

            *write = Some(lib);
        }
    }
    Ok(())
}

#[cfg(feature = "dlopen_tensorrt_onnxparser")]
#[cfg(not(any(feature = "link_tensorrt_onnxparser", feature = "mock")))]
pub(crate) static TRT_ONNXPARSER_LIB: std::sync::RwLock<Option<libloading::Library>> =
    std::sync::RwLock::new(None);

#[cfg(feature = "dlopen_tensorrt_onnxparser")]
pub fn dynamically_load_tensorrt_onnxparser(_filename: Option<impl AsFilename>) -> Result<()> {
    #[cfg(not(any(feature = "link_tensorrt_onnxparser", feature = "mock")))]
    {
        use log::{debug, warn};
        let mut write = TRT_ONNXPARSER_LIB.write()?;
        if write.is_some() {
            return Ok(());
        }

        let lib = if let Some(filename) = _filename {
            debug!("Loading library TensorRT onnxparser library",);
            unsafe { libloading::Library::new(filename) }
        } else {
            unsafe {
                libloading::Library::new({
                    let filename = if cfg!(unix) {
                        if cfg!(feature = "enterprise") {
                            "libnvonnxparser.so".to_string()
                        } else {
                            use trtx_sys::{
                                get_tensorrt_major_version, get_tensorrt_minor_version,
                                get_tensorrt_patch_version,
                            };
                            format!(
                                "libtensorrt_onnxparser_rtx.so.{}.{}.{}",
                                get_tensorrt_major_version(),
                                get_tensorrt_minor_version(),
                                get_tensorrt_patch_version()
                            )
                        }
                    } else {
                        if cfg!(feature = "enterprise") {
                            "nvonnxparser.dll".to_string()
                        } else {
                            use trtx_sys::{
                                get_tensorrt_major_version, get_tensorrt_minor_version,
                            };

                            format!(
                                "tensorrt_onnxparser_rtx_{}_{}.dll",
                                get_tensorrt_major_version(),
                                get_tensorrt_minor_version()
                            )
                        }
                    };

                    debug!("Loading library {filename} as TensorRT onnxparser library");
                    filename
                })
            }
        }
        .inspect_err(|e| warn!("Failed to load TensorRT onnxparser library: {e:?}"))?;

        *write = Some(lib);
    }
    Ok(())
}

// Re-export TensorRT enums
pub use trtx_sys::{
    self, ActivationType, AttentionNormalizationOp, BuilderFlag, CumulativeOperation, DataType,
    DeviceType, ElementWiseOperation, EngineCapability, ExecutionContextAllocationStrategy,
    FillOperation, GatherMode, HardwareCompatibilityLevel, InterpolationMode, KVCacheMode,
    LayerInformationFormat, LayerType, LoopOutput, MatrixOperation, MemoryPoolType,
    OptProfileSelector, PaddingMode, PoolingType, PreviewFeature, ProfilingVerbosity,
    ReduceOperation, ResizeCoordinateTransformation, ResizeMode, ResizeRoundMode, ResizeSelector,
    RuntimePlatform, SampleMode, ScaleMode, ScatterMode, SeekPosition, SerializationFlag,
    TensorFormat, TensorIOMode, TensorLocation, TilingOptimizationLevel, TopKOperation, TripLimit,
    UnaryOperation, WeightsRole,
};

#[cfg(not(feature = "enterprise"))]
pub use trtx_sys::{
    ComputeCapability, CudaGraphStrategy, DynamicShapesKernelSpecializationStrategy,
};

#[cfg(feature = "v_1_4")]
pub use trtx_sys::{CollectiveOperation, MoEActType};

#[cfg(feature = "v_1_5")]
pub use trtx_sys::{AttentionIOForm, CausalMaskKind};