1#![warn(missing_debug_implementations, rust_2018_idioms)]
9
10use std::ffi::{CStr, CString};
11use std::marker::PhantomData;
12use std::os::raw::c_void;
13
14use baracuda_cuda_sys::runtime::cudaStream_t;
15use baracuda_tensorrt_sys as sys;
16
17#[derive(Debug, thiserror::Error)]
18pub enum Error {
19 #[error("TensorRT loader: {0}")]
20 Loader(#[from] baracuda_core::LoaderError),
21 #[error("TensorRT returned null for {op}")]
22 NullHandle { op: &'static str },
23 #[error("TensorRT call failed: {op}")]
24 Call { op: &'static str },
25 #[error("invalid C string: {0}")]
26 Utf8(#[from] std::ffi::NulError),
27}
28
29pub type Result<T> = std::result::Result<T, Error>;
30
31pub use sys::{
32 trtDataType_t as DataType, trtExecutionContextAllocationStrategy_t as AllocStrategy,
33 trtSeverity_t as Severity, trtTensorIOMode_t as IoMode,
34};
35
36pub fn version() -> Result<i32> {
38 let t = sys::tensorrt()?;
39 Ok(unsafe { t.get_infer_lib_version()?() })
40}
41
42#[derive(Copy, Clone, Debug, PartialEq, Eq)]
44pub struct Dims {
45 pub dims: [i64; sys::TRT_MAX_DIMS],
46 pub rank: usize,
47}
48
49impl Dims {
50 pub fn new(dims: &[i64]) -> Self {
51 let mut out = Dims {
52 dims: [0; sys::TRT_MAX_DIMS],
53 rank: dims.len().min(sys::TRT_MAX_DIMS),
54 };
55 out.dims[..out.rank].copy_from_slice(&dims[..out.rank]);
56 out
57 }
58 pub fn as_slice(&self) -> &[i64] {
59 &self.dims[..self.rank]
60 }
61 fn to_raw(self) -> sys::trtDims_t {
62 sys::trtDims_t {
63 nb_dims: self.rank as i32,
64 d: self.dims,
65 }
66 }
67 fn from_raw(raw: sys::trtDims_t) -> Self {
68 let mut out = Dims {
69 dims: [0; sys::TRT_MAX_DIMS],
70 rank: raw.nb_dims.max(0) as usize,
71 };
72 for i in 0..out.rank {
73 out.dims[i] = raw.d[i];
74 }
75 out
76 }
77}
78
79#[derive(Debug)]
82pub struct Runtime {
83 raw: sys::trtIRuntime_t,
84}
85
86impl Runtime {
87 pub unsafe fn new(logger: sys::trtILogger_t) -> Result<Self> { unsafe {
92 let t = sys::tensorrt()?;
93 let raw = (t.create_infer_runtime()?)(logger);
94 if raw.is_null() {
95 return Err(Error::NullHandle {
96 op: "createInferRuntime",
97 });
98 }
99 Ok(Self { raw })
100 }}
101
102 pub fn with_null_logger() -> Result<Self> {
105 unsafe { Self::new(core::ptr::null_mut()) }
106 }
107
108 pub fn deserialize(&self, blob: &[u8]) -> Result<Engine<'_>> {
109 let t = sys::tensorrt()?;
110 let raw = unsafe {
111 (t.deserialize_cuda_engine()?)(self.raw, blob.as_ptr() as *const c_void, blob.len())
112 };
113 if raw.is_null() {
114 return Err(Error::NullHandle {
115 op: "deserializeCudaEngine",
116 });
117 }
118 Ok(Engine {
119 raw,
120 _owner: PhantomData,
121 })
122 }
123
124 pub fn as_raw(&self) -> sys::trtIRuntime_t {
125 self.raw
126 }
127}
128
129impl Drop for Runtime {
130 fn drop(&mut self) {
131 if let Ok(t) = sys::tensorrt() {
132 if let Ok(f) = t.destroy_infer_runtime() {
133 unsafe { f(self.raw) };
134 }
135 }
136 }
137}
138
139#[derive(Debug)]
141pub struct Engine<'rt> {
142 raw: sys::trtICudaEngine_t,
143 _owner: PhantomData<&'rt Runtime>,
144}
145
146impl Engine<'_> {
147 pub fn as_raw(&self) -> sys::trtICudaEngine_t {
148 self.raw
149 }
150
151 pub fn num_io_tensors(&self) -> Result<i32> {
152 let t = sys::tensorrt()?;
153 Ok(unsafe { (t.engine_get_nb_io_tensors()?)(self.raw) })
154 }
155
156 pub fn io_tensor_name(&self, index: i32) -> Result<String> {
157 let t = sys::tensorrt()?;
158 let cstr = unsafe { (t.engine_get_io_tensor_name()?)(self.raw, index) };
159 if cstr.is_null() {
160 return Err(Error::NullHandle {
161 op: "getIOTensorName",
162 });
163 }
164 Ok(unsafe { CStr::from_ptr(cstr) }.to_string_lossy().into_owned())
165 }
166
167 pub fn tensor_io_mode(&self, name: &str) -> Result<IoMode> {
168 let t = sys::tensorrt()?;
169 let c = CString::new(name)?;
170 Ok(unsafe { (t.engine_get_tensor_io_mode()?)(self.raw, c.as_ptr()) })
171 }
172
173 pub fn tensor_data_type(&self, name: &str) -> Result<DataType> {
174 let t = sys::tensorrt()?;
175 let c = CString::new(name)?;
176 Ok(unsafe { (t.engine_get_tensor_data_type()?)(self.raw, c.as_ptr()) })
177 }
178
179 pub fn tensor_shape(&self, name: &str) -> Result<Dims> {
180 let t = sys::tensorrt()?;
181 let c = CString::new(name)?;
182 let raw = unsafe { (t.engine_get_tensor_shape()?)(self.raw, c.as_ptr()) };
183 Ok(Dims::from_raw(raw))
184 }
185
186 pub fn create_execution_context(&self) -> Result<ExecutionContext<'_>> {
187 let t = sys::tensorrt()?;
188 let raw = unsafe { (t.engine_create_execution_context()?)(self.raw) };
189 if raw.is_null() {
190 return Err(Error::NullHandle {
191 op: "createExecutionContext",
192 });
193 }
194 Ok(ExecutionContext {
195 raw,
196 _owner: PhantomData,
197 })
198 }
199
200 pub fn create_execution_context_with_strategy(
205 &self,
206 strategy: AllocStrategy,
207 ) -> Result<ExecutionContext<'_>> {
208 let t = sys::tensorrt()?;
209 let raw = unsafe {
210 (t.engine_create_execution_context_with_strategy()?)(self.raw, strategy)
211 };
212 if raw.is_null() {
213 return Err(Error::NullHandle {
214 op: "createExecutionContextWithStrategy",
215 });
216 }
217 Ok(ExecutionContext {
218 raw,
219 _owner: PhantomData,
220 })
221 }
222
223 pub fn name(&self) -> Result<String> {
225 let t = sys::tensorrt()?;
226 let cstr = unsafe { (t.engine_get_name()?)(self.raw) };
227 if cstr.is_null() {
228 return Err(Error::NullHandle { op: "engineGetName" });
229 }
230 Ok(unsafe { CStr::from_ptr(cstr) }
231 .to_string_lossy()
232 .into_owned())
233 }
234
235 pub fn num_optimization_profiles(&self) -> Result<i32> {
237 let t = sys::tensorrt()?;
238 Ok(unsafe { (t.engine_get_nb_optimization_profiles()?)(self.raw) })
239 }
240
241 pub fn serialize(&self) -> Result<HostMemory> {
245 let t = sys::tensorrt()?;
246 let raw = unsafe { (t.engine_serialize()?)(self.raw) };
247 if raw.is_null() {
248 return Err(Error::NullHandle {
249 op: "engineSerialize",
250 });
251 }
252 Ok(HostMemory { raw })
253 }
254}
255
256#[derive(Debug)]
258pub struct HostMemory {
259 raw: sys::trtIHostMemory_t,
260}
261
262impl HostMemory {
263 pub fn len(&self) -> Result<usize> {
264 let t = sys::tensorrt()?;
265 Ok(unsafe { (t.host_memory_size()?)(self.raw) })
266 }
267
268 pub fn is_empty(&self) -> Result<bool> {
269 Ok(self.len()? == 0)
270 }
271
272 pub fn as_slice(&self) -> Result<&[u8]> {
273 let t = sys::tensorrt()?;
274 let ptr = unsafe { (t.host_memory_data()?)(self.raw) };
275 let len = self.len()?;
276 if ptr.is_null() || len == 0 {
277 return Ok(&[]);
278 }
279 Ok(unsafe { core::slice::from_raw_parts(ptr as *const u8, len) })
280 }
281}
282
283impl Drop for HostMemory {
284 fn drop(&mut self) {
285 if let Ok(t) = sys::tensorrt() {
286 if let Ok(d) = t.host_memory_destroy() {
287 unsafe { d(self.raw) };
288 }
289 }
290 }
291}
292
293impl Drop for Engine<'_> {
294 fn drop(&mut self) {
295 if let Ok(t) = sys::tensorrt() {
296 if let Ok(f) = t.destroy_cuda_engine() {
297 unsafe { f(self.raw) };
298 }
299 }
300 }
301}
302
303#[derive(Debug)]
304pub struct ExecutionContext<'e> {
305 raw: sys::trtIExecutionContext_t,
306 _owner: PhantomData<&'e Engine<'e>>,
307}
308
309impl ExecutionContext<'_> {
310 pub fn as_raw(&self) -> sys::trtIExecutionContext_t {
311 self.raw
312 }
313
314 pub fn set_input_shape(&self, name: &str, dims: Dims) -> Result<()> {
315 let t = sys::tensorrt()?;
316 let c = CString::new(name)?;
317 let raw_dims = dims.to_raw();
318 let ok = unsafe { (t.context_set_input_shape()?)(self.raw, c.as_ptr(), &raw_dims) };
319 if !ok {
320 return Err(Error::Call {
321 op: "setInputShape",
322 });
323 }
324 Ok(())
325 }
326
327 pub unsafe fn set_tensor_address(&self, name: &str, addr: *mut c_void) -> Result<()> {
339 let t = sys::tensorrt()?;
340 let c = CString::new(name)?;
341 let ok = unsafe { (t.context_set_tensor_address()?)(self.raw, c.as_ptr(), addr) };
342 if !ok {
343 return Err(Error::Call {
344 op: "setTensorAddress",
345 });
346 }
347 Ok(())
348 }
349
350 pub fn tensor_shape(&self, name: &str) -> Result<Dims> {
351 let t = sys::tensorrt()?;
352 let c = CString::new(name)?;
353 let raw = unsafe { (t.context_get_tensor_shape()?)(self.raw, c.as_ptr()) };
354 Ok(Dims::from_raw(raw))
355 }
356
357 pub fn tensor_address(&self, name: &str) -> Result<*mut c_void> {
359 let t = sys::tensorrt()?;
360 let c = CString::new(name)?;
361 Ok(unsafe { (t.context_get_tensor_address()?)(self.raw, c.as_ptr()) })
362 }
363
364 pub unsafe fn enqueue_v3(&self, stream: cudaStream_t) -> Result<()> { unsafe {
371 let t = sys::tensorrt()?;
372 let ok = (t.context_enqueue_v3()?)(self.raw, stream);
373 if !ok {
374 return Err(Error::Call { op: "enqueueV3" });
375 }
376 Ok(())
377 }}
378}
379
380impl Drop for ExecutionContext<'_> {
381 fn drop(&mut self) {
382 if let Ok(t) = sys::tensorrt() {
383 if let Ok(f) = t.destroy_execution_context() {
384 unsafe { f(self.raw) };
385 }
386 }
387 }
388}