atomr_accel_tensorrt/lib.rs
1//! # atomr-accel-tensorrt
2//!
3//! TensorRT engine builder + runtime as supervised atomr actors.
4//! Wraps NVIDIA's libnvinfer (and optionally libnvonnxparser) at
5//! runtime — the library itself is **not** vendored because it is
6//! proprietary; users opt in via the `tensorrt-link` feature and
7//! either install TensorRT system-wide or set `LIBNVINFER_PATH`.
8//!
9//! ## Features
10//!
11//! - `tensorrt-link` — actually link libnvinfer at build time.
12//! Off-by-default so the crate compiles on hosts without
13//! TensorRT (used by CI + unit tests).
14//! - `tensorrt-onnx` — pull in `nvonnxparser` for ONNX import.
15//! - `tensorrt-plugin` — `IPluginV3` Rust trampolines.
16//! - `tensorrt-int8` — INT8 calibration helpers (entropy / minmax).
17//! - `tensorrt-fp8` — FP8 PTQ helpers (Hopper-class GPUs).
18//!
19//! ## Public surface
20//!
21//! - [`actor::TrtActor`] / [`actor::TrtMsg`] — sibling actor to
22//! `atomr_accel_cuda::DeviceActor`. Shares `Arc<CudaStream>` with
23//! the device actor so inference rides the same execution
24//! timeline.
25//! - [`builder::IBuilderConfig`] — pure-Rust mirror of the TensorRT
26//! builder config, with knobs for precision, DLA, structured
27//! sparsity, tactic sources, timing cache, and engine refit.
28//! - [`engine::TrtEngine`] — owned, immutable engine handle that's
29//! `Send + Sync` via newtype.
30//! - [`runtime::TrtRuntime`] / [`runtime::ExecutionContext`] — load
31//! serialised plans + drive `enqueueV3` on a shared CUDA stream.
32//! - [`onnx::OnnxParser`] — gated on `tensorrt-onnx`.
33//! - [`calibration`] — gated on `tensorrt-int8` / `tensorrt-fp8`.
34//! - [`plugin`] — gated on `tensorrt-plugin`.
35
36#![allow(
37 clippy::type_complexity,
38 clippy::too_many_arguments,
39 clippy::arc_with_non_send_sync
40)]
41
42// The `tensorrt-link` feature compiles `csrc/nvinfer_shim.cpp` and
43// links the resulting static lib against system libnvinfer (and
44// libnvonnxparser / libnvinfer_plugin when their sub-features are
45// also on). See `build.rs` for the probe order and the env-var
46// contract (`LIBNVINFER_PATH`, `TENSORRT_INCLUDE_PATH`, `CUDA_PATH`).
47
48pub mod actor;
49pub mod builder;
50pub mod engine;
51pub mod error;
52pub mod runtime;
53pub mod sys;
54
55#[cfg(feature = "tensorrt-onnx")]
56pub mod onnx;
57
58#[cfg(feature = "tensorrt-int8")]
59pub mod calibration;
60
61#[cfg(feature = "tensorrt-plugin")]
62pub mod plugin;
63
64pub use actor::{
65 BuildFromOnnxReply, BuildReply, CreateContextReply, DeserializeReply, EnqueueReply,
66 ExecuteReply, NetworkSource, RefitReply, RefitWeights, TrtActor, TrtMsg,
67};
68pub use builder::{
69 BuilderFlags, DeviceType, IBuilderConfig, Precision, RefitPolicy, TacticSources,
70};
71pub use engine::{EnginePlan, TrtEngine, TrtRefitter};
72pub use error::TrtError;
73pub use runtime::{EnqueueRequest, ExecutionBindings, ExecutionContext, TensorShape, TrtRuntime};
74
75/// Install the Rust→`tracing` logger bridge into the C++ shim's
76/// `RustBridgeLogger`. Called from `TrtActor::new`, `TrtRuntime::new`,
77/// and `IBuilderConfig::default` so any entry point sets it up before
78/// the first TRT call. `Once` makes it idempotent across the process
79/// lifetime.
80#[cfg(feature = "tensorrt-link")]
81pub fn init_logger() {
82 use std::sync::Once;
83 static INIT: Once = Once::new();
84 INIT.call_once(|| unsafe {
85 sys::atomr_trt_install_logger(rust_log_trampoline, std::ptr::null_mut());
86 });
87}
88
89/// No-op when the `tensorrt-link` feature is off so callers don't
90/// need to gate their `init_logger()` calls.
91#[cfg(not(feature = "tensorrt-link"))]
92pub fn init_logger() {}
93
94#[cfg(feature = "tensorrt-link")]
95unsafe extern "C" fn rust_log_trampoline(
96 sev: std::os::raw::c_int,
97 msg: *const std::os::raw::c_char,
98 len: usize,
99 _user: *mut std::os::raw::c_void,
100) {
101 if msg.is_null() || len == 0 {
102 return;
103 }
104 let bytes = std::slice::from_raw_parts(msg as *const u8, len);
105 let text = String::from_utf8_lossy(bytes);
106 match sev {
107 0 | 1 => tracing::error!(target: "tensorrt", "{text}"),
108 2 => tracing::warn!(target: "tensorrt", "{text}"),
109 3 => tracing::info!(target: "tensorrt", "{text}"),
110 _ => tracing::debug!(target: "tensorrt", "{text}"),
111 }
112}