Skip to main content

atomr_accel_tensorrt/
actor.rs

1//! `TrtActor` — sibling of `atomr_accel_cuda::DeviceActor`.
2//!
3//! Lifecycle:
4//! - On `Build` it consumes a network builder (or ONNX bytes when
5//!   `tensorrt-onnx` is enabled) plus an [`IBuilderConfig`], drives
6//!   `IBuilder::buildSerializedNetwork` and returns an
7//!   [`EnginePlan`].
8//! - On `Deserialize` it loads a previously built plan into an
9//!   [`TrtEngine`].
10//! - On `CreateContext` it creates a fresh [`ExecutionContext`].
11//! - On `EnqueueOnStream { stream, context, reply }` it submits the
12//!   inference on the supplied `Arc<cudarc::driver::CudaStream>` —
13//!   the same stream type carried by `DeviceActor` so the two actors
14//!   share one CUDA execution timeline.
15//! - On `Refit` it patches engine weights via [`TrtRefitter`].
16//!
17//! The actor keeps the `TrtEngine` alive in an `Arc` so multiple
18//! `ExecutionContext`s can share it.
19
20use std::sync::Arc;
21
22use tokio::sync::oneshot;
23
24use crate::builder::IBuilderConfig;
25use crate::engine::{EnginePlan, TrtEngine};
26use crate::error::TrtError;
27use crate::runtime::{ExecutionBindings, ExecutionContext};
28
29/// Network description for `TrtMsg::Build`. The builder API has
30/// many entry points; for now we accept either a serialised ONNX blob
31/// (under `tensorrt-onnx`) or a precompiled TensorRT plan to import.
32#[derive(Debug, Clone)]
33pub enum NetworkSource {
34    /// Raw ONNX bytes. Requires the `tensorrt-onnx` feature.
35    Onnx(Vec<u8>),
36    /// A previously serialised TensorRT plan; just deserialise.
37    SerializedPlan(Vec<u8>),
38}
39
40/// Descriptor of a single weight blob to push into the engine via
41/// the refitter. The pointer / device pointer is **not** held inside
42/// the message; instead callers pass a host-side blob (refitter
43/// stages it). Future variants can add a `DevicePtr` tag if direct
44/// device-to-device refit is desired.
45pub struct RefitWeights {
46    pub name: String,
47    pub bytes: Vec<u8>,
48    pub dtype: crate::sys::DataType,
49}
50
51/// Reply types for each `TrtMsg` variant. Each is a `oneshot::Sender`
52/// so the actor never blocks on IO.
53pub type BuildReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
54pub type DeserializeReply = oneshot::Sender<Result<Arc<TrtEngine>, TrtError>>;
55pub type CreateContextReply = oneshot::Sender<Result<ExecutionContext, TrtError>>;
56pub type EnqueueReply = oneshot::Sender<Result<(), TrtError>>;
57pub type RefitReply = oneshot::Sender<Result<(), TrtError>>;
58pub type ExecuteReply = oneshot::Sender<Result<(), TrtError>>;
59pub type BuildFromOnnxReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
60
61/// Public message surface for `TrtActor`.
62///
63/// The variant `EnqueueOnStream` accepts the `Arc<CudaStream>` from
64/// `atomr-accel-cuda::DeviceActor` so the TensorRT runtime shares
65/// the device's stream timeline (no cross-stream synchronisation,
66/// no extra event hops).
67pub enum TrtMsg {
68    /// Build a TensorRT engine from a network source + config.
69    /// Returns the serialised plan on success.
70    Build {
71        source: NetworkSource,
72        config: Box<IBuilderConfig>,
73        reply: BuildReply,
74    },
75
76    /// Deserialise a plan blob into a shared engine handle.
77    Deserialize {
78        plan: EnginePlan,
79        reply: DeserializeReply,
80    },
81
82    /// Create a fresh `IExecutionContext` from an existing engine.
83    /// Returns the new context (caller owns it).
84    CreateContext {
85        engine: Arc<TrtEngine>,
86        reply: CreateContextReply,
87    },
88
89    /// Submit `enqueueV3` on the supplied CUDA stream. The actor
90    /// returns immediately after submission; real GPU completion is
91    /// observed by `atomr-accel-cuda`'s completion strategy on the
92    /// shared stream.
93    EnqueueOnStream {
94        stream: Arc<cudarc::driver::CudaStream>,
95        context: ExecutionContext,
96        bindings: ExecutionBindings,
97        reply: EnqueueReply,
98    },
99
100    /// Refit a built engine in-place with new weights. Requires the
101    /// engine to have been built with `RefitPolicy::OnDemand` or
102    /// `WeightsStreaming`.
103    Refit {
104        engine: Arc<TrtEngine>,
105        weights: Vec<RefitWeights>,
106        reply: RefitReply,
107    },
108
109    /// Phase 4.5++ — Run inference on a previously-loaded engine.
110    /// `bindings` is `(tensor_name, CUdeviceptr)` for every I/O
111    /// tensor on the engine; `stream` is the `Arc<CudaStream>` to
112    /// `enqueueV3` against (typically the device's primary stream
113    /// from `DeviceMsg::SnapshotStream`).
114    ///
115    /// The handler creates a fresh `IExecutionContext`, binds every
116    /// tensor address, then calls `enqueueV3`. Returns `Ok(())` on
117    /// successful submission (kernel still running on the GPU);
118    /// real completion is observed by `atomr-accel-cuda`'s
119    /// completion strategy on the shared stream.
120    ///
121    /// On builds without `tensorrt-link` the variant compiles but
122    /// the handler returns `TrtError::NotLinked`.
123    Execute {
124        engine: Arc<TrtEngine>,
125        bindings: Vec<(String, u64)>,
126        input_shapes: Vec<(String, Vec<i32>)>,
127        stream: Arc<cudarc::driver::CudaStream>,
128        reply: ExecuteReply,
129    },
130
131    /// Phase 4.5++ — Parse an ONNX model and build a serialised
132    /// engine plan. Gated on the upstream `tensorrt-onnx` feature
133    /// (and transitively on `tensorrt-link`). Without those the
134    /// handler returns `TrtError::NotLinked`.
135    BuildFromOnnx {
136        onnx_bytes: Vec<u8>,
137        config: Box<IBuilderConfig>,
138        reply: BuildFromOnnxReply,
139    },
140}
141
142/// `TrtActor` — owns nothing across messages besides the FFI
143/// runtime/builder handles, all engines/contexts ride the messages.
144///
145/// The actor itself is intentionally minimal: most of the heavy
146/// state lives in `Arc<TrtEngine>` values that the caller threads
147/// through. This mirrors `DeviceActor`'s design where per-context
148/// state lives in the `ContextActor` but engines live with the
149/// caller.
150pub struct TrtActor {
151    /// Cached runtime; lazily created on first `Deserialize`. Held
152    /// behind a `parking_lot::Mutex` because the actor mailbox
153    /// already serialises but interior mutability avoids a redundant
154    /// `&mut self` thread through every method.
155    runtime: parking_lot::Mutex<Option<crate::runtime::TrtRuntime>>,
156}
157
158impl TrtActor {
159    pub fn new() -> Self {
160        Self {
161            runtime: parking_lot::Mutex::new(None),
162        }
163    }
164
165    /// Get-or-create the cached runtime. Without `tensorrt-link` the
166    /// inner constructor returns `NotLinked`.
167    pub fn ensure_runtime(&self) -> Result<(), TrtError> {
168        let mut guard = self.runtime.lock();
169        if guard.is_none() {
170            *guard = Some(crate::runtime::TrtRuntime::new()?);
171        }
172        Ok(())
173    }
174
175    /// Phase 4.5++ — synchronous helper that drives the
176    /// `TrtMsg::Execute` semantics (creates an `IExecutionContext`,
177    /// binds tensor addresses, calls `enqueueV3`).
178    ///
179    /// Without `tensorrt-link` this returns `TrtError::NotLinked`
180    /// without ever touching libnvinfer. With the feature on, the
181    /// actor performs the full FFI sequence under the supplied
182    /// `Arc<CudaStream>`. The function returns once the launch
183    /// has been submitted — real GPU completion is observed
184    /// downstream (the typical caller pairs this with an
185    /// `atomr-accel-cuda` completion strategy on the same stream).
186    pub fn execute(
187        &self,
188        engine: &Arc<TrtEngine>,
189        bindings: &[(String, u64)],
190        input_shapes: &[(String, Vec<i32>)],
191        _stream: &Arc<cudarc::driver::CudaStream>,
192    ) -> Result<(), TrtError> {
193        #[cfg(feature = "tensorrt-link")]
194        {
195            use std::ffi::CString;
196            unsafe {
197                let ctx_ptr = crate::sys::atomr_trt_engine_create_execution_context(engine.raw());
198                if ctx_ptr.is_null() {
199                    return Err(TrtError::Execution(
200                        "createExecutionContext returned null".into(),
201                    ));
202                }
203                // Apply input shapes first (TensorRT requires shapes
204                // before set_tensor_address on dynamic tensors).
205                for (name, dims) in input_shapes {
206                    if dims.len() > 8 {
207                        crate::sys::atomr_trt_context_destroy(ctx_ptr);
208                        return Err(TrtError::InvalidArg(format!(
209                            "tensor {name:?}: TensorRT supports at most 8 dims (got {})",
210                            dims.len()
211                        )));
212                    }
213                    let cname = match CString::new(name.clone()) {
214                        Ok(c) => c,
215                        Err(e) => {
216                            crate::sys::atomr_trt_context_destroy(ctx_ptr);
217                            return Err(TrtError::InvalidArg(format!(
218                                "tensor name contains NUL: {e}"
219                            )));
220                        }
221                    };
222                    let mut d = [0i32; 8];
223                    for (i, v) in dims.iter().enumerate() {
224                        d[i] = *v;
225                    }
226                    let dims_struct = crate::sys::Dims {
227                        nb_dims: dims.len() as std::os::raw::c_int,
228                        d,
229                    };
230                    let rc = crate::sys::atomr_trt_context_set_input_shape(
231                        ctx_ptr,
232                        cname.as_ptr(),
233                        &dims_struct as *const crate::sys::Dims,
234                    );
235                    if rc != 0 {
236                        crate::sys::atomr_trt_context_destroy(ctx_ptr);
237                        return Err(TrtError::Execution(format!(
238                            "set_input_shape({name}) returned {rc}"
239                        )));
240                    }
241                }
242                // Bind every tensor address.
243                for (name, addr) in bindings {
244                    let cname = match CString::new(name.clone()) {
245                        Ok(c) => c,
246                        Err(e) => {
247                            crate::sys::atomr_trt_context_destroy(ctx_ptr);
248                            return Err(TrtError::InvalidArg(format!(
249                                "tensor name contains NUL: {e}"
250                            )));
251                        }
252                    };
253                    let rc = crate::sys::atomr_trt_context_set_tensor_address(
254                        ctx_ptr,
255                        cname.as_ptr(),
256                        *addr as *mut std::os::raw::c_void,
257                    );
258                    if rc != 0 {
259                        crate::sys::atomr_trt_context_destroy(ctx_ptr);
260                        return Err(TrtError::Execution(format!(
261                            "set_tensor_address({name}) returned {rc}"
262                        )));
263                    }
264                }
265                // Cudarc's `CudaStream` exposes the raw stream via
266                // `cu_stream()` — but the field is `pub(crate)`. We
267                // pass through cudarc's `DevicePtr`-style accessor by
268                // using `cuStream` symbol from `cudarc::driver::sys`
269                // — which is what other call sites in atomr-accel-cuda
270                // do. The shim takes `*mut c_void` (any CUstream).
271                let stream_raw = _stream.cu_stream() as *mut std::os::raw::c_void;
272                let rc = crate::sys::atomr_trt_context_enqueue_v3(ctx_ptr, stream_raw);
273                let result = if rc != 0 {
274                    Err(TrtError::Execution(format!("enqueueV3 returned {rc}")))
275                } else {
276                    Ok(())
277                };
278                crate::sys::atomr_trt_context_destroy(ctx_ptr);
279                result
280            }
281        }
282        #[cfg(not(feature = "tensorrt-link"))]
283        {
284            let _ = (engine, bindings, input_shapes, _stream);
285            Err(TrtError::NotLinked(
286                "TrtActor::execute requires the `tensorrt-link` feature",
287            ))
288        }
289    }
290
291    /// Phase 4.5++ — synchronous helper that drives the
292    /// `TrtMsg::BuildFromOnnx` semantics. Parses an ONNX model and
293    /// returns a serialised plan blob ready for `TrtRuntime::deserialize`.
294    /// Gated on `tensorrt-onnx` (transitively `tensorrt-link`).
295    pub fn build_from_onnx(
296        &self,
297        _onnx_bytes: &[u8],
298        _config: &IBuilderConfig,
299    ) -> Result<EnginePlan, TrtError> {
300        #[cfg(all(feature = "tensorrt-link", feature = "tensorrt-onnx"))]
301        {
302            use crate::builder::BuilderFlags;
303            unsafe {
304                let builder = crate::sys::atomr_trt_builder_create(0);
305                if builder.is_null() {
306                    return Err(TrtError::Build("builder_create returned null".into()));
307                }
308                // EXPLICIT_BATCH (1 << 0) is required for ONNX import.
309                let network = crate::sys::atomr_trt_builder_create_network(builder, 1u32 << 0);
310                if network.is_null() {
311                    crate::sys::atomr_trt_builder_destroy(builder);
312                    return Err(TrtError::Build("create_network returned null".into()));
313                }
314                let parser = crate::sys::atomr_trt_onnx_parser_create(network, 0);
315                if parser.is_null() {
316                    crate::sys::atomr_trt_builder_destroy(builder);
317                    return Err(TrtError::Onnx("onnx_parser_create returned null".into()));
318                }
319                let parse_rc = crate::sys::atomr_trt_onnx_parser_parse(
320                    parser,
321                    _onnx_bytes.as_ptr(),
322                    _onnx_bytes.len(),
323                    std::ptr::null(),
324                );
325                if parse_rc == 0 {
326                    let nerr = crate::sys::atomr_trt_onnx_parser_num_errors(parser);
327                    crate::sys::atomr_trt_onnx_parser_destroy(parser);
328                    crate::sys::atomr_trt_builder_destroy(builder);
329                    return Err(TrtError::Onnx(format!(
330                        "onnx parse failed (rc={parse_rc}, errors={nerr})"
331                    )));
332                }
333
334                let cfg_ptr = crate::sys::atomr_trt_builder_create_config(builder);
335                if cfg_ptr.is_null() {
336                    crate::sys::atomr_trt_onnx_parser_destroy(parser);
337                    crate::sys::atomr_trt_builder_destroy(builder);
338                    return Err(TrtError::Build(
339                        "builder_create_config returned null".into(),
340                    ));
341                }
342                // Replay caller-requested flags onto the C++ config.
343                let flags = _config.effective_flags();
344                for flag in [
345                    (BuilderFlags::FP16, crate::sys::BuilderFlag::kFP16 as u32),
346                    (BuilderFlags::INT8, crate::sys::BuilderFlag::kINT8 as u32),
347                    (BuilderFlags::TF32, crate::sys::BuilderFlag::kTF32 as u32),
348                    (BuilderFlags::BF16, crate::sys::BuilderFlag::kBF16 as u32),
349                    (BuilderFlags::FP8, crate::sys::BuilderFlag::kFP8 as u32),
350                    (BuilderFlags::REFIT, crate::sys::BuilderFlag::kREFIT as u32),
351                    (
352                        BuilderFlags::SPARSE_WEIGHTS,
353                        crate::sys::BuilderFlag::kSPARSE_WEIGHTS as u32,
354                    ),
355                    (
356                        BuilderFlags::STRIP_PLAN,
357                        crate::sys::BuilderFlag::kSTRIP_PLAN as u32,
358                    ),
359                ] {
360                    if flags.contains(flag.0) {
361                        crate::sys::atomr_trt_config_set_flag(cfg_ptr, flag.1, 1);
362                    }
363                }
364                if _config.workspace_bytes > 0 {
365                    crate::sys::atomr_trt_config_set_memory_pool_limit(
366                        cfg_ptr,
367                        0, // kWORKSPACE
368                        _config.workspace_bytes,
369                    );
370                }
371
372                let host_mem =
373                    crate::sys::atomr_trt_builder_build_serialized(builder, network, cfg_ptr);
374                let cleanup = || {
375                    crate::sys::atomr_trt_config_destroy(cfg_ptr);
376                    crate::sys::atomr_trt_onnx_parser_destroy(parser);
377                    crate::sys::atomr_trt_builder_destroy(builder);
378                };
379                if host_mem.is_null() {
380                    cleanup();
381                    return Err(TrtError::Build(
382                        "buildSerializedNetwork returned null".into(),
383                    ));
384                }
385                let data_ptr = crate::sys::atomr_trt_host_memory_data(host_mem);
386                let data_len = crate::sys::atomr_trt_host_memory_size(host_mem);
387                let bytes = if data_ptr.is_null() || data_len == 0 {
388                    Vec::new()
389                } else {
390                    std::slice::from_raw_parts(data_ptr, data_len).to_vec()
391                };
392                crate::sys::atomr_trt_host_memory_destroy(host_mem);
393                cleanup();
394                if bytes.is_empty() {
395                    return Err(TrtError::Build("serialised plan was empty".into()));
396                }
397                Ok(EnginePlan::new(bytes))
398            }
399        }
400        #[cfg(not(all(feature = "tensorrt-link", feature = "tensorrt-onnx")))]
401        {
402            Err(TrtError::NotLinked(
403                "TrtActor::build_from_onnx requires the `tensorrt-link` + `tensorrt-onnx` features",
404            ))
405        }
406    }
407}
408
409impl Default for TrtActor {
410    fn default() -> Self {
411        Self::new()
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::builder::Precision;
419
420    #[test]
421    fn trt_msg_constructs() {
422        // Walk every variant — confirms the message enum builds and
423        // is `Send`-clean (oneshot::Sender is Send for any T).
424        let (b_tx, _b_rx) = oneshot::channel();
425        let _build = TrtMsg::Build {
426            source: NetworkSource::SerializedPlan(vec![1, 2, 3]),
427            config: Box::new(IBuilderConfig::new().with_precision(Precision::Fp16)),
428            reply: b_tx,
429        };
430
431        let (d_tx, _d_rx) = oneshot::channel();
432        let _deser = TrtMsg::Deserialize {
433            plan: EnginePlan::new(vec![0xAA; 8]),
434            reply: d_tx,
435        };
436
437        let engine = Arc::new(TrtEngine::for_test());
438        let (c_tx, _c_rx) = oneshot::channel();
439        let _ctx = TrtMsg::CreateContext {
440            engine: engine.clone(),
441            reply: c_tx,
442        };
443
444        let (r_tx, _r_rx) = oneshot::channel();
445        let _refit = TrtMsg::Refit {
446            engine: engine.clone(),
447            weights: vec![RefitWeights {
448                name: "fc.weight".into(),
449                bytes: vec![0; 16],
450                dtype: crate::sys::DataType::kHALF,
451            }],
452            reply: r_tx,
453        };
454
455        // Verify the actor itself is Send so it can live inside an
456        // `atomr_core::actor::Actor`.
457        fn assert_send<T: Send>() {}
458        assert_send::<TrtActor>();
459    }
460
461    #[test]
462    fn actor_runtime_lazy_init() {
463        let actor = TrtActor::new();
464        // Without the link feature this should error cleanly, never
465        // panic.
466        #[cfg(not(feature = "tensorrt-link"))]
467        {
468            let r = actor.ensure_runtime();
469            assert!(matches!(r, Err(TrtError::NotLinked(_))));
470        }
471        #[cfg(feature = "tensorrt-link")]
472        {
473            // Real link path is exercised by integration tests with a
474            // GPU host.
475            let _ = actor;
476        }
477    }
478}