#![forbid(unsafe_code)]
#![deny(rust_2018_idioms)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use atomr_infer_core::batch::ExecuteBatch;
use atomr_infer_core::error::{InferenceError, InferenceResult};
use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
#[cfg(feature = "tensorrt")]
pub use atomr_accel_tensorrt::{
builder::Precision,
runtime::{ExecutionBindings, TensorShape},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorRtConfig {
pub plan_path: std::path::PathBuf,
#[serde(default = "default_max_batch_size")]
pub max_batch_size: u32,
#[serde(default)]
pub precision: TrtPrecision,
#[serde(default)]
pub device_id: u32,
}
fn default_max_batch_size() -> u32 {
1
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TrtPrecision {
#[default]
Fp32,
Fp16,
Bf16,
Int8,
Fp8,
Best,
}
#[cfg(feature = "tensorrt")]
impl From<TrtPrecision> for Precision {
fn from(p: TrtPrecision) -> Self {
match p {
TrtPrecision::Fp32 => Precision::Fp32,
TrtPrecision::Fp16 => Precision::Fp16,
TrtPrecision::Bf16 => Precision::Bf16,
TrtPrecision::Int8 => Precision::Int8,
TrtPrecision::Fp8 => Precision::Fp8,
TrtPrecision::Best => Precision::Best,
}
}
}
#[cfg(feature = "tensorrt")]
struct TrtState {
engine: std::sync::Arc<atomr_accel_tensorrt::engine::TrtEngine>,
stream: std::sync::Arc<cudarc::driver::CudaStream>,
}
pub struct TensorRtRunner {
#[cfg_attr(not(feature = "tensorrt"), allow(dead_code))]
config: TensorRtConfig,
#[cfg_attr(not(feature = "tensorrt"), allow(dead_code))]
plan: Vec<u8>,
#[cfg(feature = "tensorrt")]
state: parking_lot::Mutex<Option<TrtState>>,
}
impl TensorRtRunner {
pub fn new(config: TensorRtConfig) -> InferenceResult<Self> {
let plan = std::fs::read(&config.plan_path).map_err(|e| {
InferenceError::Internal(format!(
"tensorrt: failed to read plan from {}: {e}",
config.plan_path.display()
))
})?;
Ok(Self {
config,
plan,
#[cfg(feature = "tensorrt")]
state: parking_lot::Mutex::new(None),
})
}
#[cfg(feature = "tensorrt")]
pub fn with_stream(self, stream: std::sync::Arc<cudarc::driver::CudaStream>) -> Self {
if let Some(state) = self.state.lock().as_mut() {
state.stream = stream;
}
self
}
#[cfg(feature = "tensorrt")]
pub async fn enqueue(&mut self, bindings: ExecutionBindings) -> InferenceResult<()> {
self.ensure_state()?;
let guard = self.state.lock();
let Some(state) = guard.as_ref() else {
return Err(InferenceError::Internal(
"tensorrt: state was cleared between ensure_state and lock — \
retry the enqueue"
.into(),
));
};
let _engine = state.engine.clone();
let _stream = state.stream.clone();
let _ = bindings;
Err(InferenceError::Internal(
"tensorrt: enqueue requires the `tensorrt-link` feature \
(libnvinfer must be installed and the link probe must \
succeed in atomr-accel-tensorrt's build.rs)"
.into(),
))
}
#[cfg(feature = "tensorrt")]
fn ensure_state(&self) -> InferenceResult<()> {
let mut guard = self.state.lock();
if guard.is_some() {
return Ok(());
}
let cuda_ctx = cudarc::driver::CudaContext::new(self.config.device_id as usize).map_err(|e| {
InferenceError::Internal(format!(
"tensorrt: failed to create CUDA context on device {}: {e}",
self.config.device_id
))
})?;
let stream = cuda_ctx.default_stream();
let runtime = atomr_accel_tensorrt::runtime::TrtRuntime::new().map_err(map_trt_err)?;
let engine = runtime.deserialize(&self.plan).map_err(map_trt_err)?;
let engine = std::sync::Arc::new(engine);
*guard = Some(TrtState { engine, stream });
Ok(())
}
}
#[cfg(feature = "tensorrt")]
fn map_trt_err(err: atomr_accel_tensorrt::error::TrtError) -> InferenceError {
use atomr_accel_tensorrt::error::TrtError;
match err {
TrtError::NotLinked(msg) => InferenceError::Internal(format!(
"tensorrt not linked: {msg} (rebuild with --features tensorrt-link)"
)),
TrtError::Build(m)
| TrtError::Runtime(m)
| TrtError::Execution(m)
| TrtError::Onnx(m)
| TrtError::Calibration(m)
| TrtError::Plugin(m)
| TrtError::Refit(m) => InferenceError::Internal(format!("tensorrt: {m}")),
TrtError::NullEngine => InferenceError::Internal("tensorrt: engine pointer was null".into()),
TrtError::InvalidArg(m) => InferenceError::BadRequest {
message: format!("tensorrt: invalid argument: {m}"),
},
}
}
#[async_trait]
impl ModelRunner for TensorRtRunner {
#[cfg_attr(
feature = "tensorrt",
tracing::instrument(skip(self, _batch), fields(plan = %self.config.plan_path.display()))
)]
async fn execute(&mut self, _batch: ExecuteBatch) -> InferenceResult<RunHandle> {
#[cfg(not(feature = "tensorrt"))]
{
Err(InferenceError::Internal(
"tensorrt feature disabled at build time — rebuild with --features tensorrt".into(),
))
}
#[cfg(feature = "tensorrt")]
{
self.ensure_state()?;
Err(InferenceError::Internal(
"tensorrt runner: chat-style execute requires a tokeniser layer; \
callers staging tensors directly should invoke `enqueue` with \
a prepared ExecutionBindings"
.into(),
))
}
}
async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
#[cfg(feature = "tensorrt")]
{
if matches!(
cause,
SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
) {
let plan = std::fs::read(&self.config.plan_path).map_err(|e| {
InferenceError::Internal(format!(
"tensorrt: failed to re-read plan from {}: {e}",
self.config.plan_path.display()
))
})?;
self.plan = plan;
}
*self.state.lock() = None;
}
let _ = cause;
Ok(())
}
fn runtime_kind(&self) -> RuntimeKind {
RuntimeKind::TensorRt
}
fn transport_kind(&self) -> TransportKind {
TransportKind::LocalGpu
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn missing_plan_returns_internal_error() {
let cfg = TensorRtConfig {
plan_path: std::path::PathBuf::from("/this/path/does/not/exist.plan"),
max_batch_size: 1,
precision: TrtPrecision::default(),
device_id: 0,
};
let result = TensorRtRunner::new(cfg);
assert!(matches!(result, Err(InferenceError::Internal(_))));
}
#[test]
fn empty_plan_loads_into_runner() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
std::fs::write(tmp.path(), b"").expect("write empty plan");
let cfg = TensorRtConfig {
plan_path: tmp.path().to_path_buf(),
max_batch_size: 1,
precision: TrtPrecision::Fp16,
device_id: 0,
};
let runner = TensorRtRunner::new(cfg).expect("loads empty plan");
assert_eq!(runner.runtime_kind(), RuntimeKind::TensorRt);
assert_eq!(runner.transport_kind(), TransportKind::LocalGpu);
}
#[cfg(not(feature = "tensorrt"))]
#[tokio::test]
async fn execute_without_feature_returns_internal_error() {
use atomr_infer_core::batch::SamplingParams;
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
std::fs::write(tmp.path(), b"").expect("write empty plan");
let cfg = TensorRtConfig {
plan_path: tmp.path().to_path_buf(),
max_batch_size: 1,
precision: TrtPrecision::default(),
device_id: 0,
};
let mut runner = TensorRtRunner::new(cfg).expect("loads empty plan");
let batch = ExecuteBatch {
request_id: "test".into(),
model: "test".into(),
messages: vec![],
sampling: SamplingParams::default(),
stream: false,
estimated_tokens: 1,
};
let result = runner.execute(batch).await;
assert!(matches!(result, Err(InferenceError::Internal(_))));
}
}