Skip to main content

atomr_accel_tensorrt/
engine.rs

1//! Safe wrapper around `nvinfer1::ICudaEngine`.
2//!
3//! The C++ object is `*mut sys::ICudaEngine`; we wrap it in a newtype
4//! that owns the pointer and is `Send + Sync` because the engine is
5//! immutable post-build (multiple `IExecutionContext`s share it
6//! safely). The `Drop` impl calls the FFI destroy shim under the
7//! `tensorrt-link` feature; without the feature the pointer is always
8//! null and `Drop` is a no-op (so unit tests construct engines without
9//! libnvinfer).
10
11use std::sync::Arc;
12
13use crate::error::TrtError;
14use crate::sys;
15
16/// Owned, immutable TensorRT engine.
17///
18/// Built either from a serialised plan via [`TrtRuntime::deserialize`]
19/// or from a fresh build via [`crate::builder::IBuilderConfig`] +
20/// `TrtActor::Build`.
21pub struct TrtEngine {
22    raw: *mut sys::ICudaEngine,
23    /// Cached number of I/O tensors; populated under the link feature.
24    num_io: usize,
25}
26
27// SAFETY: post-build engines are immutable and the C++ runtime is
28// thread-safe for concurrent reads / `IExecutionContext` creation.
29unsafe impl Send for TrtEngine {}
30unsafe impl Sync for TrtEngine {}
31
32impl TrtEngine {
33    /// Construct a wrapper from a raw pointer obtained from the FFI
34    /// shim. Returns `Err` if the pointer is null.
35    ///
36    /// # Safety
37    /// Caller must ensure `raw` was returned by a TensorRT runtime /
38    /// builder shim and has not been destroyed.
39    pub unsafe fn from_raw(raw: *mut sys::ICudaEngine, num_io: usize) -> Result<Self, TrtError> {
40        if raw.is_null() {
41            Err(TrtError::NullEngine)
42        } else {
43            Ok(Self { raw, num_io })
44        }
45    }
46
47    /// Test-only constructor (no FFI). Used by the unit tests to
48    /// exercise the Send/Sync newtype on hosts without libnvinfer.
49    #[allow(dead_code)]
50    pub(crate) fn for_test() -> Self {
51        Self {
52            raw: std::ptr::null_mut(),
53            num_io: 0,
54        }
55    }
56
57    pub fn raw(&self) -> *mut sys::ICudaEngine {
58        self.raw
59    }
60
61    pub fn num_io_tensors(&self) -> usize {
62        self.num_io
63    }
64
65    /// Wrap the engine in an `Arc<TrtEngine>` so multiple
66    /// `ExecutionContext`s can share it.
67    pub fn into_shared(self) -> Arc<TrtEngine> {
68        Arc::new(self)
69    }
70}
71
72impl Drop for TrtEngine {
73    fn drop(&mut self) {
74        #[cfg(feature = "tensorrt-link")]
75        unsafe {
76            if !self.raw.is_null() {
77                sys::atomr_trt_engine_destroy(self.raw);
78            }
79        }
80        // Without `tensorrt-link`: pointer is null (test-only path),
81        // nothing to free.
82    }
83}
84
85/// Owned plan blob (serialised engine).
86///
87/// Stored as a `Vec<u8>` rather than the TensorRT `IHostMemory*` so
88/// it survives shim teardown and can be journaled / written to disk.
89#[derive(Debug, Clone)]
90pub struct EnginePlan(pub Vec<u8>);
91
92impl EnginePlan {
93    pub fn new(bytes: Vec<u8>) -> Self {
94        Self(bytes)
95    }
96
97    pub fn as_slice(&self) -> &[u8] {
98        &self.0
99    }
100}
101
102/// Refit handle — holds an `IRefitter*` for in-place engine weight
103/// updates.
104pub struct TrtRefitter {
105    raw: *mut sys::IRefitter,
106}
107
108unsafe impl Send for TrtRefitter {}
109unsafe impl Sync for TrtRefitter {}
110
111impl TrtRefitter {
112    /// # Safety
113    /// `raw` must be a valid pointer returned by the refitter shim.
114    pub unsafe fn from_raw(raw: *mut sys::IRefitter) -> Result<Self, TrtError> {
115        if raw.is_null() {
116            Err(TrtError::Refit("null refitter".into()))
117        } else {
118            Ok(Self { raw })
119        }
120    }
121
122    #[allow(dead_code)]
123    pub(crate) fn for_test() -> Self {
124        Self {
125            raw: std::ptr::null_mut(),
126        }
127    }
128
129    pub fn raw(&self) -> *mut sys::IRefitter {
130        self.raw
131    }
132}
133
134impl Drop for TrtRefitter {
135    fn drop(&mut self) {
136        #[cfg(feature = "tensorrt-link")]
137        unsafe {
138            if !self.raw.is_null() {
139                sys::atomr_trt_refitter_destroy(self.raw);
140            }
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    fn assert_send_sync<T: Send + Sync>() {}
150
151    #[test]
152    fn engine_handle_send_sync() {
153        // Newtype must be Send + Sync so it can live inside Arc<...>
154        // and ride actor messages across tokio threads.
155        assert_send_sync::<TrtEngine>();
156        assert_send_sync::<Arc<TrtEngine>>();
157        assert_send_sync::<TrtRefitter>();
158
159        let e = TrtEngine::for_test();
160        assert_eq!(e.num_io_tensors(), 0);
161        let shared: Arc<TrtEngine> = e.into_shared();
162        assert!(Arc::strong_count(&shared) >= 1);
163    }
164
165    #[test]
166    fn engine_plan_round_trip() {
167        let plan = EnginePlan::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
168        assert_eq!(plan.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
169    }
170}