trtx 0.4.0

Safe Rust bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
//! Safe Rust bindings to NVIDIA TensorRT-RTX
//!
//! ⚠️ **EXPERIMENTAL - NOT FOR PRODUCTION USE**
//!
//! This crate is in early experimental development. The API is unstable and will change.
//! This is NOT production-ready software. Use at your own risk.
//!
//! This crate provides safe, ergonomic Rust bindings to the TensorRT-RTX library
//! for high-performance deep learning inference on NVIDIA GPUs.
//!
//! # Overview
//!
//! TensorRT-RTX enables efficient inference by:
//! - Optimizing neural network graphs
//! - Fusing layers and operations
//! - Selecting optimal kernels for your hardware
//! - Supporting dynamic shapes and batching
//!
//! # Workflow
//!
//! Using TensorRT-RTX typically follows two phases:
//!
//! ## Build Phase (Ahead-of-Time)
//!
//! 1. Create a [`Logger`] to capture TensorRT messages
//! 2. Create a [`Builder`] to construct an optimized engine
//! 3. Define your network using [`NetworkDefinition`]
//! 4. Configure optimization with [`BuilderConfig`]
//! 5. Build and serialize the engine to disk
//!
//! ## Inference Phase (Runtime)
//!
//! 1. Create a [`Runtime`] with a logger
//! 2. Deserialize the engine using [`Runtime::deserialize_cuda_engine`]
//! 3. Create an [`ExecutionContext`] from the engine
//! 4. Bind input/output tensors
//! 5. Execute inference with [`ExecutionContext::enqueue_v3`]
//!
//! # Example
//!
//! ```rust,no_run
//! use trtx::{Logger, Builder, Runtime};
//! use trtx::builder::{BuilderConfig, MemoryPoolType, network_flags};
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // Dynamically load TensorRT with optional path
//! // when using the crate's dlopen_tensorrt_rtx feature (the default, optional and a no-op when link_tensorrt_rtx is also enabled)
//! trtx::dynamically_load_tensorrt(None::<String>).unwrap();
//!
//! // Create logger
//! let logger = Logger::stderr()?;
//!
//! // Build phase
//! let mut builder = Builder::new(&logger)?;
//! let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
//! let mut config = builder.create_config()?;
//!
//! // Configure memory
//! config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 30);
//!
//! // Build and serialize
//! let engine_data = builder.build_serialized_network(&mut network, &mut config)?;
//! std::fs::write("model.engine", &engine_data)?;
//!
//! // Inference phase
//! let mut runtime = Runtime::new(&logger)?;
//! let mut engine = runtime.deserialize_cuda_engine(&engine_data)?;
//! let context = engine.create_execution_context()?;
//!
//! // List I/O tensors
//! let num_tensors = engine.get_nb_io_tensors()?;
//! for i in 0..num_tensors {
//!     let name = engine.get_tensor_name(i)?;
//!     println!("Tensor {}: {}", i, name);
//! }
//! # Ok(())
//! # }
//! ```
//!
//! # Safety
//!
//! This crate provides safe abstractions over the underlying C++ API. However,
//! some operations (like setting tensor addresses and enqueueing inference)
//! require careful management of CUDA memory and are marked as `unsafe`.
//!
//! ### Required (building)
//!
//! 1. **Clang**: Required for autocxx. On Windows: `winget install LLVM.LLVM`
//!
//! TensorRT is by default dynamically loaded. So, the TensorRT SDK is only required for building
//! with Cargo features `link_tensorrt_rtx`/ `link_tensorrt_onnxparser` which would link the TensorRT libraries.
//! Use `TENSORRT_RTX_DIR` to point to the TensorRT SDK root directory (the path that contains the `lib` folder with the shared libraries).
//!
//! ### Required (GPU execution)
//!
//! 1. **NVIDIA TensorRT-RTX**: Download and install from [NVIDIA Developer](https://developer.nvidia.com/tensorrt)
//!      - The TensorRT libraries should be in a location where they can be dynamically loaded.
//!        (e.g. by setting PATH on Windows or LD_LIBRARY_PATH on Linux)
//!      - This crate currently requires TensorRT RTX version 1.3 or 1.4 (see Cargo feature `v_1_3`, `v_1_4`)
//!        Use `default-features = false` plus version feature to select version.
//!        You will also have to either enable `dlopen_tensorrt_rtx` or `link_tensorrt_rtx`.
//!
//! 2. **NVIDIA GPU**: Compatible with TensorRT-RTX requirements
//!
//! # C++ API reference
//!
//! Rust types in this crate wrap TensorRT for RTX C++ interfaces. The authoritative class list and
//! method documentation is the
//! [TensorRT for RTX C++ API (annotated)](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/annotated.html).
//! Each wrapper’s docs also link the Rust FFI type in [`trtx_sys::nvinfer1`] or [`trtx_sys::nvonnxparser`]
//! alongside the matching C++ class on NVIDIA’s site.

// Allow unnecessary casts - they're needed for real mode (u32) but not mock mode (i32)
#![cfg_attr(
    any(feature = "mock", feature = "mock_runtime"),
    allow(clippy::unnecessary_cast)
)]
// We don't use real parameters in mocks
#![cfg_attr(any(feature = "mock", feature = "mock_runtime"), allow(unused))]
#![cfg_attr(
    any(feature = "mock", feature = "mock_runtime"),
    allow(unused_variables)
)]

pub mod axes;
pub mod builder;
pub mod builder_config;
pub mod cuda;
pub mod cuda_engine;
pub mod engine_inspector;
pub mod error;
#[cfg(feature = "onnxparser")]
pub mod executor;
pub mod host_memory;
pub mod interfaces;
pub mod logger;
pub mod network;
#[cfg(feature = "onnxparser")]
pub mod onnx_parser;
pub mod optimization_profile;
pub mod refitter;
pub mod runtime;

// Re-export commonly used types
pub use axes::Axes;
pub use builder::{Builder, BuilderConfig, ProfilingVerbosity};
pub use cuda::{get_default_stream, synchronize, DeviceBuffer};
pub use error::{Error, Result};
#[cfg(feature = "onnxparser")]
pub use executor::{run_onnx_with_tensorrt, run_onnx_zeroed};
#[cfg(feature = "onnxparser")]
pub use executor::{TensorInput, TensorOutput};
#[cfg(feature = "dlopen_tensorrt_rtx")]
use libloading::AsFilename;
pub use logger::{LogHandler, Logger, Severity, StderrLogger};
pub use network::{ConvWeights, NetworkDefinition, OwnedConvWeights, OwnedWeights, Tensor};
#[cfg(feature = "onnxparser")]
pub use onnx_parser::OnnxParser;
pub use refitter::Refitter;
pub use runtime::{CudaEngine, EngineInspector, ExecutionContext, Runtime};

#[cfg(feature = "dlopen_tensorrt_rtx")]
#[cfg(not(any(feature = "link_tensorrt_rtx", feature = "mock")))]
pub(crate) static TRTLIB: std::sync::RwLock<Option<libloading::Library>> =
    std::sync::RwLock::new(None);

#[cfg(feature = "dlopen_tensorrt_rtx")]
pub fn dynamically_load_tensorrt(_filename: Option<impl AsFilename>) -> Result<()> {
    #[cfg(not(any(feature = "link_tensorrt_rtx", feature = "mock")))]
    {
        use log::debug;
        if TRTLIB.read()?.is_some() {
            return Ok(());
        }
        let lib = if let Some(filename) = _filename {
            debug!("Loading library TensorRT library");
            unsafe { libloading::Library::new(filename) }
        } else {
            unsafe {
                libloading::Library::new({
                    let filename = if cfg!(unix) {
                        if cfg!(feature = "enterprise") {
                            "libnvinfer.so"
                        } else if cfg!(feature = "v_1_4") {
                            "libtensorrt_rtx.so.1.4.0"
                        } else {
                            "libtensorrt_rtx.so.1.3.0"
                        }
                    } else {
                        if cfg!(feature = "enterprise") {
                            "nvinfer.dll"
                        } else if cfg!(feature = "v_1_4") {
                            "tensorrt_rtx_1_4.dll"
                        } else {
                            "tensorrt_rtx_1_3.dll"
                        }
                    };

                    debug!("Loading library {filename} as TensorRT library");
                    filename
                })
            }
        }?;

        *TRTLIB.write()? = Some(lib);
    }
    Ok(())
}

#[cfg(feature = "dlopen_tensorrt_onnxparser")]
#[cfg(not(any(feature = "link_tensorrt_onnxparser", feature = "mock")))]
pub(crate) static TRT_ONNXPARSER_LIB: std::sync::RwLock<Option<libloading::Library>> =
    std::sync::RwLock::new(None);

#[cfg(feature = "dlopen_tensorrt_onnxparser")]
pub fn dynamically_load_tensorrt_onnxparser(_filename: Option<impl AsFilename>) -> Result<()> {
    #[cfg(not(any(feature = "link_tensorrt_onnxparser", feature = "mock")))]
    {
        use log::debug;
        if TRT_ONNXPARSER_LIB.read()?.is_some() {
            return Ok(());
        }
        let lib = if let Some(filename) = _filename {
            debug!("Loading library TensorRT onnxparser library",);
            unsafe { libloading::Library::new(filename) }
        } else {
            unsafe {
                libloading::Library::new({
                    let filename = if cfg!(unix) {
                        if cfg!(feature = "enterprise") {
                            "libnvonnxparser.so"
                        } else if cfg!(feature = "v_1_4") {
                            "libtensorrt_onnxparser_rtx.so.1.4.0"
                        } else {
                            "libtensorrt_onnxparser_rtx.so.1.3.0"
                        }
                    } else {
                        if cfg!(feature = "enterprise") {
                            "nvonnxparser.dll"
                        } else if cfg!(feature = "v_1_4") {
                            "tensorrt_onnxparser_rtx_1_4.dll"
                        } else {
                            "tensorrt_onnxparser_rtx_1_3.dll"
                        }
                    };

                    debug!("Loading library {filename} as TensorRT onnxparser library");
                    filename
                })
            }
        }?;

        *TRT_ONNXPARSER_LIB.write()? = Some(lib);
    }
    Ok(())
}

// Re-export TensorRT enums
pub use trtx_sys::{
    ActivationType, CumulativeOperation, DataType, ElementWiseOperation, GatherMode,
    InterpolationMode, LayerInformationFormat, LayerType, MatrixOperation, PaddingMode,
    PoolingType, ReduceOperation, ResizeCoordinateTransformation, ResizeMode, ResizeRoundMode,
    ResizeSelector, SampleMode, ScaleMode, ScatterMode, TensorFormat, TensorIOMode, TopKOperation,
    UnaryOperation,
};

#[cfg(not(feature = "enterprise"))]
pub use trtx_sys::ComputeCapability;