Skip to main content

atomr_accel_tensorrt/
runtime.rs

1//! Safe wrapper around `nvinfer1::IExecutionContext` and the
2//! `IRuntime` deserialiser. This is the inference-time hot path.
3//!
4//! The runtime actor (`TrtActor`, see `actor.rs`) drives an
5//! `ExecutionContext` per inference, calling `enqueueV3` on a CUDA
6//! stream provided by `atomr-accel-cuda::DeviceActor`. The actor
7//! never blocks on the GPU; completion is signalled via the same
8//! host-fn-completion mechanism the BLAS/cuDNN actors use.
9
10#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use tokio::sync::oneshot;
16
17use crate::engine::TrtEngine;
18use crate::error::TrtError;
19use crate::sys;
20
21/// Shape of a dynamic tensor input. Exactly mirrors `nvinfer1::Dims`
22/// (max 8 dims).
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct TensorShape {
25    pub nb_dims: usize,
26    pub dims: [i32; 8],
27}
28
29impl TensorShape {
30    pub fn new(dims: &[i32]) -> Self {
31        assert!(dims.len() <= 8, "TensorRT supports at most 8 dimensions");
32        let mut out = [0i32; 8];
33        out[..dims.len()].copy_from_slice(dims);
34        Self {
35            nb_dims: dims.len(),
36            dims: out,
37        }
38    }
39
40    pub fn as_slice(&self) -> &[i32] {
41        &self.dims[..self.nb_dims]
42    }
43}
44
45/// Per-call inputs/outputs: tensor name → device pointer.
46/// Pointers are raw `u64`s (CUDA device addresses) so the message is
47/// `Send + Sync` without lifetimes from `Arc<CudaSlice<T>>`.
48#[derive(Debug, Clone, Default)]
49pub struct ExecutionBindings {
50    pub addresses: HashMap<String, u64>,
51    pub input_shapes: HashMap<String, TensorShape>,
52}
53
54impl ExecutionBindings {
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    pub fn bind(&mut self, name: impl Into<String>, device_ptr: u64) -> &mut Self {
60        self.addresses.insert(name.into(), device_ptr);
61        self
62    }
63
64    pub fn set_shape(&mut self, name: impl Into<String>, shape: TensorShape) -> &mut Self {
65        self.input_shapes.insert(name.into(), shape);
66        self
67    }
68}
69
70/// Wrapper around an owned `IExecutionContext*`. Held inside the
71/// `TrtActor` and is **not** `Send` to outside callers — the actor
72/// owns it for life and serialises access.
73pub struct ExecutionContext {
74    raw: *mut sys::IExecutionContext,
75    engine: Arc<TrtEngine>,
76}
77
78// SAFETY: The `IExecutionContext` is only ever touched from inside the
79// owning actor (single-thread access serialised by the mailbox). The
80// underlying TensorRT runtime is thread-safe for concurrent
81// `enqueueV3` calls *across distinct* contexts that share an engine.
82unsafe impl Send for ExecutionContext {}
83unsafe impl Sync for ExecutionContext {}
84
85impl ExecutionContext {
86    /// # Safety
87    /// `raw` must be a valid pointer returned by
88    /// `IcudaEngine::createExecutionContext`.
89    pub unsafe fn from_raw(
90        raw: *mut sys::IExecutionContext,
91        engine: Arc<TrtEngine>,
92    ) -> Result<Self, TrtError> {
93        if raw.is_null() {
94            Err(TrtError::Execution("null execution context".into()))
95        } else {
96            Ok(Self { raw, engine })
97        }
98    }
99
100    pub(crate) fn for_test(engine: Arc<TrtEngine>) -> Self {
101        Self {
102            raw: std::ptr::null_mut(),
103            engine,
104        }
105    }
106
107    pub fn raw(&self) -> *mut sys::IExecutionContext {
108        self.raw
109    }
110
111    pub fn engine(&self) -> &Arc<TrtEngine> {
112        &self.engine
113    }
114}
115
116impl Drop for ExecutionContext {
117    fn drop(&mut self) {
118        #[cfg(feature = "tensorrt-link")]
119        unsafe {
120            if !self.raw.is_null() {
121                sys::atomr_trt_context_destroy(self.raw);
122            }
123        }
124    }
125}
126
127/// Owned wrapper for `IRuntime`, used to deserialise plan blobs.
128pub struct TrtRuntime {
129    raw: *mut sys::IRuntime,
130}
131
132unsafe impl Send for TrtRuntime {}
133unsafe impl Sync for TrtRuntime {}
134
135impl TrtRuntime {
136    /// Construct a runtime. Without the `tensorrt-link` feature this
137    /// returns `Err(NotLinked)`.
138    pub fn new() -> Result<Self, TrtError> {
139        #[cfg(feature = "tensorrt-link")]
140        {
141            let raw = unsafe { sys::atomr_trt_runtime_create(0) };
142            if raw.is_null() {
143                Err(TrtError::Runtime("runtime create returned null".into()))
144            } else {
145                Ok(Self { raw })
146            }
147        }
148        #[cfg(not(feature = "tensorrt-link"))]
149        {
150            Err(TrtError::NotLinked(
151                "TrtRuntime requires the `tensorrt-link` feature",
152            ))
153        }
154    }
155
156    pub(crate) fn for_test() -> Self {
157        Self {
158            raw: std::ptr::null_mut(),
159        }
160    }
161
162    /// Deserialise a plan blob. Without the link feature this is an
163    /// error.
164    pub fn deserialize(&self, _plan: &[u8]) -> Result<TrtEngine, TrtError> {
165        #[cfg(feature = "tensorrt-link")]
166        {
167            let raw = unsafe {
168                sys::atomr_trt_runtime_deserialize(self.raw, _plan.as_ptr(), _plan.len())
169            };
170            if raw.is_null() {
171                Err(TrtError::Runtime("deserialize returned null".into()))
172            } else {
173                let num_io = unsafe { sys::atomr_trt_engine_num_io_tensors(raw) } as usize;
174                unsafe { TrtEngine::from_raw(raw, num_io) }
175            }
176        }
177        #[cfg(not(feature = "tensorrt-link"))]
178        {
179            Err(TrtError::NotLinked(
180                "TrtRuntime::deserialize requires the `tensorrt-link` feature",
181            ))
182        }
183    }
184}
185
186impl Drop for TrtRuntime {
187    fn drop(&mut self) {
188        #[cfg(feature = "tensorrt-link")]
189        unsafe {
190            if !self.raw.is_null() {
191                sys::atomr_trt_runtime_destroy(self.raw);
192            }
193        }
194    }
195}
196
197/// Reply payload for an enqueue request. Ok = stream submission
198/// succeeded (kernel still running on the GPU); the caller awaits
199/// real completion via the shared CUDA stream completion sentinel.
200pub type EnqueueReply = Result<(), TrtError>;
201
202/// Standalone enqueue request type — embedded into the `TrtActor`'s
203/// message enum but exposed here so the message dispatcher in
204/// `actor.rs` and tests can construct it without crossing module
205/// boundaries.
206pub struct EnqueueRequest {
207    pub bindings: ExecutionBindings,
208    pub stream: Arc<cudarc::driver::CudaStream>,
209    pub reply: oneshot::Sender<EnqueueReply>,
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    fn assert_send_sync<T: Send + Sync>() {}
217
218    #[test]
219    fn execution_context_msg_round_trip() {
220        let engine = Arc::new(TrtEngine::for_test());
221        let ctx = ExecutionContext::for_test(engine.clone());
222        assert!(Arc::ptr_eq(ctx.engine(), &engine));
223
224        let mut bindings = ExecutionBindings::new();
225        bindings
226            .bind("input", 0xDEADBEEF)
227            .set_shape("input", TensorShape::new(&[1, 3, 224, 224]));
228        assert_eq!(bindings.addresses.get("input").copied(), Some(0xDEADBEEF));
229        assert_eq!(
230            bindings.input_shapes.get("input").map(|s| s.as_slice()),
231            Some(&[1i32, 3, 224, 224][..])
232        );
233
234        assert_send_sync::<ExecutionBindings>();
235        assert_send_sync::<TrtRuntime>();
236        assert_send_sync::<ExecutionContext>();
237    }
238
239    #[test]
240    fn shape_round_trip() {
241        let s = TensorShape::new(&[2, 4, 8]);
242        assert_eq!(s.nb_dims, 3);
243        assert_eq!(s.as_slice(), &[2, 4, 8]);
244    }
245
246    #[test]
247    fn runtime_unlinked_returns_not_linked() {
248        // Without the link feature, TrtRuntime::new must surface a
249        // clean error instead of panicking.
250        #[cfg(not(feature = "tensorrt-link"))]
251        {
252            let r = TrtRuntime::new();
253            assert!(matches!(r, Err(TrtError::NotLinked(_))));
254        }
255    }
256}