atomr_accel_tensorrt/
engine.rs1use std::sync::Arc;
12
13use crate::error::TrtError;
14use crate::sys;
15
16pub struct TrtEngine {
22 raw: *mut sys::ICudaEngine,
23 num_io: usize,
25}
26
27unsafe impl Send for TrtEngine {}
30unsafe impl Sync for TrtEngine {}
31
32impl TrtEngine {
33 pub unsafe fn from_raw(raw: *mut sys::ICudaEngine, num_io: usize) -> Result<Self, TrtError> {
40 if raw.is_null() {
41 Err(TrtError::NullEngine)
42 } else {
43 Ok(Self { raw, num_io })
44 }
45 }
46
47 #[allow(dead_code)]
50 pub(crate) fn for_test() -> Self {
51 Self {
52 raw: std::ptr::null_mut(),
53 num_io: 0,
54 }
55 }
56
57 pub fn raw(&self) -> *mut sys::ICudaEngine {
58 self.raw
59 }
60
61 pub fn num_io_tensors(&self) -> usize {
62 self.num_io
63 }
64
65 pub fn io_tensor_name(&self, _idx: usize) -> Option<String> {
72 #[cfg(feature = "tensorrt-link")]
73 {
74 if _idx >= self.num_io {
75 return None;
76 }
77 unsafe {
78 let p = sys::atomr_trt_engine_io_tensor_name(self.raw, _idx as i32);
79 if p.is_null() {
80 return None;
81 }
82 let cstr = std::ffi::CStr::from_ptr(p);
83 cstr.to_str().ok().map(|s| s.to_string())
84 }
85 }
86 #[cfg(not(feature = "tensorrt-link"))]
87 {
88 None
89 }
90 }
91
92 pub fn into_shared(self) -> Arc<TrtEngine> {
95 Arc::new(self)
96 }
97}
98
99impl Drop for TrtEngine {
100 fn drop(&mut self) {
101 #[cfg(feature = "tensorrt-link")]
102 unsafe {
103 if !self.raw.is_null() {
104 sys::atomr_trt_engine_destroy(self.raw);
105 }
106 }
107 }
110}
111
112#[derive(Debug, Clone)]
117pub struct EnginePlan(pub Vec<u8>);
118
119impl EnginePlan {
120 pub fn new(bytes: Vec<u8>) -> Self {
121 Self(bytes)
122 }
123
124 pub fn as_slice(&self) -> &[u8] {
125 &self.0
126 }
127}
128
129pub struct TrtRefitter {
132 raw: *mut sys::IRefitter,
133}
134
135unsafe impl Send for TrtRefitter {}
136unsafe impl Sync for TrtRefitter {}
137
138impl TrtRefitter {
139 pub unsafe fn from_raw(raw: *mut sys::IRefitter) -> Result<Self, TrtError> {
142 if raw.is_null() {
143 Err(TrtError::Refit("null refitter".into()))
144 } else {
145 Ok(Self { raw })
146 }
147 }
148
149 #[allow(dead_code)]
150 pub(crate) fn for_test() -> Self {
151 Self {
152 raw: std::ptr::null_mut(),
153 }
154 }
155
156 pub fn raw(&self) -> *mut sys::IRefitter {
157 self.raw
158 }
159}
160
161impl Drop for TrtRefitter {
162 fn drop(&mut self) {
163 #[cfg(feature = "tensorrt-link")]
164 unsafe {
165 if !self.raw.is_null() {
166 sys::atomr_trt_refitter_destroy(self.raw);
167 }
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 fn assert_send_sync<T: Send + Sync>() {}
177
178 #[test]
179 fn engine_handle_send_sync() {
180 assert_send_sync::<TrtEngine>();
183 assert_send_sync::<Arc<TrtEngine>>();
184 assert_send_sync::<TrtRefitter>();
185
186 let e = TrtEngine::for_test();
187 assert_eq!(e.num_io_tensors(), 0);
188 let shared: Arc<TrtEngine> = e.into_shared();
189 assert!(Arc::strong_count(&shared) >= 1);
190 }
191
192 #[test]
193 fn engine_plan_round_trip() {
194 let plan = EnginePlan::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
195 assert_eq!(plan.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
196 }
197}