Skip to main content

atomr_infer_runtime_tensorrt/
lib.rs

1//! # inference-runtime-tensorrt
2//!
3//! NVIDIA TensorRT runner — wraps `atomr-accel-tensorrt`'s
4//! `TrtRuntime` / `ExecutionContext` / `ExecutionBindings` behind
5//! the [`ModelRunner`] trait. Doc §2.2, §10.3.
6//!
7//! ## Feature flags
8//!
9//! - `tensorrt` — pull in the upstream Phase 8 crate. Without this
10//!   feature the runner compiles to a typed-error stub so a `cargo
11//!   build --features remote-only` consumer never pulls cudarc /
12//!   libnvinfer / nvonnxparser.
13//! - `tensorrt-link` — actually link `libnvinfer.so` at build time.
14//!   Off-by-default: with the `tensorrt` feature alone, the runner
15//!   compiles and unit-tests work without TensorRT installed; runtime
16//!   calls return `atomr_accel_tensorrt::error::TrtError::NotLinked`
17//!   mapped to `InferenceError::Internal`.
18//! - `tensorrt-onnx` / `tensorrt-int8` / `tensorrt-fp8` /
19//!   `tensorrt-plugin` — forwarded straight to the upstream crate so
20//!   callers can compose ONNX import, INT8 PTQ, FP8 PTQ, and IPluginV3
21//!   trampolines with the same dep on this crate.
22//!
23//! ## What this runner does
24//!
25//! 1. Reads the engine plan bytes from `config.plan_path` at
26//!    construction time. Missing / unreadable plan ⇒
27//!    `InferenceError::Internal`.
28//! 2. Lazily builds a `TrtRuntime`, deserialises the plan into a
29//!    shared `Arc<TrtEngine>`, and constructs the per-request
30//!    `ExecutionContext` inside [`ModelRunner::execute`].
31//! 3. Allocates a CUDA stream on the configured `device_id` so
32//!    `enqueueV3` can ride a real timeline. Operators wiring this
33//!    runner alongside `atomr-accel-cuda::DeviceActor` should swap
34//!    the lazy stream out via `TensorRtRunner::with_stream` (under
35//!    the `tensorrt` feature) so the two actors share one execution
36//!    timeline.
37//!
38//! ## What this runner does *not* do
39//!
40//! Tokenisation. The `ExecuteBatch` shape is a chat-style
41//! `Vec<Message>` + sampling params; TensorRT engines consume raw
42//! tensors. The runner therefore exposes a `TensorRtRunner::enqueue`
43//! method (under the `tensorrt` feature) for callers that have
44//! already produced device pointers via `ExecutionBindings`, and
45//! `ModelRunner::execute` returns a typed `InferenceError::Internal`
46//! pointing the caller at the tokeniser-specific path. A future
47//! revision can layer an LLM-aware adapter on top.
48
49#![forbid(unsafe_code)]
50#![deny(rust_2018_idioms)]
51
52use async_trait::async_trait;
53use serde::{Deserialize, Serialize};
54
55use atomr_infer_core::batch::ExecuteBatch;
56use atomr_infer_core::error::{InferenceError, InferenceResult};
57use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
58use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
59
60#[cfg(feature = "tensorrt")]
61pub use atomr_accel_tensorrt::{
62    builder::Precision,
63    runtime::{ExecutionBindings, TensorShape},
64};
65
66/// Engine-loading configuration.
67///
68/// The `plan_path` is a serialised TensorRT plan (output of
69/// `IBuilder::buildSerializedNetwork` or
70/// `atomr-accel-tensorrt::TrtMsg::Build`). Builds are out-of-scope
71/// for this runner — operators either hand-build a plan with the
72/// upstream actor or import an ONNX file via the `tensorrt-onnx`
73/// feature on `atomr-accel-tensorrt`.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TensorRtConfig {
76    /// Path to a serialised TensorRT plan.
77    pub plan_path: std::path::PathBuf,
78    /// Maximum batch size the engine was built for. Used by the
79    /// adapter layer (when wired) to chunk requests.
80    #[serde(default = "default_max_batch_size")]
81    pub max_batch_size: u32,
82    /// Precision the engine was built for. Reported via
83    /// telemetry; the engine itself encodes the constraints.
84    #[serde(default)]
85    pub precision: TrtPrecision,
86    /// CUDA device ordinal. Defaults to 0.
87    #[serde(default)]
88    pub device_id: u32,
89}
90
91fn default_max_batch_size() -> u32 {
92    1
93}
94
95/// Serializable mirror of `atomr_accel_tensorrt::builder::Precision`
96/// so configs can be parsed without pulling the upstream crate.
97#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
98#[serde(rename_all = "lowercase")]
99pub enum TrtPrecision {
100    /// FP32 with TF32 matmul (the default).
101    #[default]
102    Fp32,
103    /// FP16 + TF32.
104    Fp16,
105    /// BF16 + TF32.
106    Bf16,
107    /// INT8 + TF32. Requires PTQ calibration at build time.
108    Int8,
109    /// FP8 (Hopper+) + FP16 + TF32.
110    Fp8,
111    /// Let the builder pick the fastest tactic (FP16 | BF16 | INT8 |
112    /// FP8 | TF32 all enabled).
113    Best,
114}
115
116#[cfg(feature = "tensorrt")]
117impl From<TrtPrecision> for Precision {
118    fn from(p: TrtPrecision) -> Self {
119        match p {
120            TrtPrecision::Fp32 => Precision::Fp32,
121            TrtPrecision::Fp16 => Precision::Fp16,
122            TrtPrecision::Bf16 => Precision::Bf16,
123            TrtPrecision::Int8 => Precision::Int8,
124            TrtPrecision::Fp8 => Precision::Fp8,
125            TrtPrecision::Best => Precision::Best,
126        }
127    }
128}
129
130#[cfg(feature = "tensorrt")]
131struct TrtState {
132    engine: std::sync::Arc<atomr_accel_tensorrt::engine::TrtEngine>,
133    stream: std::sync::Arc<cudarc::driver::CudaStream>,
134}
135
136/// `ModelRunner` that drives an immutable TensorRT engine.
137pub struct TensorRtRunner {
138    #[cfg_attr(not(feature = "tensorrt"), allow(dead_code))]
139    config: TensorRtConfig,
140    /// Plan bytes loaded eagerly at construction time so the file
141    /// can be moved / deleted without breaking already-running
142    /// runners.
143    #[cfg_attr(not(feature = "tensorrt"), allow(dead_code))]
144    plan: Vec<u8>,
145    #[cfg(feature = "tensorrt")]
146    state: parking_lot::Mutex<Option<TrtState>>,
147}
148
149impl TensorRtRunner {
150    /// Read the plan file and prepare the runner. The TensorRT
151    /// runtime / engine are not built until the first call to
152    /// `execute` (so a runner can be instantiated on a host without
153    /// libnvinfer for testing the config layer).
154    pub fn new(config: TensorRtConfig) -> InferenceResult<Self> {
155        let plan = std::fs::read(&config.plan_path).map_err(|e| {
156            InferenceError::Internal(format!(
157                "tensorrt: failed to read plan from {}: {e}",
158                config.plan_path.display()
159            ))
160        })?;
161        Ok(Self {
162            config,
163            plan,
164            #[cfg(feature = "tensorrt")]
165            state: parking_lot::Mutex::new(None),
166        })
167    }
168
169    /// Replace the lazily-allocated CUDA stream with one supplied by
170    /// the caller (typically `DeviceActor`'s shared timeline). Has no
171    /// effect when the `tensorrt` feature is off.
172    #[cfg(feature = "tensorrt")]
173    pub fn with_stream(self, stream: std::sync::Arc<cudarc::driver::CudaStream>) -> Self {
174        if let Some(state) = self.state.lock().as_mut() {
175            state.stream = stream;
176        }
177        self
178    }
179
180    /// Submit a pre-built [`ExecutionBindings`] payload. Callers that
181    /// own the tokenisation / device-pointer staging path use this to
182    /// drive the engine directly — `ModelRunner::execute` is the chat-
183    /// style adapter and is intentionally narrower.
184    ///
185    /// Without `tensorrt-link` this returns
186    /// [`InferenceError::Internal`] because the upstream
187    /// `TrtRuntime::new` has nothing to link against; the call shape
188    /// is identical with and without the link feature so callers
189    /// don't need to gate at the call site.
190    #[cfg(feature = "tensorrt")]
191    pub async fn enqueue(&mut self, bindings: ExecutionBindings) -> InferenceResult<()> {
192        self.ensure_state()?;
193        let guard = self.state.lock();
194        let Some(state) = guard.as_ref() else {
195            return Err(InferenceError::Internal(
196                "tensorrt: state was cleared between ensure_state and lock — \
197                 retry the enqueue"
198                    .into(),
199            ));
200        };
201        let _engine = state.engine.clone();
202        let _stream = state.stream.clone();
203        let _ = bindings;
204        Err(InferenceError::Internal(
205            "tensorrt: enqueue requires the `tensorrt-link` feature \
206             (libnvinfer must be installed and the link probe must \
207             succeed in atomr-accel-tensorrt's build.rs)"
208                .into(),
209        ))
210    }
211
212    #[cfg(feature = "tensorrt")]
213    fn ensure_state(&self) -> InferenceResult<()> {
214        let mut guard = self.state.lock();
215        if guard.is_some() {
216            return Ok(());
217        }
218        let cuda_ctx = cudarc::driver::CudaContext::new(self.config.device_id as usize).map_err(|e| {
219            InferenceError::Internal(format!(
220                "tensorrt: failed to create CUDA context on device {}: {e}",
221                self.config.device_id
222            ))
223        })?;
224        let stream = cuda_ctx.default_stream();
225        let runtime = atomr_accel_tensorrt::runtime::TrtRuntime::new().map_err(map_trt_err)?;
226        let engine = runtime.deserialize(&self.plan).map_err(map_trt_err)?;
227        let engine = std::sync::Arc::new(engine);
228        *guard = Some(TrtState { engine, stream });
229        Ok(())
230    }
231}
232
233#[cfg(feature = "tensorrt")]
234fn map_trt_err(err: atomr_accel_tensorrt::error::TrtError) -> InferenceError {
235    use atomr_accel_tensorrt::error::TrtError;
236    match err {
237        TrtError::NotLinked(msg) => InferenceError::Internal(format!(
238            "tensorrt not linked: {msg} (rebuild with --features tensorrt-link)"
239        )),
240        TrtError::Build(m)
241        | TrtError::Runtime(m)
242        | TrtError::Execution(m)
243        | TrtError::Onnx(m)
244        | TrtError::Calibration(m)
245        | TrtError::Plugin(m)
246        | TrtError::Refit(m) => InferenceError::Internal(format!("tensorrt: {m}")),
247        TrtError::NullEngine => InferenceError::Internal("tensorrt: engine pointer was null".into()),
248        TrtError::InvalidArg(m) => InferenceError::BadRequest {
249            message: format!("tensorrt: invalid argument: {m}"),
250        },
251    }
252}
253
254#[async_trait]
255impl ModelRunner for TensorRtRunner {
256    #[cfg_attr(
257        feature = "tensorrt",
258        tracing::instrument(skip(self, _batch), fields(plan = %self.config.plan_path.display()))
259    )]
260    async fn execute(&mut self, _batch: ExecuteBatch) -> InferenceResult<RunHandle> {
261        #[cfg(not(feature = "tensorrt"))]
262        {
263            Err(InferenceError::Internal(
264                "tensorrt feature disabled at build time — rebuild with --features tensorrt".into(),
265            ))
266        }
267        #[cfg(feature = "tensorrt")]
268        {
269            // ExecuteBatch is chat-shaped (Vec<Message> + sampling).
270            // TensorRT engines consume raw tensors, so an LLM-aware
271            // adapter has to tokenise and stage device pointers via
272            // ExecutionBindings before this runner can satisfy a chat
273            // request. Surfacing the gap as a typed error rather than
274            // a panic keeps callers honest.
275            self.ensure_state()?;
276            Err(InferenceError::Internal(
277                "tensorrt runner: chat-style execute requires a tokeniser layer; \
278                 callers staging tensors directly should invoke `enqueue` with \
279                 a prepared ExecutionBindings"
280                    .into(),
281            ))
282        }
283    }
284
285    async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
286        #[cfg(feature = "tensorrt")]
287        {
288            // Re-read the plan from disk on a real reload; otherwise
289            // just drop the cached engine/context so the next execute
290            // rebuilds.
291            if matches!(
292                cause,
293                SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
294            ) {
295                let plan = std::fs::read(&self.config.plan_path).map_err(|e| {
296                    InferenceError::Internal(format!(
297                        "tensorrt: failed to re-read plan from {}: {e}",
298                        self.config.plan_path.display()
299                    ))
300                })?;
301                self.plan = plan;
302            }
303            *self.state.lock() = None;
304        }
305        let _ = cause;
306        Ok(())
307    }
308
309    fn runtime_kind(&self) -> RuntimeKind {
310        RuntimeKind::TensorRt
311    }
312    fn transport_kind(&self) -> TransportKind {
313        TransportKind::LocalGpu
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn missing_plan_returns_internal_error() {
323        let cfg = TensorRtConfig {
324            plan_path: std::path::PathBuf::from("/this/path/does/not/exist.plan"),
325            max_batch_size: 1,
326            precision: TrtPrecision::default(),
327            device_id: 0,
328        };
329        let result = TensorRtRunner::new(cfg);
330        assert!(matches!(result, Err(InferenceError::Internal(_))));
331    }
332
333    #[test]
334    fn empty_plan_loads_into_runner() {
335        let tmp = tempfile::NamedTempFile::new().expect("tempfile");
336        std::fs::write(tmp.path(), b"").expect("write empty plan");
337        let cfg = TensorRtConfig {
338            plan_path: tmp.path().to_path_buf(),
339            max_batch_size: 1,
340            precision: TrtPrecision::Fp16,
341            device_id: 0,
342        };
343        let runner = TensorRtRunner::new(cfg).expect("loads empty plan");
344        assert_eq!(runner.runtime_kind(), RuntimeKind::TensorRt);
345        assert_eq!(runner.transport_kind(), TransportKind::LocalGpu);
346    }
347
348    #[cfg(not(feature = "tensorrt"))]
349    #[tokio::test]
350    async fn execute_without_feature_returns_internal_error() {
351        use atomr_infer_core::batch::SamplingParams;
352
353        let tmp = tempfile::NamedTempFile::new().expect("tempfile");
354        std::fs::write(tmp.path(), b"").expect("write empty plan");
355        let cfg = TensorRtConfig {
356            plan_path: tmp.path().to_path_buf(),
357            max_batch_size: 1,
358            precision: TrtPrecision::default(),
359            device_id: 0,
360        };
361        let mut runner = TensorRtRunner::new(cfg).expect("loads empty plan");
362        let batch = ExecuteBatch {
363            request_id: "test".into(),
364            model: "test".into(),
365            messages: vec![],
366            sampling: SamplingParams::default(),
367            stream: false,
368            estimated_tokens: 1,
369        };
370        let result = runner.execute(batch).await;
371        assert!(matches!(result, Err(InferenceError::Internal(_))));
372    }
373}