atomr_accel_tensorrt/
engine.rs1use std::sync::Arc;
12
13use crate::error::TrtError;
14use crate::sys;
15
16pub struct TrtEngine {
22 raw: *mut sys::ICudaEngine,
23 num_io: usize,
25}
26
27unsafe impl Send for TrtEngine {}
30unsafe impl Sync for TrtEngine {}
31
32impl TrtEngine {
33 pub unsafe fn from_raw(raw: *mut sys::ICudaEngine, num_io: usize) -> Result<Self, TrtError> {
40 if raw.is_null() {
41 Err(TrtError::NullEngine)
42 } else {
43 Ok(Self { raw, num_io })
44 }
45 }
46
47 #[allow(dead_code)]
50 pub(crate) fn for_test() -> Self {
51 Self {
52 raw: std::ptr::null_mut(),
53 num_io: 0,
54 }
55 }
56
57 pub fn raw(&self) -> *mut sys::ICudaEngine {
58 self.raw
59 }
60
61 pub fn num_io_tensors(&self) -> usize {
62 self.num_io
63 }
64
65 pub fn into_shared(self) -> Arc<TrtEngine> {
68 Arc::new(self)
69 }
70}
71
72impl Drop for TrtEngine {
73 fn drop(&mut self) {
74 #[cfg(feature = "tensorrt-link")]
75 unsafe {
76 if !self.raw.is_null() {
77 sys::atomr_trt_engine_destroy(self.raw);
78 }
79 }
80 }
83}
84
85#[derive(Debug, Clone)]
90pub struct EnginePlan(pub Vec<u8>);
91
92impl EnginePlan {
93 pub fn new(bytes: Vec<u8>) -> Self {
94 Self(bytes)
95 }
96
97 pub fn as_slice(&self) -> &[u8] {
98 &self.0
99 }
100}
101
102pub struct TrtRefitter {
105 raw: *mut sys::IRefitter,
106}
107
108unsafe impl Send for TrtRefitter {}
109unsafe impl Sync for TrtRefitter {}
110
111impl TrtRefitter {
112 pub unsafe fn from_raw(raw: *mut sys::IRefitter) -> Result<Self, TrtError> {
115 if raw.is_null() {
116 Err(TrtError::Refit("null refitter".into()))
117 } else {
118 Ok(Self { raw })
119 }
120 }
121
122 #[allow(dead_code)]
123 pub(crate) fn for_test() -> Self {
124 Self {
125 raw: std::ptr::null_mut(),
126 }
127 }
128
129 pub fn raw(&self) -> *mut sys::IRefitter {
130 self.raw
131 }
132}
133
134impl Drop for TrtRefitter {
135 fn drop(&mut self) {
136 #[cfg(feature = "tensorrt-link")]
137 unsafe {
138 if !self.raw.is_null() {
139 sys::atomr_trt_refitter_destroy(self.raw);
140 }
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 fn assert_send_sync<T: Send + Sync>() {}
150
151 #[test]
152 fn engine_handle_send_sync() {
153 assert_send_sync::<TrtEngine>();
156 assert_send_sync::<Arc<TrtEngine>>();
157 assert_send_sync::<TrtRefitter>();
158
159 let e = TrtEngine::for_test();
160 assert_eq!(e.num_io_tensors(), 0);
161 let shared: Arc<TrtEngine> = e.into_shared();
162 assert!(Arc::strong_count(&shared) >= 1);
163 }
164
165 #[test]
166 fn engine_plan_round_trip() {
167 let plan = EnginePlan::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
168 assert_eq!(plan.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
169 }
170}