atomr_accel_tensorrt/actor.rs
1//! `TrtActor` — sibling of `atomr_accel_cuda::DeviceActor`.
2//!
3//! Lifecycle:
4//! - On `Build` it consumes a network builder (or ONNX bytes when
5//! `tensorrt-onnx` is enabled) plus an [`IBuilderConfig`], drives
6//! `IBuilder::buildSerializedNetwork` and returns an
7//! [`EnginePlan`].
8//! - On `Deserialize` it loads a previously built plan into an
9//! [`TrtEngine`].
10//! - On `CreateContext` it creates a fresh [`ExecutionContext`].
11//! - On `EnqueueOnStream { stream, context, reply }` it submits the
12//! inference on the supplied `Arc<cudarc::driver::CudaStream>` —
13//! the same stream type carried by `DeviceActor` so the two actors
14//! share one CUDA execution timeline.
15//! - On `Refit` it patches engine weights via [`TrtRefitter`].
16//!
17//! The actor keeps the `TrtEngine` alive in an `Arc` so multiple
18//! `ExecutionContext`s can share it.
19
20use std::sync::Arc;
21
22use tokio::sync::oneshot;
23
24use crate::builder::IBuilderConfig;
25use crate::engine::{EnginePlan, TrtEngine};
26use crate::error::TrtError;
27use crate::runtime::{ExecutionBindings, ExecutionContext};
28
29/// Network description for `TrtMsg::Build`. The builder API has
30/// many entry points; for now we accept either a serialised ONNX blob
31/// (under `tensorrt-onnx`) or a precompiled TensorRT plan to import.
32#[derive(Debug, Clone)]
33pub enum NetworkSource {
34 /// Raw ONNX bytes. Requires the `tensorrt-onnx` feature.
35 Onnx(Vec<u8>),
36 /// A previously serialised TensorRT plan; just deserialise.
37 SerializedPlan(Vec<u8>),
38}
39
40/// Descriptor of a single weight blob to push into the engine via
41/// the refitter. The pointer / device pointer is **not** held inside
42/// the message; instead callers pass a host-side blob (refitter
43/// stages it). Future variants can add a `DevicePtr` tag if direct
44/// device-to-device refit is desired.
45pub struct RefitWeights {
46 pub name: String,
47 pub bytes: Vec<u8>,
48 pub dtype: crate::sys::DataType,
49}
50
51/// Reply types for each `TrtMsg` variant. Each is a `oneshot::Sender`
52/// so the actor never blocks on IO.
53pub type BuildReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
54pub type DeserializeReply = oneshot::Sender<Result<Arc<TrtEngine>, TrtError>>;
55pub type CreateContextReply = oneshot::Sender<Result<ExecutionContext, TrtError>>;
56pub type EnqueueReply = oneshot::Sender<Result<(), TrtError>>;
57pub type RefitReply = oneshot::Sender<Result<(), TrtError>>;
58pub type ExecuteReply = oneshot::Sender<Result<(), TrtError>>;
59pub type BuildFromOnnxReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
60
61/// Public message surface for `TrtActor`.
62///
63/// The variant `EnqueueOnStream` accepts the `Arc<CudaStream>` from
64/// `atomr-accel-cuda::DeviceActor` so the TensorRT runtime shares
65/// the device's stream timeline (no cross-stream synchronisation,
66/// no extra event hops).
67pub enum TrtMsg {
68 /// Build a TensorRT engine from a network source + config.
69 /// Returns the serialised plan on success.
70 Build {
71 source: NetworkSource,
72 config: Box<IBuilderConfig>,
73 reply: BuildReply,
74 },
75
76 /// Deserialise a plan blob into a shared engine handle.
77 Deserialize {
78 plan: EnginePlan,
79 reply: DeserializeReply,
80 },
81
82 /// Create a fresh `IExecutionContext` from an existing engine.
83 /// Returns the new context (caller owns it).
84 CreateContext {
85 engine: Arc<TrtEngine>,
86 reply: CreateContextReply,
87 },
88
89 /// Submit `enqueueV3` on the supplied CUDA stream. The actor
90 /// returns immediately after submission; real GPU completion is
91 /// observed by `atomr-accel-cuda`'s completion strategy on the
92 /// shared stream.
93 EnqueueOnStream {
94 stream: Arc<cudarc::driver::CudaStream>,
95 context: ExecutionContext,
96 bindings: ExecutionBindings,
97 reply: EnqueueReply,
98 },
99
100 /// Refit a built engine in-place with new weights. Requires the
101 /// engine to have been built with `RefitPolicy::OnDemand` or
102 /// `WeightsStreaming`.
103 Refit {
104 engine: Arc<TrtEngine>,
105 weights: Vec<RefitWeights>,
106 reply: RefitReply,
107 },
108
109 /// Phase 4.5++ — Run inference on a previously-loaded engine.
110 /// `bindings` is `(tensor_name, CUdeviceptr)` for every I/O
111 /// tensor on the engine; `stream` is the `Arc<CudaStream>` to
112 /// `enqueueV3` against (typically the device's primary stream
113 /// from `DeviceMsg::SnapshotStream`).
114 ///
115 /// The handler creates a fresh `IExecutionContext`, binds every
116 /// tensor address, then calls `enqueueV3`. Returns `Ok(())` on
117 /// successful submission (kernel still running on the GPU);
118 /// real completion is observed by `atomr-accel-cuda`'s
119 /// completion strategy on the shared stream.
120 ///
121 /// On builds without `tensorrt-link` the variant compiles but
122 /// the handler returns `TrtError::NotLinked`.
123 Execute {
124 engine: Arc<TrtEngine>,
125 bindings: Vec<(String, u64)>,
126 input_shapes: Vec<(String, Vec<i32>)>,
127 stream: Arc<cudarc::driver::CudaStream>,
128 reply: ExecuteReply,
129 },
130
131 /// Phase 4.5++ — Parse an ONNX model and build a serialised
132 /// engine plan. Gated on the upstream `tensorrt-onnx` feature
133 /// (and transitively on `tensorrt-link`). Without those the
134 /// handler returns `TrtError::NotLinked`.
135 BuildFromOnnx {
136 onnx_bytes: Vec<u8>,
137 config: Box<IBuilderConfig>,
138 reply: BuildFromOnnxReply,
139 },
140}
141
142/// `TrtActor` — owns nothing across messages besides the FFI
143/// runtime/builder handles, all engines/contexts ride the messages.
144///
145/// The actor itself is intentionally minimal: most of the heavy
146/// state lives in `Arc<TrtEngine>` values that the caller threads
147/// through. This mirrors `DeviceActor`'s design where per-context
148/// state lives in the `ContextActor` but engines live with the
149/// caller.
150pub struct TrtActor {
151 /// Cached runtime; lazily created on first `Deserialize`. Held
152 /// behind a `parking_lot::Mutex` because the actor mailbox
153 /// already serialises but interior mutability avoids a redundant
154 /// `&mut self` thread through every method.
155 runtime: parking_lot::Mutex<Option<crate::runtime::TrtRuntime>>,
156}
157
158impl TrtActor {
159 pub fn new() -> Self {
160 Self {
161 runtime: parking_lot::Mutex::new(None),
162 }
163 }
164
165 /// Get-or-create the cached runtime. Without `tensorrt-link` the
166 /// inner constructor returns `NotLinked`.
167 pub fn ensure_runtime(&self) -> Result<(), TrtError> {
168 let mut guard = self.runtime.lock();
169 if guard.is_none() {
170 *guard = Some(crate::runtime::TrtRuntime::new()?);
171 }
172 Ok(())
173 }
174
175 /// Phase 4.5++ — synchronous helper that drives the
176 /// `TrtMsg::Execute` semantics (creates an `IExecutionContext`,
177 /// binds tensor addresses, calls `enqueueV3`).
178 ///
179 /// Without `tensorrt-link` this returns `TrtError::NotLinked`
180 /// without ever touching libnvinfer. With the feature on, the
181 /// actor performs the full FFI sequence under the supplied
182 /// `Arc<CudaStream>`. The function returns once the launch
183 /// has been submitted — real GPU completion is observed
184 /// downstream (the typical caller pairs this with an
185 /// `atomr-accel-cuda` completion strategy on the same stream).
186 pub fn execute(
187 &self,
188 engine: &Arc<TrtEngine>,
189 bindings: &[(String, u64)],
190 input_shapes: &[(String, Vec<i32>)],
191 _stream: &Arc<cudarc::driver::CudaStream>,
192 ) -> Result<(), TrtError> {
193 #[cfg(feature = "tensorrt-link")]
194 {
195 use std::ffi::CString;
196 unsafe {
197 let ctx_ptr = crate::sys::atomr_trt_engine_create_execution_context(engine.raw());
198 if ctx_ptr.is_null() {
199 return Err(TrtError::Execution(
200 "createExecutionContext returned null".into(),
201 ));
202 }
203 // Apply input shapes first (TensorRT requires shapes
204 // before set_tensor_address on dynamic tensors).
205 for (name, dims) in input_shapes {
206 if dims.len() > 8 {
207 crate::sys::atomr_trt_context_destroy(ctx_ptr);
208 return Err(TrtError::InvalidArg(format!(
209 "tensor {name:?}: TensorRT supports at most 8 dims (got {})",
210 dims.len()
211 )));
212 }
213 let cname = match CString::new(name.clone()) {
214 Ok(c) => c,
215 Err(e) => {
216 crate::sys::atomr_trt_context_destroy(ctx_ptr);
217 return Err(TrtError::InvalidArg(format!(
218 "tensor name contains NUL: {e}"
219 )));
220 }
221 };
222 let mut d = [0i32; 8];
223 for (i, v) in dims.iter().enumerate() {
224 d[i] = *v;
225 }
226 let dims_struct = crate::sys::Dims {
227 nb_dims: dims.len() as std::os::raw::c_int,
228 d,
229 };
230 let rc = crate::sys::atomr_trt_context_set_input_shape(
231 ctx_ptr,
232 cname.as_ptr(),
233 &dims_struct as *const crate::sys::Dims,
234 );
235 if rc != 0 {
236 crate::sys::atomr_trt_context_destroy(ctx_ptr);
237 return Err(TrtError::Execution(format!(
238 "set_input_shape({name}) returned {rc}"
239 )));
240 }
241 }
242 // Bind every tensor address.
243 for (name, addr) in bindings {
244 let cname = match CString::new(name.clone()) {
245 Ok(c) => c,
246 Err(e) => {
247 crate::sys::atomr_trt_context_destroy(ctx_ptr);
248 return Err(TrtError::InvalidArg(format!(
249 "tensor name contains NUL: {e}"
250 )));
251 }
252 };
253 let rc = crate::sys::atomr_trt_context_set_tensor_address(
254 ctx_ptr,
255 cname.as_ptr(),
256 *addr as *mut std::os::raw::c_void,
257 );
258 if rc != 0 {
259 crate::sys::atomr_trt_context_destroy(ctx_ptr);
260 return Err(TrtError::Execution(format!(
261 "set_tensor_address({name}) returned {rc}"
262 )));
263 }
264 }
265 // Cudarc's `CudaStream` exposes the raw stream via
266 // `cu_stream()` — but the field is `pub(crate)`. We
267 // pass through cudarc's `DevicePtr`-style accessor by
268 // using `cuStream` symbol from `cudarc::driver::sys`
269 // — which is what other call sites in atomr-accel-cuda
270 // do. The shim takes `*mut c_void` (any CUstream).
271 let stream_raw = _stream.cu_stream() as *mut std::os::raw::c_void;
272 let rc = crate::sys::atomr_trt_context_enqueue_v3(ctx_ptr, stream_raw);
273 let result = if rc != 0 {
274 Err(TrtError::Execution(format!("enqueueV3 returned {rc}")))
275 } else {
276 Ok(())
277 };
278 crate::sys::atomr_trt_context_destroy(ctx_ptr);
279 result
280 }
281 }
282 #[cfg(not(feature = "tensorrt-link"))]
283 {
284 let _ = (engine, bindings, input_shapes, _stream);
285 Err(TrtError::NotLinked(
286 "TrtActor::execute requires the `tensorrt-link` feature",
287 ))
288 }
289 }
290
291 /// Phase 4.5++ — synchronous helper that drives the
292 /// `TrtMsg::BuildFromOnnx` semantics. Parses an ONNX model and
293 /// returns a serialised plan blob ready for `TrtRuntime::deserialize`.
294 /// Gated on `tensorrt-onnx` (transitively `tensorrt-link`).
295 pub fn build_from_onnx(
296 &self,
297 _onnx_bytes: &[u8],
298 _config: &IBuilderConfig,
299 ) -> Result<EnginePlan, TrtError> {
300 #[cfg(all(feature = "tensorrt-link", feature = "tensorrt-onnx"))]
301 {
302 use crate::builder::BuilderFlags;
303 unsafe {
304 let builder = crate::sys::atomr_trt_builder_create(0);
305 if builder.is_null() {
306 return Err(TrtError::Build("builder_create returned null".into()));
307 }
308 // EXPLICIT_BATCH (1 << 0) is required for ONNX import.
309 let network = crate::sys::atomr_trt_builder_create_network(builder, 1u32 << 0);
310 if network.is_null() {
311 crate::sys::atomr_trt_builder_destroy(builder);
312 return Err(TrtError::Build("create_network returned null".into()));
313 }
314 let parser = crate::sys::atomr_trt_onnx_parser_create(network, 0);
315 if parser.is_null() {
316 crate::sys::atomr_trt_builder_destroy(builder);
317 return Err(TrtError::Onnx("onnx_parser_create returned null".into()));
318 }
319 let parse_rc = crate::sys::atomr_trt_onnx_parser_parse(
320 parser,
321 _onnx_bytes.as_ptr(),
322 _onnx_bytes.len(),
323 std::ptr::null(),
324 );
325 if parse_rc == 0 {
326 let nerr = crate::sys::atomr_trt_onnx_parser_num_errors(parser);
327 crate::sys::atomr_trt_onnx_parser_destroy(parser);
328 crate::sys::atomr_trt_builder_destroy(builder);
329 return Err(TrtError::Onnx(format!(
330 "onnx parse failed (rc={parse_rc}, errors={nerr})"
331 )));
332 }
333
334 let cfg_ptr = crate::sys::atomr_trt_builder_create_config(builder);
335 if cfg_ptr.is_null() {
336 crate::sys::atomr_trt_onnx_parser_destroy(parser);
337 crate::sys::atomr_trt_builder_destroy(builder);
338 return Err(TrtError::Build(
339 "builder_create_config returned null".into(),
340 ));
341 }
342 // Replay caller-requested flags onto the C++ config.
343 let flags = _config.effective_flags();
344 for flag in [
345 (BuilderFlags::FP16, crate::sys::BuilderFlag::kFP16 as u32),
346 (BuilderFlags::INT8, crate::sys::BuilderFlag::kINT8 as u32),
347 (BuilderFlags::TF32, crate::sys::BuilderFlag::kTF32 as u32),
348 (BuilderFlags::BF16, crate::sys::BuilderFlag::kBF16 as u32),
349 (BuilderFlags::FP8, crate::sys::BuilderFlag::kFP8 as u32),
350 (BuilderFlags::REFIT, crate::sys::BuilderFlag::kREFIT as u32),
351 (
352 BuilderFlags::SPARSE_WEIGHTS,
353 crate::sys::BuilderFlag::kSPARSE_WEIGHTS as u32,
354 ),
355 (
356 BuilderFlags::STRIP_PLAN,
357 crate::sys::BuilderFlag::kSTRIP_PLAN as u32,
358 ),
359 ] {
360 if flags.contains(flag.0) {
361 crate::sys::atomr_trt_config_set_flag(cfg_ptr, flag.1, 1);
362 }
363 }
364 if _config.workspace_bytes > 0 {
365 crate::sys::atomr_trt_config_set_memory_pool_limit(
366 cfg_ptr,
367 0, // kWORKSPACE
368 _config.workspace_bytes,
369 );
370 }
371
372 let host_mem =
373 crate::sys::atomr_trt_builder_build_serialized(builder, network, cfg_ptr);
374 let cleanup = || {
375 crate::sys::atomr_trt_config_destroy(cfg_ptr);
376 crate::sys::atomr_trt_onnx_parser_destroy(parser);
377 crate::sys::atomr_trt_builder_destroy(builder);
378 };
379 if host_mem.is_null() {
380 cleanup();
381 return Err(TrtError::Build(
382 "buildSerializedNetwork returned null".into(),
383 ));
384 }
385 let data_ptr = crate::sys::atomr_trt_host_memory_data(host_mem);
386 let data_len = crate::sys::atomr_trt_host_memory_size(host_mem);
387 let bytes = if data_ptr.is_null() || data_len == 0 {
388 Vec::new()
389 } else {
390 std::slice::from_raw_parts(data_ptr, data_len).to_vec()
391 };
392 crate::sys::atomr_trt_host_memory_destroy(host_mem);
393 cleanup();
394 if bytes.is_empty() {
395 return Err(TrtError::Build("serialised plan was empty".into()));
396 }
397 Ok(EnginePlan::new(bytes))
398 }
399 }
400 #[cfg(not(all(feature = "tensorrt-link", feature = "tensorrt-onnx")))]
401 {
402 Err(TrtError::NotLinked(
403 "TrtActor::build_from_onnx requires the `tensorrt-link` + `tensorrt-onnx` features",
404 ))
405 }
406 }
407}
408
409impl Default for TrtActor {
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::builder::Precision;
419
420 #[test]
421 fn trt_msg_constructs() {
422 // Walk every variant — confirms the message enum builds and
423 // is `Send`-clean (oneshot::Sender is Send for any T).
424 let (b_tx, _b_rx) = oneshot::channel();
425 let _build = TrtMsg::Build {
426 source: NetworkSource::SerializedPlan(vec![1, 2, 3]),
427 config: Box::new(IBuilderConfig::new().with_precision(Precision::Fp16)),
428 reply: b_tx,
429 };
430
431 let (d_tx, _d_rx) = oneshot::channel();
432 let _deser = TrtMsg::Deserialize {
433 plan: EnginePlan::new(vec![0xAA; 8]),
434 reply: d_tx,
435 };
436
437 let engine = Arc::new(TrtEngine::for_test());
438 let (c_tx, _c_rx) = oneshot::channel();
439 let _ctx = TrtMsg::CreateContext {
440 engine: engine.clone(),
441 reply: c_tx,
442 };
443
444 let (r_tx, _r_rx) = oneshot::channel();
445 let _refit = TrtMsg::Refit {
446 engine: engine.clone(),
447 weights: vec![RefitWeights {
448 name: "fc.weight".into(),
449 bytes: vec![0; 16],
450 dtype: crate::sys::DataType::kHALF,
451 }],
452 reply: r_tx,
453 };
454
455 // Verify the actor itself is Send so it can live inside an
456 // `atomr_core::actor::Actor`.
457 fn assert_send<T: Send>() {}
458 assert_send::<TrtActor>();
459 }
460
461 #[test]
462 fn actor_runtime_lazy_init() {
463 let actor = TrtActor::new();
464 // Without the link feature this should error cleanly, never
465 // panic.
466 #[cfg(not(feature = "tensorrt-link"))]
467 {
468 let r = actor.ensure_runtime();
469 assert!(matches!(r, Err(TrtError::NotLinked(_))));
470 }
471 #[cfg(feature = "tensorrt-link")]
472 {
473 // Real link path is exercised by integration tests with a
474 // GPU host.
475 let _ = actor;
476 }
477 }
478}