Skip to main content

atomr_accel_tensorrt/
engine.rs

1//! Safe wrapper around `nvinfer1::ICudaEngine`.
2//!
3//! The C++ object is `*mut sys::ICudaEngine`; we wrap it in a newtype
4//! that owns the pointer and is `Send + Sync` because the engine is
5//! immutable post-build (multiple `IExecutionContext`s share it
6//! safely). The `Drop` impl calls the FFI destroy shim under the
7//! `tensorrt-link` feature; without the feature the pointer is always
8//! null and `Drop` is a no-op (so unit tests construct engines without
9//! libnvinfer).
10
11use std::sync::Arc;
12
13use crate::error::TrtError;
14use crate::sys;
15
16/// Owned, immutable TensorRT engine.
17///
18/// Built either from a serialised plan via [`TrtRuntime::deserialize`]
19/// or from a fresh build via [`crate::builder::IBuilderConfig`] +
20/// `TrtActor::Build`.
21pub struct TrtEngine {
22    raw: *mut sys::ICudaEngine,
23    /// Cached number of I/O tensors; populated under the link feature.
24    num_io: usize,
25}
26
27// SAFETY: post-build engines are immutable and the C++ runtime is
28// thread-safe for concurrent reads / `IExecutionContext` creation.
29unsafe impl Send for TrtEngine {}
30unsafe impl Sync for TrtEngine {}
31
32impl TrtEngine {
33    /// Construct a wrapper from a raw pointer obtained from the FFI
34    /// shim. Returns `Err` if the pointer is null.
35    ///
36    /// # Safety
37    /// Caller must ensure `raw` was returned by a TensorRT runtime /
38    /// builder shim and has not been destroyed.
39    pub unsafe fn from_raw(raw: *mut sys::ICudaEngine, num_io: usize) -> Result<Self, TrtError> {
40        if raw.is_null() {
41            Err(TrtError::NullEngine)
42        } else {
43            Ok(Self { raw, num_io })
44        }
45    }
46
47    /// Test-only constructor (no FFI). Used by the unit tests to
48    /// exercise the Send/Sync newtype on hosts without libnvinfer.
49    #[allow(dead_code)]
50    pub(crate) fn for_test() -> Self {
51        Self {
52            raw: std::ptr::null_mut(),
53            num_io: 0,
54        }
55    }
56
57    pub fn raw(&self) -> *mut sys::ICudaEngine {
58        self.raw
59    }
60
61    pub fn num_io_tensors(&self) -> usize {
62        self.num_io
63    }
64
65    /// Phase 4.5++ — name of the I/O tensor at index `idx`. Returns
66    /// `None` if `idx >= num_io_tensors()` or if the upstream FFI
67    /// shim returns a null/invalid string.
68    ///
69    /// Without `tensorrt-link` this always returns `None` — there's no
70    /// linked libnvinfer to query.
71    pub fn io_tensor_name(&self, _idx: usize) -> Option<String> {
72        #[cfg(feature = "tensorrt-link")]
73        {
74            if _idx >= self.num_io {
75                return None;
76            }
77            unsafe {
78                let p = sys::atomr_trt_engine_io_tensor_name(self.raw, _idx as i32);
79                if p.is_null() {
80                    return None;
81                }
82                let cstr = std::ffi::CStr::from_ptr(p);
83                cstr.to_str().ok().map(|s| s.to_string())
84            }
85        }
86        #[cfg(not(feature = "tensorrt-link"))]
87        {
88            None
89        }
90    }
91
92    /// Wrap the engine in an `Arc<TrtEngine>` so multiple
93    /// `ExecutionContext`s can share it.
94    pub fn into_shared(self) -> Arc<TrtEngine> {
95        Arc::new(self)
96    }
97}
98
99impl Drop for TrtEngine {
100    fn drop(&mut self) {
101        #[cfg(feature = "tensorrt-link")]
102        unsafe {
103            if !self.raw.is_null() {
104                sys::atomr_trt_engine_destroy(self.raw);
105            }
106        }
107        // Without `tensorrt-link`: pointer is null (test-only path),
108        // nothing to free.
109    }
110}
111
112/// Owned plan blob (serialised engine).
113///
114/// Stored as a `Vec<u8>` rather than the TensorRT `IHostMemory*` so
115/// it survives shim teardown and can be journaled / written to disk.
116#[derive(Debug, Clone)]
117pub struct EnginePlan(pub Vec<u8>);
118
119impl EnginePlan {
120    pub fn new(bytes: Vec<u8>) -> Self {
121        Self(bytes)
122    }
123
124    pub fn as_slice(&self) -> &[u8] {
125        &self.0
126    }
127}
128
129/// Refit handle — holds an `IRefitter*` for in-place engine weight
130/// updates.
131pub struct TrtRefitter {
132    raw: *mut sys::IRefitter,
133}
134
135unsafe impl Send for TrtRefitter {}
136unsafe impl Sync for TrtRefitter {}
137
138impl TrtRefitter {
139    /// # Safety
140    /// `raw` must be a valid pointer returned by the refitter shim.
141    pub unsafe fn from_raw(raw: *mut sys::IRefitter) -> Result<Self, TrtError> {
142        if raw.is_null() {
143            Err(TrtError::Refit("null refitter".into()))
144        } else {
145            Ok(Self { raw })
146        }
147    }
148
149    #[allow(dead_code)]
150    pub(crate) fn for_test() -> Self {
151        Self {
152            raw: std::ptr::null_mut(),
153        }
154    }
155
156    pub fn raw(&self) -> *mut sys::IRefitter {
157        self.raw
158    }
159}
160
161impl Drop for TrtRefitter {
162    fn drop(&mut self) {
163        #[cfg(feature = "tensorrt-link")]
164        unsafe {
165            if !self.raw.is_null() {
166                sys::atomr_trt_refitter_destroy(self.raw);
167            }
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    fn assert_send_sync<T: Send + Sync>() {}
177
178    #[test]
179    fn engine_handle_send_sync() {
180        // Newtype must be Send + Sync so it can live inside Arc<...>
181        // and ride actor messages across tokio threads.
182        assert_send_sync::<TrtEngine>();
183        assert_send_sync::<Arc<TrtEngine>>();
184        assert_send_sync::<TrtRefitter>();
185
186        let e = TrtEngine::for_test();
187        assert_eq!(e.num_io_tensors(), 0);
188        let shared: Arc<TrtEngine> = e.into_shared();
189        assert!(Arc::strong_count(&shared) >= 1);
190    }
191
192    #[test]
193    fn engine_plan_round_trip() {
194        let plan = EnginePlan::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
195        assert_eq!(plan.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
196    }
197}