async_tensorrt/
engine.rs

1use async_cuda::runtime::Future;
2use async_cuda::{DeviceBuffer, Stream};
3
4use crate::ffi::memory::HostBuffer;
5use crate::ffi::sync::engine::Engine as InnerEngine;
6use crate::ffi::sync::engine::ExecutionContext as InnerExecutionContext;
7
8pub use crate::ffi::sync::engine::TensorIoMode;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12/// Engine for executing inference on a built network.
13///
14/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html)
15pub struct Engine {
16    inner: InnerEngine,
17}
18
19impl Engine {
20    /// Create [`Engine`] from its inner object.
21    pub fn from_inner(inner: InnerEngine) -> Self {
22        Self { inner }
23    }
24
25    /// Serialize the network.
26    ///
27    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ab42c2fde3292f557ed17aae6f332e571)
28    ///
29    /// # Return value
30    ///
31    /// A [`HostBuffer`] that contains the serialized engine.
32    #[inline(always)]
33    pub fn serialize(&self) -> Result<HostBuffer> {
34        self.inner.serialize()
35    }
36
37    /// Get the number of IO tensors.
38    ///
39    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#af2018924cbea2fa84808040e60c58405)
40    #[inline(always)]
41    pub fn num_io_tensors(&self) -> usize {
42        self.inner.num_io_tensors()
43    }
44
45    /// Retrieve the name of an IO tensor.
46    ///
47    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#a0b1e9e3f82724be40f0ab74742deaf92)
48    ///
49    /// # Arguments
50    ///
51    /// * `io_tensor_index` - IO tensor index.
52    #[inline(always)]
53    pub fn io_tensor_name(&self, io_tensor_index: usize) -> String {
54        self.inner.io_tensor_name(io_tensor_index)
55    }
56
57    /// Get the shape of a tensor.
58    ///
59    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#af96a2ee402ab47b7e0b7f0becb63d693)
60    ///
61    /// # Arguments
62    ///
63    /// * `tensor_name` - Tensor name.
64    #[inline(always)]
65    pub fn tensor_shape(&self, tensor_name: &str) -> Vec<usize> {
66        self.inner.tensor_shape(tensor_name)
67    }
68
69    /// Get the IO mode of a tensor.
70    ///
71    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ae236a14178df506070cd39a9ef3775e7)
72    ///
73    /// # Arguments
74    ///
75    /// * `tensor_name` - Tensor name.
76    #[inline(always)]
77    pub fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode {
78        self.inner.tensor_io_mode(tensor_name)
79    }
80}
81
82/// Context for executing inference using an engine.
83///
84/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html)
85pub struct ExecutionContext<'engine> {
86    inner: InnerExecutionContext<'engine>,
87}
88
89impl ExecutionContext<'static> {
90    /// Create an execution context from an [`Engine`].
91    ///
92    /// This is the owned version of [`ExecutionContext::new()`]. It consumes the engine. In
93    /// exchange, it produces an execution context with a `'static` lifetime.
94    ///
95    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ac7a34cf3b59aa633a35f66f07f22a617)
96    ///
97    /// # Arguments
98    ///
99    /// * `engine` - Parent engine.
100    pub async fn from_engine(engine: Engine) -> Result<Self> {
101        Future::new(move || {
102            InnerExecutionContext::from_engine(engine.inner).map(ExecutionContext::from_inner_owned)
103        })
104        .await
105    }
106
107    /// Create multiple execution contexts from an [`Engine`].
108    ///
109    /// This is the owned version of [`ExecutionContext::new()`]. It consumes the engine. In
110    /// exchange, it produces a set of execution contexts with a `'static` lifetime.
111    ///
112    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ac7a34cf3b59aa633a35f66f07f22a617)
113    ///
114    /// # Arguments
115    ///
116    /// * `engine` - Parent engine.
117    /// * `num` - Number of execution contexsts to produce.
118    pub async fn from_engine_many(engine: Engine, num: usize) -> Result<Vec<Self>> {
119        Future::new(move || {
120            Ok(InnerExecutionContext::from_engine_many(engine.inner, num)?
121                .into_iter()
122                .map(Self::from_inner_owned)
123                .collect())
124        })
125        .await
126    }
127
128    /// Create [`ExecutionContext`] from its inner object.
129    fn from_inner_owned(inner: InnerExecutionContext<'static>) -> Self {
130        Self { inner }
131    }
132}
133
134impl<'engine> ExecutionContext<'engine> {
135    /// Create [`ExecutionContext`] from its inner object.
136    fn from_inner(inner: InnerExecutionContext<'engine>) -> Self {
137        Self { inner }
138    }
139
140    /// Create an execution context from an [`Engine`].
141    ///
142    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ac7a34cf3b59aa633a35f66f07f22a617)
143    ///
144    /// # Arguments
145    ///
146    /// * `engine` - Parent engine.
147    pub async fn new(engine: &mut Engine) -> Result<ExecutionContext> {
148        Future::new(move || {
149            InnerExecutionContext::new(&mut engine.inner).map(ExecutionContext::from_inner)
150        })
151        .await
152    }
153
154    /// Asynchronously execute inference.
155    ///
156    /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a63cd95430852038ce864e17c670e0b36)
157    ///
158    /// # Stream ordered semantics
159    ///
160    /// This function exhibits stream ordered semantics. This means that it is only guaranteed to
161    /// complete serially with respect to other operations on the same stream.
162    ///
163    /// # Thread-safety
164    ///
165    /// Calling this function from the same context with a different CUDA stream concurrently
166    /// results in undefined behavior. To perform inference concurrently in multiple streams, use
167    /// one execution context per stream.
168    ///
169    /// # Arguments
170    ///
171    /// * `io_buffers` - Input and output buffers.
172    /// * `stream` - CUDA stream to execute on.
173    pub async fn enqueue<T: Copy>(
174        &mut self,
175        io_buffers: &mut std::collections::HashMap<&str, &mut DeviceBuffer<T>>,
176        stream: &Stream,
177    ) -> Result<()> {
178        let mut io_buffers_inner = io_buffers
179            .iter_mut()
180            .map(|(name, buffer)| (*name, buffer.inner_mut()))
181            .collect::<std::collections::HashMap<_, _>>();
182        Future::new(move || self.inner.enqueue(&mut io_buffers_inner, stream.inner())).await
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use crate::tests::memory::*;
189    use crate::tests::utils::*;
190
191    use super::*;
192
193    #[tokio::test]
194    async fn test_engine_serialize() {
195        let engine = simple_engine!();
196        let serialized_engine = engine.serialize().unwrap();
197        let serialized_engine_bytes = serialized_engine.as_bytes();
198        assert!(serialized_engine_bytes.len() > 0);
199        assert_eq!(
200            &serialized_engine_bytes[..8],
201            &[102_u8, 116_u8, 114_u8, 116_u8, 0_u8, 0_u8, 0_u8, 0_u8],
202        );
203    }
204
205    #[tokio::test]
206    async fn test_engine_tensor_info() {
207        let engine = simple_engine!();
208        assert_eq!(engine.num_io_tensors(), 2);
209        assert_eq!(engine.io_tensor_name(0), "X");
210        assert_eq!(engine.io_tensor_name(1), "Y");
211        assert_eq!(engine.tensor_io_mode("X"), TensorIoMode::Input);
212        assert_eq!(engine.tensor_io_mode("Y"), TensorIoMode::Output);
213        assert_eq!(engine.tensor_shape("X"), &[1, 2]);
214        assert_eq!(engine.tensor_shape("Y"), &[2, 3]);
215    }
216
217    #[tokio::test]
218    async fn test_execution_context_new() {
219        let mut engine = simple_engine!();
220        assert!(ExecutionContext::new(&mut engine).await.is_ok());
221        assert!(ExecutionContext::new(&mut engine).await.is_ok());
222    }
223
224    #[tokio::test]
225    async fn test_execution_context_enqueue() {
226        let stream = Stream::new().await.unwrap();
227        let mut engine = simple_engine!();
228        let mut context = ExecutionContext::new(&mut engine).await.unwrap();
229        let mut io_buffers = std::collections::HashMap::from([
230            ("X", to_device!(&[2.0, 4.0], &stream)),
231            ("Y", to_device!(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &stream)),
232        ]);
233        let mut io_buffers_ref = io_buffers
234            .iter_mut()
235            .map(|(name, buffer)| (*name, buffer))
236            .collect();
237        context.enqueue(&mut io_buffers_ref, &stream).await.unwrap();
238        let output = to_host!(io_buffers["Y"], &stream);
239        assert_eq!(&output, &[2.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
240    }
241}