atomr_infer_runtime_tensorrt/
lib.rs1#![forbid(unsafe_code)]
50#![deny(rust_2018_idioms)]
51
52use async_trait::async_trait;
53use serde::{Deserialize, Serialize};
54
55use atomr_infer_core::batch::ExecuteBatch;
56use atomr_infer_core::error::{InferenceError, InferenceResult};
57use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
58use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
59
60#[cfg(feature = "tensorrt")]
61pub use atomr_accel_tensorrt::{
62 builder::Precision,
63 runtime::{ExecutionBindings, TensorShape},
64};
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TensorRtConfig {
76 pub plan_path: std::path::PathBuf,
78 #[serde(default = "default_max_batch_size")]
81 pub max_batch_size: u32,
82 #[serde(default)]
85 pub precision: TrtPrecision,
86 #[serde(default)]
88 pub device_id: u32,
89}
90
91fn default_max_batch_size() -> u32 {
92 1
93}
94
95#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
98#[serde(rename_all = "lowercase")]
99pub enum TrtPrecision {
100 #[default]
102 Fp32,
103 Fp16,
105 Bf16,
107 Int8,
109 Fp8,
111 Best,
114}
115
116#[cfg(feature = "tensorrt")]
117impl From<TrtPrecision> for Precision {
118 fn from(p: TrtPrecision) -> Self {
119 match p {
120 TrtPrecision::Fp32 => Precision::Fp32,
121 TrtPrecision::Fp16 => Precision::Fp16,
122 TrtPrecision::Bf16 => Precision::Bf16,
123 TrtPrecision::Int8 => Precision::Int8,
124 TrtPrecision::Fp8 => Precision::Fp8,
125 TrtPrecision::Best => Precision::Best,
126 }
127 }
128}
129
130#[cfg(feature = "tensorrt")]
131struct TrtState {
132 engine: std::sync::Arc<atomr_accel_tensorrt::engine::TrtEngine>,
133 stream: std::sync::Arc<cudarc::driver::CudaStream>,
134}
135
136pub struct TensorRtRunner {
138 #[cfg_attr(not(feature = "tensorrt"), allow(dead_code))]
139 config: TensorRtConfig,
140 #[cfg_attr(not(feature = "tensorrt"), allow(dead_code))]
144 plan: Vec<u8>,
145 #[cfg(feature = "tensorrt")]
146 state: parking_lot::Mutex<Option<TrtState>>,
147}
148
149impl TensorRtRunner {
150 pub fn new(config: TensorRtConfig) -> InferenceResult<Self> {
155 let plan = std::fs::read(&config.plan_path).map_err(|e| {
156 InferenceError::Internal(format!(
157 "tensorrt: failed to read plan from {}: {e}",
158 config.plan_path.display()
159 ))
160 })?;
161 Ok(Self {
162 config,
163 plan,
164 #[cfg(feature = "tensorrt")]
165 state: parking_lot::Mutex::new(None),
166 })
167 }
168
169 #[cfg(feature = "tensorrt")]
173 pub fn with_stream(self, stream: std::sync::Arc<cudarc::driver::CudaStream>) -> Self {
174 if let Some(state) = self.state.lock().as_mut() {
175 state.stream = stream;
176 }
177 self
178 }
179
180 #[cfg(feature = "tensorrt")]
191 pub async fn enqueue(&mut self, bindings: ExecutionBindings) -> InferenceResult<()> {
192 self.ensure_state()?;
193 let guard = self.state.lock();
194 let Some(state) = guard.as_ref() else {
195 return Err(InferenceError::Internal(
196 "tensorrt: state was cleared between ensure_state and lock — \
197 retry the enqueue"
198 .into(),
199 ));
200 };
201 let _engine = state.engine.clone();
202 let _stream = state.stream.clone();
203 let _ = bindings;
204 Err(InferenceError::Internal(
205 "tensorrt: enqueue requires the `tensorrt-link` feature \
206 (libnvinfer must be installed and the link probe must \
207 succeed in atomr-accel-tensorrt's build.rs)"
208 .into(),
209 ))
210 }
211
212 #[cfg(feature = "tensorrt")]
213 fn ensure_state(&self) -> InferenceResult<()> {
214 let mut guard = self.state.lock();
215 if guard.is_some() {
216 return Ok(());
217 }
218 let cuda_ctx = cudarc::driver::CudaContext::new(self.config.device_id as usize).map_err(|e| {
219 InferenceError::Internal(format!(
220 "tensorrt: failed to create CUDA context on device {}: {e}",
221 self.config.device_id
222 ))
223 })?;
224 let stream = cuda_ctx.default_stream();
225 let runtime = atomr_accel_tensorrt::runtime::TrtRuntime::new().map_err(map_trt_err)?;
226 let engine = runtime.deserialize(&self.plan).map_err(map_trt_err)?;
227 let engine = std::sync::Arc::new(engine);
228 *guard = Some(TrtState { engine, stream });
229 Ok(())
230 }
231}
232
233#[cfg(feature = "tensorrt")]
234fn map_trt_err(err: atomr_accel_tensorrt::error::TrtError) -> InferenceError {
235 use atomr_accel_tensorrt::error::TrtError;
236 match err {
237 TrtError::NotLinked(msg) => InferenceError::Internal(format!(
238 "tensorrt not linked: {msg} (rebuild with --features tensorrt-link)"
239 )),
240 TrtError::Build(m)
241 | TrtError::Runtime(m)
242 | TrtError::Execution(m)
243 | TrtError::Onnx(m)
244 | TrtError::Calibration(m)
245 | TrtError::Plugin(m)
246 | TrtError::Refit(m) => InferenceError::Internal(format!("tensorrt: {m}")),
247 TrtError::NullEngine => InferenceError::Internal("tensorrt: engine pointer was null".into()),
248 TrtError::InvalidArg(m) => InferenceError::BadRequest {
249 message: format!("tensorrt: invalid argument: {m}"),
250 },
251 }
252}
253
254#[async_trait]
255impl ModelRunner for TensorRtRunner {
256 #[cfg_attr(
257 feature = "tensorrt",
258 tracing::instrument(skip(self, _batch), fields(plan = %self.config.plan_path.display()))
259 )]
260 async fn execute(&mut self, _batch: ExecuteBatch) -> InferenceResult<RunHandle> {
261 #[cfg(not(feature = "tensorrt"))]
262 {
263 Err(InferenceError::Internal(
264 "tensorrt feature disabled at build time — rebuild with --features tensorrt".into(),
265 ))
266 }
267 #[cfg(feature = "tensorrt")]
268 {
269 self.ensure_state()?;
276 Err(InferenceError::Internal(
277 "tensorrt runner: chat-style execute requires a tokeniser layer; \
278 callers staging tensors directly should invoke `enqueue` with \
279 a prepared ExecutionBindings"
280 .into(),
281 ))
282 }
283 }
284
285 async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
286 #[cfg(feature = "tensorrt")]
287 {
288 if matches!(
292 cause,
293 SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
294 ) {
295 let plan = std::fs::read(&self.config.plan_path).map_err(|e| {
296 InferenceError::Internal(format!(
297 "tensorrt: failed to re-read plan from {}: {e}",
298 self.config.plan_path.display()
299 ))
300 })?;
301 self.plan = plan;
302 }
303 *self.state.lock() = None;
304 }
305 let _ = cause;
306 Ok(())
307 }
308
309 fn runtime_kind(&self) -> RuntimeKind {
310 RuntimeKind::TensorRt
311 }
312 fn transport_kind(&self) -> TransportKind {
313 TransportKind::LocalGpu
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn missing_plan_returns_internal_error() {
323 let cfg = TensorRtConfig {
324 plan_path: std::path::PathBuf::from("/this/path/does/not/exist.plan"),
325 max_batch_size: 1,
326 precision: TrtPrecision::default(),
327 device_id: 0,
328 };
329 let result = TensorRtRunner::new(cfg);
330 assert!(matches!(result, Err(InferenceError::Internal(_))));
331 }
332
333 #[test]
334 fn empty_plan_loads_into_runner() {
335 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
336 std::fs::write(tmp.path(), b"").expect("write empty plan");
337 let cfg = TensorRtConfig {
338 plan_path: tmp.path().to_path_buf(),
339 max_batch_size: 1,
340 precision: TrtPrecision::Fp16,
341 device_id: 0,
342 };
343 let runner = TensorRtRunner::new(cfg).expect("loads empty plan");
344 assert_eq!(runner.runtime_kind(), RuntimeKind::TensorRt);
345 assert_eq!(runner.transport_kind(), TransportKind::LocalGpu);
346 }
347
348 #[cfg(not(feature = "tensorrt"))]
349 #[tokio::test]
350 async fn execute_without_feature_returns_internal_error() {
351 use atomr_infer_core::batch::SamplingParams;
352
353 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
354 std::fs::write(tmp.path(), b"").expect("write empty plan");
355 let cfg = TensorRtConfig {
356 plan_path: tmp.path().to_path_buf(),
357 max_batch_size: 1,
358 precision: TrtPrecision::default(),
359 device_id: 0,
360 };
361 let mut runner = TensorRtRunner::new(cfg).expect("loads empty plan");
362 let batch = ExecuteBatch {
363 request_id: "test".into(),
364 model: "test".into(),
365 messages: vec![],
366 sampling: SamplingParams::default(),
367 stream: false,
368 estimated_tokens: 1,
369 };
370 let result = runner.execute(batch).await;
371 assert!(matches!(result, Err(InferenceError::Internal(_))));
372 }
373}