async_tensorrt/engine.rs
1use async_cuda::runtime::Future;
2use async_cuda::{DeviceBuffer, Stream};
3
4use crate::ffi::memory::HostBuffer;
5use crate::ffi::sync::engine::Engine as InnerEngine;
6use crate::ffi::sync::engine::ExecutionContext as InnerExecutionContext;
7
8pub use crate::ffi::sync::engine::TensorIoMode;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12/// Engine for executing inference on a built network.
13///
14/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html)
15pub struct Engine {
16 inner: InnerEngine,
17}
18
19impl Engine {
20 /// Create [`Engine`] from its inner object.
21 pub fn from_inner(inner: InnerEngine) -> Self {
22 Self { inner }
23 }
24
25 /// Serialize the network.
26 ///
27 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ab42c2fde3292f557ed17aae6f332e571)
28 ///
29 /// # Return value
30 ///
31 /// A [`HostBuffer`] that contains the serialized engine.
32 #[inline(always)]
33 pub fn serialize(&self) -> Result<HostBuffer> {
34 self.inner.serialize()
35 }
36
37 /// Get the number of IO tensors.
38 ///
39 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#af2018924cbea2fa84808040e60c58405)
40 #[inline(always)]
41 pub fn num_io_tensors(&self) -> usize {
42 self.inner.num_io_tensors()
43 }
44
45 /// Retrieve the name of an IO tensor.
46 ///
47 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#a0b1e9e3f82724be40f0ab74742deaf92)
48 ///
49 /// # Arguments
50 ///
51 /// * `io_tensor_index` - IO tensor index.
52 #[inline(always)]
53 pub fn io_tensor_name(&self, io_tensor_index: usize) -> String {
54 self.inner.io_tensor_name(io_tensor_index)
55 }
56
57 /// Get the shape of a tensor.
58 ///
59 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#af96a2ee402ab47b7e0b7f0becb63d693)
60 ///
61 /// # Arguments
62 ///
63 /// * `tensor_name` - Tensor name.
64 #[inline(always)]
65 pub fn tensor_shape(&self, tensor_name: &str) -> Vec<usize> {
66 self.inner.tensor_shape(tensor_name)
67 }
68
69 /// Get the IO mode of a tensor.
70 ///
71 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ae236a14178df506070cd39a9ef3775e7)
72 ///
73 /// # Arguments
74 ///
75 /// * `tensor_name` - Tensor name.
76 #[inline(always)]
77 pub fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode {
78 self.inner.tensor_io_mode(tensor_name)
79 }
80}
81
82/// Context for executing inference using an engine.
83///
84/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html)
85pub struct ExecutionContext<'engine> {
86 inner: InnerExecutionContext<'engine>,
87}
88
89impl ExecutionContext<'static> {
90 /// Create an execution context from an [`Engine`].
91 ///
92 /// This is the owned version of [`ExecutionContext::new()`]. It consumes the engine. In
93 /// exchange, it produces an execution context with a `'static` lifetime.
94 ///
95 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ac7a34cf3b59aa633a35f66f07f22a617)
96 ///
97 /// # Arguments
98 ///
99 /// * `engine` - Parent engine.
100 pub async fn from_engine(engine: Engine) -> Result<Self> {
101 Future::new(move || {
102 InnerExecutionContext::from_engine(engine.inner).map(ExecutionContext::from_inner_owned)
103 })
104 .await
105 }
106
107 /// Create multiple execution contexts from an [`Engine`].
108 ///
109 /// This is the owned version of [`ExecutionContext::new()`]. It consumes the engine. In
110 /// exchange, it produces a set of execution contexts with a `'static` lifetime.
111 ///
112 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ac7a34cf3b59aa633a35f66f07f22a617)
113 ///
114 /// # Arguments
115 ///
116 /// * `engine` - Parent engine.
117 /// * `num` - Number of execution contexsts to produce.
118 pub async fn from_engine_many(engine: Engine, num: usize) -> Result<Vec<Self>> {
119 Future::new(move || {
120 Ok(InnerExecutionContext::from_engine_many(engine.inner, num)?
121 .into_iter()
122 .map(Self::from_inner_owned)
123 .collect())
124 })
125 .await
126 }
127
128 /// Create [`ExecutionContext`] from its inner object.
129 fn from_inner_owned(inner: InnerExecutionContext<'static>) -> Self {
130 Self { inner }
131 }
132}
133
134impl<'engine> ExecutionContext<'engine> {
135 /// Create [`ExecutionContext`] from its inner object.
136 fn from_inner(inner: InnerExecutionContext<'engine>) -> Self {
137 Self { inner }
138 }
139
140 /// Create an execution context from an [`Engine`].
141 ///
142 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#ac7a34cf3b59aa633a35f66f07f22a617)
143 ///
144 /// # Arguments
145 ///
146 /// * `engine` - Parent engine.
147 pub async fn new(engine: &mut Engine) -> Result<ExecutionContext> {
148 Future::new(move || {
149 InnerExecutionContext::new(&mut engine.inner).map(ExecutionContext::from_inner)
150 })
151 .await
152 }
153
154 /// Asynchronously execute inference.
155 ///
156 /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a63cd95430852038ce864e17c670e0b36)
157 ///
158 /// # Stream ordered semantics
159 ///
160 /// This function exhibits stream ordered semantics. This means that it is only guaranteed to
161 /// complete serially with respect to other operations on the same stream.
162 ///
163 /// # Thread-safety
164 ///
165 /// Calling this function from the same context with a different CUDA stream concurrently
166 /// results in undefined behavior. To perform inference concurrently in multiple streams, use
167 /// one execution context per stream.
168 ///
169 /// # Arguments
170 ///
171 /// * `io_buffers` - Input and output buffers.
172 /// * `stream` - CUDA stream to execute on.
173 pub async fn enqueue<T: Copy>(
174 &mut self,
175 io_buffers: &mut std::collections::HashMap<&str, &mut DeviceBuffer<T>>,
176 stream: &Stream,
177 ) -> Result<()> {
178 let mut io_buffers_inner = io_buffers
179 .iter_mut()
180 .map(|(name, buffer)| (*name, buffer.inner_mut()))
181 .collect::<std::collections::HashMap<_, _>>();
182 Future::new(move || self.inner.enqueue(&mut io_buffers_inner, stream.inner())).await
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use crate::tests::memory::*;
189 use crate::tests::utils::*;
190
191 use super::*;
192
193 #[tokio::test]
194 async fn test_engine_serialize() {
195 let engine = simple_engine!();
196 let serialized_engine = engine.serialize().unwrap();
197 let serialized_engine_bytes = serialized_engine.as_bytes();
198 assert!(serialized_engine_bytes.len() > 0);
199 assert_eq!(
200 &serialized_engine_bytes[..8],
201 &[102_u8, 116_u8, 114_u8, 116_u8, 0_u8, 0_u8, 0_u8, 0_u8],
202 );
203 }
204
205 #[tokio::test]
206 async fn test_engine_tensor_info() {
207 let engine = simple_engine!();
208 assert_eq!(engine.num_io_tensors(), 2);
209 assert_eq!(engine.io_tensor_name(0), "X");
210 assert_eq!(engine.io_tensor_name(1), "Y");
211 assert_eq!(engine.tensor_io_mode("X"), TensorIoMode::Input);
212 assert_eq!(engine.tensor_io_mode("Y"), TensorIoMode::Output);
213 assert_eq!(engine.tensor_shape("X"), &[1, 2]);
214 assert_eq!(engine.tensor_shape("Y"), &[2, 3]);
215 }
216
217 #[tokio::test]
218 async fn test_execution_context_new() {
219 let mut engine = simple_engine!();
220 assert!(ExecutionContext::new(&mut engine).await.is_ok());
221 assert!(ExecutionContext::new(&mut engine).await.is_ok());
222 }
223
224 #[tokio::test]
225 async fn test_execution_context_enqueue() {
226 let stream = Stream::new().await.unwrap();
227 let mut engine = simple_engine!();
228 let mut context = ExecutionContext::new(&mut engine).await.unwrap();
229 let mut io_buffers = std::collections::HashMap::from([
230 ("X", to_device!(&[2.0, 4.0], &stream)),
231 ("Y", to_device!(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &stream)),
232 ]);
233 let mut io_buffers_ref = io_buffers
234 .iter_mut()
235 .map(|(name, buffer)| (*name, buffer))
236 .collect();
237 context.enqueue(&mut io_buffers_ref, &stream).await.unwrap();
238 let output = to_host!(io_buffers["Y"], &stream);
239 assert_eq!(&output, &[2.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
240 }
241}