Skip to main content

atomr_accel_tensorrt/
actor.rs

1//! `TrtActor` — sibling of `atomr_accel_cuda::DeviceActor`.
2//!
3//! Lifecycle:
4//! - On `Build` it consumes a network builder (or ONNX bytes when
5//!   `tensorrt-onnx` is enabled) plus an [`IBuilderConfig`], drives
6//!   `IBuilder::buildSerializedNetwork` and returns an
7//!   [`EnginePlan`].
8//! - On `Deserialize` it loads a previously built plan into an
9//!   [`TrtEngine`].
10//! - On `CreateContext` it creates a fresh [`ExecutionContext`].
11//! - On `EnqueueOnStream { stream, context, reply }` it submits the
12//!   inference on the supplied `Arc<cudarc::driver::CudaStream>` —
13//!   the same stream type carried by `DeviceActor` so the two actors
14//!   share one CUDA execution timeline.
15//! - On `Refit` it patches engine weights via [`TrtRefitter`].
16//!
17//! The actor keeps the `TrtEngine` alive in an `Arc` so multiple
18//! `ExecutionContext`s can share it.
19
20use std::sync::Arc;
21
22use tokio::sync::oneshot;
23
24use crate::builder::IBuilderConfig;
25use crate::engine::{EnginePlan, TrtEngine};
26use crate::error::TrtError;
27use crate::runtime::{ExecutionBindings, ExecutionContext};
28
29/// Network description for `TrtMsg::Build`. The builder API has
30/// many entry points; for now we accept either a serialised ONNX blob
31/// (under `tensorrt-onnx`) or a precompiled TensorRT plan to import.
32#[derive(Debug, Clone)]
33pub enum NetworkSource {
34    /// Raw ONNX bytes. Requires the `tensorrt-onnx` feature.
35    Onnx(Vec<u8>),
36    /// A previously serialised TensorRT plan; just deserialise.
37    SerializedPlan(Vec<u8>),
38}
39
40/// Descriptor of a single weight blob to push into the engine via
41/// the refitter. The pointer / device pointer is **not** held inside
42/// the message; instead callers pass a host-side blob (refitter
43/// stages it). Future variants can add a `DevicePtr` tag if direct
44/// device-to-device refit is desired.
45pub struct RefitWeights {
46    pub name: String,
47    pub bytes: Vec<u8>,
48    pub dtype: crate::sys::DataType,
49}
50
51/// Reply types for each `TrtMsg` variant. Each is a `oneshot::Sender`
52/// so the actor never blocks on IO.
53pub type BuildReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
54pub type DeserializeReply = oneshot::Sender<Result<Arc<TrtEngine>, TrtError>>;
55pub type CreateContextReply = oneshot::Sender<Result<ExecutionContext, TrtError>>;
56pub type EnqueueReply = oneshot::Sender<Result<(), TrtError>>;
57pub type RefitReply = oneshot::Sender<Result<(), TrtError>>;
58
59/// Public message surface for `TrtActor`.
60///
61/// The variant `EnqueueOnStream` accepts the `Arc<CudaStream>` from
62/// `atomr-accel-cuda::DeviceActor` so the TensorRT runtime shares
63/// the device's stream timeline (no cross-stream synchronisation,
64/// no extra event hops).
65pub enum TrtMsg {
66    /// Build a TensorRT engine from a network source + config.
67    /// Returns the serialised plan on success.
68    Build {
69        source: NetworkSource,
70        config: Box<IBuilderConfig>,
71        reply: BuildReply,
72    },
73
74    /// Deserialise a plan blob into a shared engine handle.
75    Deserialize {
76        plan: EnginePlan,
77        reply: DeserializeReply,
78    },
79
80    /// Create a fresh `IExecutionContext` from an existing engine.
81    /// Returns the new context (caller owns it).
82    CreateContext {
83        engine: Arc<TrtEngine>,
84        reply: CreateContextReply,
85    },
86
87    /// Submit `enqueueV3` on the supplied CUDA stream. The actor
88    /// returns immediately after submission; real GPU completion is
89    /// observed by `atomr-accel-cuda`'s completion strategy on the
90    /// shared stream.
91    EnqueueOnStream {
92        stream: Arc<cudarc::driver::CudaStream>,
93        context: ExecutionContext,
94        bindings: ExecutionBindings,
95        reply: EnqueueReply,
96    },
97
98    /// Refit a built engine in-place with new weights. Requires the
99    /// engine to have been built with `RefitPolicy::OnDemand` or
100    /// `WeightsStreaming`.
101    Refit {
102        engine: Arc<TrtEngine>,
103        weights: Vec<RefitWeights>,
104        reply: RefitReply,
105    },
106}
107
108/// `TrtActor` — owns nothing across messages besides the FFI
109/// runtime/builder handles, all engines/contexts ride the messages.
110///
111/// The actor itself is intentionally minimal: most of the heavy
112/// state lives in `Arc<TrtEngine>` values that the caller threads
113/// through. This mirrors `DeviceActor`'s design where per-context
114/// state lives in the `ContextActor` but engines live with the
115/// caller.
116pub struct TrtActor {
117    /// Cached runtime; lazily created on first `Deserialize`. Held
118    /// behind a `parking_lot::Mutex` because the actor mailbox
119    /// already serialises but interior mutability avoids a redundant
120    /// `&mut self` thread through every method.
121    runtime: parking_lot::Mutex<Option<crate::runtime::TrtRuntime>>,
122}
123
124impl TrtActor {
125    pub fn new() -> Self {
126        Self {
127            runtime: parking_lot::Mutex::new(None),
128        }
129    }
130
131    /// Get-or-create the cached runtime. Without `tensorrt-link` the
132    /// inner constructor returns `NotLinked`.
133    pub fn ensure_runtime(&self) -> Result<(), TrtError> {
134        let mut guard = self.runtime.lock();
135        if guard.is_none() {
136            *guard = Some(crate::runtime::TrtRuntime::new()?);
137        }
138        Ok(())
139    }
140}
141
142impl Default for TrtActor {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::builder::Precision;
152
153    #[test]
154    fn trt_msg_constructs() {
155        // Walk every variant — confirms the message enum builds and
156        // is `Send`-clean (oneshot::Sender is Send for any T).
157        let (b_tx, _b_rx) = oneshot::channel();
158        let _build = TrtMsg::Build {
159            source: NetworkSource::SerializedPlan(vec![1, 2, 3]),
160            config: Box::new(IBuilderConfig::new().with_precision(Precision::Fp16)),
161            reply: b_tx,
162        };
163
164        let (d_tx, _d_rx) = oneshot::channel();
165        let _deser = TrtMsg::Deserialize {
166            plan: EnginePlan::new(vec![0xAA; 8]),
167            reply: d_tx,
168        };
169
170        let engine = Arc::new(TrtEngine::for_test());
171        let (c_tx, _c_rx) = oneshot::channel();
172        let _ctx = TrtMsg::CreateContext {
173            engine: engine.clone(),
174            reply: c_tx,
175        };
176
177        let (r_tx, _r_rx) = oneshot::channel();
178        let _refit = TrtMsg::Refit {
179            engine: engine.clone(),
180            weights: vec![RefitWeights {
181                name: "fc.weight".into(),
182                bytes: vec![0; 16],
183                dtype: crate::sys::DataType::kHALF,
184            }],
185            reply: r_tx,
186        };
187
188        // Verify the actor itself is Send so it can live inside an
189        // `atomr_core::actor::Actor`.
190        fn assert_send<T: Send>() {}
191        assert_send::<TrtActor>();
192    }
193
194    #[test]
195    fn actor_runtime_lazy_init() {
196        let actor = TrtActor::new();
197        // Without the link feature this should error cleanly, never
198        // panic.
199        #[cfg(not(feature = "tensorrt-link"))]
200        {
201            let r = actor.ensure_runtime();
202            assert!(matches!(r, Err(TrtError::NotLinked(_))));
203        }
204        #[cfg(feature = "tensorrt-link")]
205        {
206            // Real link path is exercised by integration tests with a
207            // GPU host.
208            let _ = actor;
209        }
210    }
211}