atomr_accel_tensorrt/actor.rs
1//! `TrtActor` — sibling of `atomr_accel_cuda::DeviceActor`.
2//!
3//! Lifecycle:
4//! - On `Build` it consumes a network builder (or ONNX bytes when
5//! `tensorrt-onnx` is enabled) plus an [`IBuilderConfig`], drives
6//! `IBuilder::buildSerializedNetwork` and returns an
7//! [`EnginePlan`].
8//! - On `Deserialize` it loads a previously built plan into an
9//! [`TrtEngine`].
10//! - On `CreateContext` it creates a fresh [`ExecutionContext`].
11//! - On `EnqueueOnStream { stream, context, reply }` it submits the
12//! inference on the supplied `Arc<cudarc::driver::CudaStream>` —
13//! the same stream type carried by `DeviceActor` so the two actors
14//! share one CUDA execution timeline.
15//! - On `Refit` it patches engine weights via [`TrtRefitter`].
16//!
17//! The actor keeps the `TrtEngine` alive in an `Arc` so multiple
18//! `ExecutionContext`s can share it.
19
20use std::sync::Arc;
21
22use tokio::sync::oneshot;
23
24use crate::builder::IBuilderConfig;
25use crate::engine::{EnginePlan, TrtEngine};
26use crate::error::TrtError;
27use crate::runtime::{ExecutionBindings, ExecutionContext};
28
29/// Network description for `TrtMsg::Build`. The builder API has
30/// many entry points; for now we accept either a serialised ONNX blob
31/// (under `tensorrt-onnx`) or a precompiled TensorRT plan to import.
32#[derive(Debug, Clone)]
33pub enum NetworkSource {
34 /// Raw ONNX bytes. Requires the `tensorrt-onnx` feature.
35 Onnx(Vec<u8>),
36 /// A previously serialised TensorRT plan; just deserialise.
37 SerializedPlan(Vec<u8>),
38}
39
40/// Descriptor of a single weight blob to push into the engine via
41/// the refitter. The pointer / device pointer is **not** held inside
42/// the message; instead callers pass a host-side blob (refitter
43/// stages it). Future variants can add a `DevicePtr` tag if direct
44/// device-to-device refit is desired.
45pub struct RefitWeights {
46 pub name: String,
47 pub bytes: Vec<u8>,
48 pub dtype: crate::sys::DataType,
49}
50
51/// Reply types for each `TrtMsg` variant. Each is a `oneshot::Sender`
52/// so the actor never blocks on IO.
53pub type BuildReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
54pub type DeserializeReply = oneshot::Sender<Result<Arc<TrtEngine>, TrtError>>;
55pub type CreateContextReply = oneshot::Sender<Result<ExecutionContext, TrtError>>;
56pub type EnqueueReply = oneshot::Sender<Result<(), TrtError>>;
57pub type RefitReply = oneshot::Sender<Result<(), TrtError>>;
58
59/// Public message surface for `TrtActor`.
60///
61/// The variant `EnqueueOnStream` accepts the `Arc<CudaStream>` from
62/// `atomr-accel-cuda::DeviceActor` so the TensorRT runtime shares
63/// the device's stream timeline (no cross-stream synchronisation,
64/// no extra event hops).
65pub enum TrtMsg {
66 /// Build a TensorRT engine from a network source + config.
67 /// Returns the serialised plan on success.
68 Build {
69 source: NetworkSource,
70 config: Box<IBuilderConfig>,
71 reply: BuildReply,
72 },
73
74 /// Deserialise a plan blob into a shared engine handle.
75 Deserialize {
76 plan: EnginePlan,
77 reply: DeserializeReply,
78 },
79
80 /// Create a fresh `IExecutionContext` from an existing engine.
81 /// Returns the new context (caller owns it).
82 CreateContext {
83 engine: Arc<TrtEngine>,
84 reply: CreateContextReply,
85 },
86
87 /// Submit `enqueueV3` on the supplied CUDA stream. The actor
88 /// returns immediately after submission; real GPU completion is
89 /// observed by `atomr-accel-cuda`'s completion strategy on the
90 /// shared stream.
91 EnqueueOnStream {
92 stream: Arc<cudarc::driver::CudaStream>,
93 context: ExecutionContext,
94 bindings: ExecutionBindings,
95 reply: EnqueueReply,
96 },
97
98 /// Refit a built engine in-place with new weights. Requires the
99 /// engine to have been built with `RefitPolicy::OnDemand` or
100 /// `WeightsStreaming`.
101 Refit {
102 engine: Arc<TrtEngine>,
103 weights: Vec<RefitWeights>,
104 reply: RefitReply,
105 },
106}
107
108/// `TrtActor` — owns nothing across messages besides the FFI
109/// runtime/builder handles, all engines/contexts ride the messages.
110///
111/// The actor itself is intentionally minimal: most of the heavy
112/// state lives in `Arc<TrtEngine>` values that the caller threads
113/// through. This mirrors `DeviceActor`'s design where per-context
114/// state lives in the `ContextActor` but engines live with the
115/// caller.
116pub struct TrtActor {
117 /// Cached runtime; lazily created on first `Deserialize`. Held
118 /// behind a `parking_lot::Mutex` because the actor mailbox
119 /// already serialises but interior mutability avoids a redundant
120 /// `&mut self` thread through every method.
121 runtime: parking_lot::Mutex<Option<crate::runtime::TrtRuntime>>,
122}
123
124impl TrtActor {
125 pub fn new() -> Self {
126 Self {
127 runtime: parking_lot::Mutex::new(None),
128 }
129 }
130
131 /// Get-or-create the cached runtime. Without `tensorrt-link` the
132 /// inner constructor returns `NotLinked`.
133 pub fn ensure_runtime(&self) -> Result<(), TrtError> {
134 let mut guard = self.runtime.lock();
135 if guard.is_none() {
136 *guard = Some(crate::runtime::TrtRuntime::new()?);
137 }
138 Ok(())
139 }
140}
141
142impl Default for TrtActor {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::builder::Precision;
152
153 #[test]
154 fn trt_msg_constructs() {
155 // Walk every variant — confirms the message enum builds and
156 // is `Send`-clean (oneshot::Sender is Send for any T).
157 let (b_tx, _b_rx) = oneshot::channel();
158 let _build = TrtMsg::Build {
159 source: NetworkSource::SerializedPlan(vec![1, 2, 3]),
160 config: Box::new(IBuilderConfig::new().with_precision(Precision::Fp16)),
161 reply: b_tx,
162 };
163
164 let (d_tx, _d_rx) = oneshot::channel();
165 let _deser = TrtMsg::Deserialize {
166 plan: EnginePlan::new(vec![0xAA; 8]),
167 reply: d_tx,
168 };
169
170 let engine = Arc::new(TrtEngine::for_test());
171 let (c_tx, _c_rx) = oneshot::channel();
172 let _ctx = TrtMsg::CreateContext {
173 engine: engine.clone(),
174 reply: c_tx,
175 };
176
177 let (r_tx, _r_rx) = oneshot::channel();
178 let _refit = TrtMsg::Refit {
179 engine: engine.clone(),
180 weights: vec![RefitWeights {
181 name: "fc.weight".into(),
182 bytes: vec![0; 16],
183 dtype: crate::sys::DataType::kHALF,
184 }],
185 reply: r_tx,
186 };
187
188 // Verify the actor itself is Send so it can live inside an
189 // `atomr_core::actor::Actor`.
190 fn assert_send<T: Send>() {}
191 assert_send::<TrtActor>();
192 }
193
194 #[test]
195 fn actor_runtime_lazy_init() {
196 let actor = TrtActor::new();
197 // Without the link feature this should error cleanly, never
198 // panic.
199 #[cfg(not(feature = "tensorrt-link"))]
200 {
201 let r = actor.ensure_runtime();
202 assert!(matches!(r, Err(TrtError::NotLinked(_))));
203 }
204 #[cfg(feature = "tensorrt-link")]
205 {
206 // Real link path is exercised by integration tests with a
207 // GPU host.
208 let _ = actor;
209 }
210 }
211}