atomr_accel_tensorrt/
runtime.rs1#![allow(dead_code)]
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use tokio::sync::oneshot;
16
17use crate::engine::TrtEngine;
18use crate::error::TrtError;
19use crate::sys;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct TensorShape {
25 pub nb_dims: usize,
26 pub dims: [i32; 8],
27}
28
29impl TensorShape {
30 pub fn new(dims: &[i32]) -> Self {
31 assert!(dims.len() <= 8, "TensorRT supports at most 8 dimensions");
32 let mut out = [0i32; 8];
33 out[..dims.len()].copy_from_slice(dims);
34 Self {
35 nb_dims: dims.len(),
36 dims: out,
37 }
38 }
39
40 pub fn as_slice(&self) -> &[i32] {
41 &self.dims[..self.nb_dims]
42 }
43}
44
45#[derive(Debug, Clone, Default)]
49pub struct ExecutionBindings {
50 pub addresses: HashMap<String, u64>,
51 pub input_shapes: HashMap<String, TensorShape>,
52}
53
54impl ExecutionBindings {
55 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn bind(&mut self, name: impl Into<String>, device_ptr: u64) -> &mut Self {
60 self.addresses.insert(name.into(), device_ptr);
61 self
62 }
63
64 pub fn set_shape(&mut self, name: impl Into<String>, shape: TensorShape) -> &mut Self {
65 self.input_shapes.insert(name.into(), shape);
66 self
67 }
68}
69
70pub struct ExecutionContext {
74 raw: *mut sys::IExecutionContext,
75 engine: Arc<TrtEngine>,
76}
77
78unsafe impl Send for ExecutionContext {}
83unsafe impl Sync for ExecutionContext {}
84
85impl ExecutionContext {
86 pub unsafe fn from_raw(
90 raw: *mut sys::IExecutionContext,
91 engine: Arc<TrtEngine>,
92 ) -> Result<Self, TrtError> {
93 if raw.is_null() {
94 Err(TrtError::Execution("null execution context".into()))
95 } else {
96 Ok(Self { raw, engine })
97 }
98 }
99
100 pub(crate) fn for_test(engine: Arc<TrtEngine>) -> Self {
101 Self {
102 raw: std::ptr::null_mut(),
103 engine,
104 }
105 }
106
107 pub fn raw(&self) -> *mut sys::IExecutionContext {
108 self.raw
109 }
110
111 pub fn engine(&self) -> &Arc<TrtEngine> {
112 &self.engine
113 }
114}
115
116impl Drop for ExecutionContext {
117 fn drop(&mut self) {
118 #[cfg(feature = "tensorrt-link")]
119 unsafe {
120 if !self.raw.is_null() {
121 sys::atomr_trt_context_destroy(self.raw);
122 }
123 }
124 }
125}
126
127pub struct TrtRuntime {
129 raw: *mut sys::IRuntime,
130}
131
132unsafe impl Send for TrtRuntime {}
133unsafe impl Sync for TrtRuntime {}
134
135impl TrtRuntime {
136 pub fn new() -> Result<Self, TrtError> {
139 #[cfg(feature = "tensorrt-link")]
140 {
141 let raw = unsafe { sys::atomr_trt_runtime_create(0) };
142 if raw.is_null() {
143 Err(TrtError::Runtime("runtime create returned null".into()))
144 } else {
145 Ok(Self { raw })
146 }
147 }
148 #[cfg(not(feature = "tensorrt-link"))]
149 {
150 Err(TrtError::NotLinked(
151 "TrtRuntime requires the `tensorrt-link` feature",
152 ))
153 }
154 }
155
156 pub(crate) fn for_test() -> Self {
157 Self {
158 raw: std::ptr::null_mut(),
159 }
160 }
161
162 pub fn deserialize(&self, _plan: &[u8]) -> Result<TrtEngine, TrtError> {
165 #[cfg(feature = "tensorrt-link")]
166 {
167 let raw = unsafe {
168 sys::atomr_trt_runtime_deserialize(self.raw, _plan.as_ptr(), _plan.len())
169 };
170 if raw.is_null() {
171 Err(TrtError::Runtime("deserialize returned null".into()))
172 } else {
173 let num_io = unsafe { sys::atomr_trt_engine_num_io_tensors(raw) } as usize;
174 unsafe { TrtEngine::from_raw(raw, num_io) }
175 }
176 }
177 #[cfg(not(feature = "tensorrt-link"))]
178 {
179 Err(TrtError::NotLinked(
180 "TrtRuntime::deserialize requires the `tensorrt-link` feature",
181 ))
182 }
183 }
184}
185
186impl Drop for TrtRuntime {
187 fn drop(&mut self) {
188 #[cfg(feature = "tensorrt-link")]
189 unsafe {
190 if !self.raw.is_null() {
191 sys::atomr_trt_runtime_destroy(self.raw);
192 }
193 }
194 }
195}
196
197pub type EnqueueReply = Result<(), TrtError>;
201
202pub struct EnqueueRequest {
207 pub bindings: ExecutionBindings,
208 pub stream: Arc<cudarc::driver::CudaStream>,
209 pub reply: oneshot::Sender<EnqueueReply>,
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 fn assert_send_sync<T: Send + Sync>() {}
217
218 #[test]
219 fn execution_context_msg_round_trip() {
220 let engine = Arc::new(TrtEngine::for_test());
221 let ctx = ExecutionContext::for_test(engine.clone());
222 assert!(Arc::ptr_eq(ctx.engine(), &engine));
223
224 let mut bindings = ExecutionBindings::new();
225 bindings
226 .bind("input", 0xDEADBEEF)
227 .set_shape("input", TensorShape::new(&[1, 3, 224, 224]));
228 assert_eq!(bindings.addresses.get("input").copied(), Some(0xDEADBEEF));
229 assert_eq!(
230 bindings.input_shapes.get("input").map(|s| s.as_slice()),
231 Some(&[1i32, 3, 224, 224][..])
232 );
233
234 assert_send_sync::<ExecutionBindings>();
235 assert_send_sync::<TrtRuntime>();
236 assert_send_sync::<ExecutionContext>();
237 }
238
239 #[test]
240 fn shape_round_trip() {
241 let s = TensorShape::new(&[2, 4, 8]);
242 assert_eq!(s.nb_dims, 3);
243 assert_eq!(s.as_slice(), &[2, 4, 8]);
244 }
245
246 #[test]
247 fn runtime_unlinked_returns_not_linked() {
248 #[cfg(not(feature = "tensorrt-link"))]
251 {
252 let r = TrtRuntime::new();
253 assert!(matches!(r, Err(TrtError::NotLinked(_))));
254 }
255 }
256}