#![allow(
clippy::type_complexity,
clippy::too_many_arguments,
clippy::arc_with_non_send_sync
)]
pub mod actor;
pub mod builder;
pub mod engine;
pub mod error;
pub mod runtime;
pub mod sys;
#[cfg(feature = "tensorrt-onnx")]
pub mod onnx;
#[cfg(feature = "tensorrt-int8")]
pub mod calibration;
#[cfg(feature = "tensorrt-plugin")]
pub mod plugin;
pub use actor::{
BuildFromOnnxReply, BuildReply, CreateContextReply, DeserializeReply, EnqueueReply,
ExecuteReply, NetworkSource, RefitReply, RefitWeights, TrtActor, TrtMsg,
};
pub use builder::{
BuilderFlags, DeviceType, IBuilderConfig, Precision, RefitPolicy, TacticSources,
};
pub use engine::{EnginePlan, TrtEngine, TrtRefitter};
pub use error::TrtError;
pub use runtime::{EnqueueRequest, ExecutionBindings, ExecutionContext, TensorShape, TrtRuntime};
#[cfg(feature = "tensorrt-link")]
pub fn init_logger() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| unsafe {
sys::atomr_trt_install_logger(rust_log_trampoline, std::ptr::null_mut());
});
}
#[cfg(not(feature = "tensorrt-link"))]
pub fn init_logger() {}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn rust_log_trampoline(
sev: std::os::raw::c_int,
msg: *const std::os::raw::c_char,
len: usize,
_user: *mut std::os::raw::c_void,
) {
if msg.is_null() || len == 0 {
return;
}
let bytes = std::slice::from_raw_parts(msg as *const u8, len);
let text = String::from_utf8_lossy(bytes);
match sev {
0 | 1 => tracing::error!(target: "tensorrt", "{text}"),
2 => tracing::warn!(target: "tensorrt", "{text}"),
3 => tracing::info!(target: "tensorrt", "{text}"),
_ => tracing::debug!(target: "tensorrt", "{text}"),
}
}