Skip to main content

baracuda_tensorrt/
lib.rs

1//! Safe-ish TensorRT runtime-side bindings.
2//!
3//! Scope is inference only: load a serialized engine (e.g. produced by
4//! `trtexec` or the TRT Python API), build an execution context, bind tensor
5//! addresses, and enqueue execution on a CUDA stream. Engine construction
6//! (the builder / network definition API) is C++-only and is not wrapped here.
7
8#![warn(missing_debug_implementations, rust_2018_idioms)]
9
10use std::ffi::{CStr, CString};
11use std::marker::PhantomData;
12use std::os::raw::c_void;
13
14use baracuda_cuda_sys::runtime::cudaStream_t;
15use baracuda_tensorrt_sys as sys;
16
17#[derive(Debug, thiserror::Error)]
18pub enum Error {
19    #[error("TensorRT loader: {0}")]
20    Loader(#[from] baracuda_core::LoaderError),
21    #[error("TensorRT returned null for {op}")]
22    NullHandle { op: &'static str },
23    #[error("TensorRT call failed: {op}")]
24    Call { op: &'static str },
25    #[error("invalid C string: {0}")]
26    Utf8(#[from] std::ffi::NulError),
27}
28
29pub type Result<T> = std::result::Result<T, Error>;
30
31pub use sys::{
32    trtDataType_t as DataType, trtExecutionContextAllocationStrategy_t as AllocStrategy,
33    trtSeverity_t as Severity, trtTensorIOMode_t as IoMode,
34};
35
36/// TensorRT library version, encoded as `MAJOR * 1000 + MINOR * 100 + PATCH`.
37pub fn version() -> Result<i32> {
38    let t = sys::tensorrt()?;
39    Ok(unsafe { t.get_infer_lib_version()?() })
40}
41
42/// A dimension list up to 8 axes.
43#[derive(Copy, Clone, Debug, PartialEq, Eq)]
44pub struct Dims {
45    pub dims: [i64; sys::TRT_MAX_DIMS],
46    pub rank: usize,
47}
48
49impl Dims {
50    pub fn new(dims: &[i64]) -> Self {
51        let mut out = Dims {
52            dims: [0; sys::TRT_MAX_DIMS],
53            rank: dims.len().min(sys::TRT_MAX_DIMS),
54        };
55        out.dims[..out.rank].copy_from_slice(&dims[..out.rank]);
56        out
57    }
58    pub fn as_slice(&self) -> &[i64] {
59        &self.dims[..self.rank]
60    }
61    fn to_raw(self) -> sys::trtDims_t {
62        sys::trtDims_t {
63            nb_dims: self.rank as i32,
64            d: self.dims,
65        }
66    }
67    fn from_raw(raw: sys::trtDims_t) -> Self {
68        let mut out = Dims {
69            dims: [0; sys::TRT_MAX_DIMS],
70            rank: raw.nb_dims.max(0) as usize,
71        };
72        for i in 0..out.rank {
73            out.dims[i] = raw.d[i];
74        }
75        out
76    }
77}
78
79/// Owned TensorRT runtime. Created around a user-supplied logger (the logger
80/// pointer is passed verbatim; safety responsibility is on the caller).
81#[derive(Debug)]
82pub struct Runtime {
83    raw: sys::trtIRuntime_t,
84}
85
86impl Runtime {
87    /// # Safety
88    /// `logger` must be a valid `nvinfer1::ILogger*` (typically obtained from
89    /// C++ or passed through a thin shim). Use [`Runtime::with_null_logger`]
90    /// if no logging is desired (TRT allows `nullptr` in recent versions).
91    pub unsafe fn new(logger: sys::trtILogger_t) -> Result<Self> { unsafe {
92        let t = sys::tensorrt()?;
93        let raw = (t.create_infer_runtime()?)(logger);
94        if raw.is_null() {
95            return Err(Error::NullHandle {
96                op: "createInferRuntime",
97            });
98        }
99        Ok(Self { raw })
100    }}
101
102    /// Construct without a logger. Supported on recent TensorRT; older
103    /// versions may refuse and return null.
104    pub fn with_null_logger() -> Result<Self> {
105        unsafe { Self::new(core::ptr::null_mut()) }
106    }
107
108    pub fn deserialize(&self, blob: &[u8]) -> Result<Engine<'_>> {
109        let t = sys::tensorrt()?;
110        let raw = unsafe {
111            (t.deserialize_cuda_engine()?)(self.raw, blob.as_ptr() as *const c_void, blob.len())
112        };
113        if raw.is_null() {
114            return Err(Error::NullHandle {
115                op: "deserializeCudaEngine",
116            });
117        }
118        Ok(Engine {
119            raw,
120            _owner: PhantomData,
121        })
122    }
123
124    pub fn as_raw(&self) -> sys::trtIRuntime_t {
125        self.raw
126    }
127}
128
129impl Drop for Runtime {
130    fn drop(&mut self) {
131        if let Ok(t) = sys::tensorrt() {
132            if let Ok(f) = t.destroy_infer_runtime() {
133                unsafe { f(self.raw) };
134            }
135        }
136    }
137}
138
139/// Represents a deserialized TensorRT engine borrowed from its parent runtime.
140#[derive(Debug)]
141pub struct Engine<'rt> {
142    raw: sys::trtICudaEngine_t,
143    _owner: PhantomData<&'rt Runtime>,
144}
145
146impl Engine<'_> {
147    pub fn as_raw(&self) -> sys::trtICudaEngine_t {
148        self.raw
149    }
150
151    pub fn num_io_tensors(&self) -> Result<i32> {
152        let t = sys::tensorrt()?;
153        Ok(unsafe { (t.engine_get_nb_io_tensors()?)(self.raw) })
154    }
155
156    pub fn io_tensor_name(&self, index: i32) -> Result<String> {
157        let t = sys::tensorrt()?;
158        let cstr = unsafe { (t.engine_get_io_tensor_name()?)(self.raw, index) };
159        if cstr.is_null() {
160            return Err(Error::NullHandle {
161                op: "getIOTensorName",
162            });
163        }
164        Ok(unsafe { CStr::from_ptr(cstr) }.to_string_lossy().into_owned())
165    }
166
167    pub fn tensor_io_mode(&self, name: &str) -> Result<IoMode> {
168        let t = sys::tensorrt()?;
169        let c = CString::new(name)?;
170        Ok(unsafe { (t.engine_get_tensor_io_mode()?)(self.raw, c.as_ptr()) })
171    }
172
173    pub fn tensor_data_type(&self, name: &str) -> Result<DataType> {
174        let t = sys::tensorrt()?;
175        let c = CString::new(name)?;
176        Ok(unsafe { (t.engine_get_tensor_data_type()?)(self.raw, c.as_ptr()) })
177    }
178
179    pub fn tensor_shape(&self, name: &str) -> Result<Dims> {
180        let t = sys::tensorrt()?;
181        let c = CString::new(name)?;
182        let raw = unsafe { (t.engine_get_tensor_shape()?)(self.raw, c.as_ptr()) };
183        Ok(Dims::from_raw(raw))
184    }
185
186    pub fn create_execution_context(&self) -> Result<ExecutionContext<'_>> {
187        let t = sys::tensorrt()?;
188        let raw = unsafe { (t.engine_create_execution_context()?)(self.raw) };
189        if raw.is_null() {
190            return Err(Error::NullHandle {
191                op: "createExecutionContext",
192            });
193        }
194        Ok(ExecutionContext {
195            raw,
196            _owner: PhantomData,
197        })
198    }
199
200    /// Create an execution context with a user-chosen allocation strategy.
201    /// Use [`AllocStrategy::UserManaged`] when you intend to supply a
202    /// scratch-allocator yourself; [`AllocStrategy::Static`] preallocates
203    /// the maximum workspace at context-creation time.
204    pub fn create_execution_context_with_strategy(
205        &self,
206        strategy: AllocStrategy,
207    ) -> Result<ExecutionContext<'_>> {
208        let t = sys::tensorrt()?;
209        let raw = unsafe {
210            (t.engine_create_execution_context_with_strategy()?)(self.raw, strategy)
211        };
212        if raw.is_null() {
213            return Err(Error::NullHandle {
214                op: "createExecutionContextWithStrategy",
215            });
216        }
217        Ok(ExecutionContext {
218            raw,
219            _owner: PhantomData,
220        })
221    }
222
223    /// Engine name as set in the TensorRT builder.
224    pub fn name(&self) -> Result<String> {
225        let t = sys::tensorrt()?;
226        let cstr = unsafe { (t.engine_get_name()?)(self.raw) };
227        if cstr.is_null() {
228            return Err(Error::NullHandle { op: "engineGetName" });
229        }
230        Ok(unsafe { CStr::from_ptr(cstr) }
231            .to_string_lossy()
232            .into_owned())
233    }
234
235    /// Number of optimization profiles that were baked into the engine.
236    pub fn num_optimization_profiles(&self) -> Result<i32> {
237        let t = sys::tensorrt()?;
238        Ok(unsafe { (t.engine_get_nb_optimization_profiles()?)(self.raw) })
239    }
240
241    /// Serialize this engine back into a byte blob you can round-trip to
242    /// disk. The returned [`HostMemory`] owns TensorRT-allocated storage;
243    /// use [`HostMemory::as_slice`] to copy.
244    pub fn serialize(&self) -> Result<HostMemory> {
245        let t = sys::tensorrt()?;
246        let raw = unsafe { (t.engine_serialize()?)(self.raw) };
247        if raw.is_null() {
248            return Err(Error::NullHandle {
249                op: "engineSerialize",
250            });
251        }
252        Ok(HostMemory { raw })
253    }
254}
255
256/// TensorRT-owned host buffer (as returned by [`Engine::serialize`]).
257#[derive(Debug)]
258pub struct HostMemory {
259    raw: sys::trtIHostMemory_t,
260}
261
262impl HostMemory {
263    pub fn len(&self) -> Result<usize> {
264        let t = sys::tensorrt()?;
265        Ok(unsafe { (t.host_memory_size()?)(self.raw) })
266    }
267
268    pub fn is_empty(&self) -> Result<bool> {
269        Ok(self.len()? == 0)
270    }
271
272    pub fn as_slice(&self) -> Result<&[u8]> {
273        let t = sys::tensorrt()?;
274        let ptr = unsafe { (t.host_memory_data()?)(self.raw) };
275        let len = self.len()?;
276        if ptr.is_null() || len == 0 {
277            return Ok(&[]);
278        }
279        Ok(unsafe { core::slice::from_raw_parts(ptr as *const u8, len) })
280    }
281}
282
283impl Drop for HostMemory {
284    fn drop(&mut self) {
285        if let Ok(t) = sys::tensorrt() {
286            if let Ok(d) = t.host_memory_destroy() {
287                unsafe { d(self.raw) };
288            }
289        }
290    }
291}
292
293impl Drop for Engine<'_> {
294    fn drop(&mut self) {
295        if let Ok(t) = sys::tensorrt() {
296            if let Ok(f) = t.destroy_cuda_engine() {
297                unsafe { f(self.raw) };
298            }
299        }
300    }
301}
302
303#[derive(Debug)]
304pub struct ExecutionContext<'e> {
305    raw: sys::trtIExecutionContext_t,
306    _owner: PhantomData<&'e Engine<'e>>,
307}
308
309impl ExecutionContext<'_> {
310    pub fn as_raw(&self) -> sys::trtIExecutionContext_t {
311        self.raw
312    }
313
314    pub fn set_input_shape(&self, name: &str, dims: Dims) -> Result<()> {
315        let t = sys::tensorrt()?;
316        let c = CString::new(name)?;
317        let raw_dims = dims.to_raw();
318        let ok = unsafe { (t.context_set_input_shape()?)(self.raw, c.as_ptr(), &raw_dims) };
319        if !ok {
320            return Err(Error::Call {
321                op: "setInputShape",
322            });
323        }
324        Ok(())
325    }
326
327    /// Bind a device pointer to a named input/output tensor on this
328    /// execution context. The pointer is forwarded to TensorRT's
329    /// `setTensorAddress`, which uses it during `enqueueV3` execution.
330    ///
331    /// # Safety
332    ///
333    /// `addr` must point to device memory that:
334    /// - is large enough for the tensor's bound shape and data type, and
335    /// - remains valid (not freed, not unmapped) for the duration of any
336    ///   `enqueueV3` call that runs after this binding and before the
337    ///   stream completes.
338    pub unsafe fn set_tensor_address(&self, name: &str, addr: *mut c_void) -> Result<()> {
339        let t = sys::tensorrt()?;
340        let c = CString::new(name)?;
341        let ok = unsafe { (t.context_set_tensor_address()?)(self.raw, c.as_ptr(), addr) };
342        if !ok {
343            return Err(Error::Call {
344                op: "setTensorAddress",
345            });
346        }
347        Ok(())
348    }
349
350    pub fn tensor_shape(&self, name: &str) -> Result<Dims> {
351        let t = sys::tensorrt()?;
352        let c = CString::new(name)?;
353        let raw = unsafe { (t.context_get_tensor_shape()?)(self.raw, c.as_ptr()) };
354        Ok(Dims::from_raw(raw))
355    }
356
357    /// Read the current bound device address for a tensor (null if unset).
358    pub fn tensor_address(&self, name: &str) -> Result<*mut c_void> {
359        let t = sys::tensorrt()?;
360        let c = CString::new(name)?;
361        Ok(unsafe { (t.context_get_tensor_address()?)(self.raw, c.as_ptr()) })
362    }
363
364    /// Enqueue the inference on the given CUDA stream. Returns Ok if TRT
365    /// reports success; the stream is still responsible for ordering, and the
366    /// caller must ensure all tensor addresses have been set.
367    ///
368    /// # Safety
369    /// `stream` must be a valid `cudaStream_t` that outlives the enqueue.
370    pub unsafe fn enqueue_v3(&self, stream: cudaStream_t) -> Result<()> { unsafe {
371        let t = sys::tensorrt()?;
372        let ok = (t.context_enqueue_v3()?)(self.raw, stream);
373        if !ok {
374            return Err(Error::Call { op: "enqueueV3" });
375        }
376        Ok(())
377    }}
378}
379
380impl Drop for ExecutionContext<'_> {
381    fn drop(&mut self) {
382        if let Ok(t) = sys::tensorrt() {
383            if let Ok(f) = t.destroy_execution_context() {
384                unsafe { f(self.raw) };
385            }
386        }
387    }
388}