birdnet_onnx/tensorrt_config.rs
1//! `TensorRT` execution provider configuration
2//!
3//! This module provides fine-grained control over `TensorRT` optimization settings.
4//! For most users, the default [`crate::ClassifierBuilder::with_tensorrt()`] method provides
5//! optimal performance. Use [`TensorRTConfig`] when you need custom settings.
6//!
7//! # Performance Notes
8//!
9//! The default configuration enables:
10//! - **FP16 precision**: 2x faster inference on GPUs with tensor cores
11//! - **CUDA graphs**: Reduced CPU launch overhead for models with many small layers
12//! - **Engine caching**: Reduces session creation from minutes to seconds
13//! - **Timing cache**: Accelerates future builds with similar layer configurations
14//!
15//! # Example
16//!
17//! ```no_run
18//! use birdnet_onnx::{Classifier, TensorRTConfig};
19//!
20//! let config = TensorRTConfig::new()
21//! .with_fp16(false)
22//! .with_builder_optimization_level(5)
23//! .with_engine_cache_path("/tmp/trt_cache");
24//!
25//! let classifier = Classifier::builder()
26//! .model_path("model.onnx")
27//! .labels_path("labels.txt")
28//! .with_tensorrt_config(config)
29//! .build()?;
30//! # Ok::<(), birdnet_onnx::Error>(())
31//! ```
32
33/// Configuration for `TensorRT` execution provider
34///
35/// This struct provides fine-grained control over `TensorRT` optimization settings.
36/// For most users, the default [`crate::ClassifierBuilder::with_tensorrt()`] method provides
37/// optimal performance.
38///
39/// # Performance Notes
40///
41/// The default configuration enables:
42/// - **FP16 precision**: 2x faster inference on GPUs with tensor cores
43/// - **CUDA graphs**: Reduced CPU launch overhead for models with many small layers
44/// - **Engine caching**: Reduces session creation from minutes to seconds
45/// - **Timing cache**: Accelerates future builds with similar layer configurations
46///
47/// # Example: Custom Configuration
48///
49/// ```no_run
50/// use birdnet_onnx::TensorRTConfig;
51///
52/// let config = TensorRTConfig::new()
53/// .with_fp16(false)
54/// .with_builder_optimization_level(5)
55/// .with_engine_cache_path("/tmp/trt_cache")
56/// .with_device_id(1); // Use second GPU
57/// # let _ = config;
58/// ```
59///
60/// # Example: Disable Optimizations
61///
62/// ```no_run
63/// use birdnet_onnx::TensorRTConfig;
64///
65/// let config = TensorRTConfig::new()
66/// .with_fp16(false)
67/// .with_cuda_graph(false)
68/// .with_engine_cache(false)
69/// .with_timing_cache(false);
70/// # let _ = config;
71/// ```
72#[derive(Debug, Clone)]
73pub struct TensorRTConfig {
74 // Performance options
75 fp16: Option<bool>,
76 int8: Option<bool>,
77 cuda_graph: Option<bool>,
78 builder_optimization_level: Option<u8>,
79
80 // Caching options
81 engine_cache: Option<bool>,
82 engine_cache_path: Option<String>,
83 timing_cache: Option<bool>,
84 timing_cache_path: Option<String>,
85
86 // Hardware options
87 device_id: Option<i32>,
88 max_workspace_size: Option<usize>,
89
90 // Advanced options
91 min_subgraph_size: Option<usize>,
92 layer_norm_fp32_fallback: Option<bool>,
93}
94
95impl Default for TensorRTConfig {
96 fn default() -> Self {
97 Self {
98 // Performance defaults (same as with_tensorrt())
99 fp16: Some(true),
100 cuda_graph: Some(true),
101 engine_cache: Some(true),
102 timing_cache: Some(true),
103 builder_optimization_level: Some(3),
104
105 // Everything else None (uses TensorRT defaults)
106 int8: None,
107 engine_cache_path: None,
108 timing_cache_path: None,
109 device_id: None,
110 max_workspace_size: None,
111 min_subgraph_size: None,
112 layer_norm_fp32_fallback: None,
113 }
114 }
115}
116
117impl TensorRTConfig {
118 /// Create a new `TensorRT` configuration with optimized defaults
119 #[must_use]
120 pub const fn new() -> Self {
121 Self {
122 // Performance defaults
123 fp16: Some(true),
124 cuda_graph: Some(true),
125 engine_cache: Some(true),
126 timing_cache: Some(true),
127 builder_optimization_level: Some(3),
128
129 // Everything else None
130 int8: None,
131 engine_cache_path: None,
132 timing_cache_path: None,
133 device_id: None,
134 max_workspace_size: None,
135 min_subgraph_size: None,
136 layer_norm_fp32_fallback: None,
137 }
138 }
139
140 /// Enable or disable FP16 precision mode
141 ///
142 /// FP16 provides ~2x speedup on GPUs with tensor cores (Volta and newer).
143 /// Disable if you need full FP32 precision for accuracy-critical applications.
144 ///
145 /// Default: `true`
146 #[must_use]
147 pub const fn with_fp16(mut self, enable: bool) -> Self {
148 self.fp16 = Some(enable);
149 self
150 }
151
152 /// Enable or disable INT8 precision mode
153 ///
154 /// Requires calibration data. Provides additional speedup over FP16.
155 /// See `TensorRT` documentation for calibration requirements.
156 ///
157 /// Default: `false` (not enabled by default)
158 #[must_use]
159 pub const fn with_int8(mut self, enable: bool) -> Self {
160 self.int8 = Some(enable);
161 self
162 }
163
164 /// Enable or disable CUDA graph capture
165 ///
166 /// Reduces CPU launch overhead for models with many small layers.
167 /// Provides significant speedup by batching GPU operations.
168 ///
169 /// Default: `true`
170 #[must_use]
171 pub const fn with_cuda_graph(mut self, enable: bool) -> Self {
172 self.cuda_graph = Some(enable);
173 self
174 }
175
176 /// Set builder optimization level (0-5)
177 ///
178 /// Higher values take longer to build but may produce faster engines.
179 /// - Level 3 (default): Balanced optimization
180 /// - Level 5: Maximum optimization (longer build time)
181 /// - Level 0-2: Faster builds, may sacrifice performance
182 ///
183 /// Default: `3`
184 #[must_use]
185 pub const fn with_builder_optimization_level(mut self, level: u8) -> Self {
186 self.builder_optimization_level = Some(level);
187 self
188 }
189
190 /// Enable or disable engine caching
191 ///
192 /// Caches compiled `TensorRT` engines to disk, dramatically reducing
193 /// session creation time on subsequent runs (384s → 9s in benchmarks).
194 ///
195 /// **Important**: Clear cache when model, ONNX Runtime, or `TensorRT` version changes.
196 ///
197 /// Default: `true`
198 #[must_use]
199 pub const fn with_engine_cache(mut self, enable: bool) -> Self {
200 self.engine_cache = Some(enable);
201 self
202 }
203
204 /// Set custom path for engine cache
205 ///
206 /// By default, `TensorRT` uses system temp directory.
207 /// Set a custom path for persistent caching across system restarts.
208 ///
209 /// Default: None (uses `TensorRT` default)
210 #[must_use]
211 pub fn with_engine_cache_path(mut self, path: impl Into<String>) -> Self {
212 self.engine_cache_path = Some(path.into());
213 self
214 }
215
216 /// Enable or disable timing cache
217 ///
218 /// Stores kernel timing data to accelerate future builds with similar
219 /// layer configurations (34.6s → 7.7s in benchmarks).
220 ///
221 /// Default: `true`
222 #[must_use]
223 pub const fn with_timing_cache(mut self, enable: bool) -> Self {
224 self.timing_cache = Some(enable);
225 self
226 }
227
228 /// Set custom path for timing cache
229 ///
230 /// By default, `TensorRT` uses system temp directory.
231 ///
232 /// Default: None (uses `TensorRT` default)
233 #[must_use]
234 pub fn with_timing_cache_path(mut self, path: impl Into<String>) -> Self {
235 self.timing_cache_path = Some(path.into());
236 self
237 }
238
239 /// Set GPU device ID for multi-GPU systems
240 ///
241 /// Default: None (uses default GPU)
242 #[must_use]
243 pub const fn with_device_id(mut self, device_id: i32) -> Self {
244 self.device_id = Some(device_id);
245 self
246 }
247
248 /// Set maximum workspace size in bytes
249 ///
250 /// `TensorRT` may allocate up to this much GPU memory for optimization.
251 /// Larger values may enable more optimizations but use more memory.
252 ///
253 /// Default: None (uses `TensorRT` default)
254 #[must_use]
255 pub const fn with_max_workspace_size(mut self, max_size: usize) -> Self {
256 self.max_workspace_size = Some(max_size);
257 self
258 }
259
260 /// Set minimum subgraph size for `TensorRT` acceleration
261 ///
262 /// Subgraphs smaller than this will not be accelerated by `TensorRT`.
263 ///
264 /// Default: None (uses `TensorRT` default)
265 #[must_use]
266 pub const fn with_min_subgraph_size(mut self, min_size: usize) -> Self {
267 self.min_subgraph_size = Some(min_size);
268 self
269 }
270
271 /// Enable or disable FP32 fallback for layer normalization
272 ///
273 /// When enabled, layer norm operations use FP32 even in FP16 mode,
274 /// improving accuracy at slight performance cost.
275 ///
276 /// Default: None (uses `TensorRT` default)
277 #[must_use]
278 pub const fn with_layer_norm_fp32_fallback(mut self, enable: bool) -> Self {
279 self.layer_norm_fp32_fallback = Some(enable);
280 self
281 }
282
283 /// Apply configuration to a `TensorRT` execution provider
284 ///
285 /// This is an internal method used by `ClassifierBuilder::with_tensorrt_config()`.
286 pub(crate) fn apply_to(
287 self,
288 provider: ort::execution_providers::TensorRTExecutionProvider,
289 ) -> ort::execution_providers::TensorRTExecutionProvider {
290 let mut p = provider;
291
292 if let Some(v) = self.fp16 {
293 p = p.with_fp16(v);
294 }
295 if let Some(v) = self.int8 {
296 p = p.with_int8(v);
297 }
298 if let Some(v) = self.cuda_graph {
299 p = p.with_cuda_graph(v);
300 }
301 if let Some(v) = self.builder_optimization_level {
302 p = p.with_builder_optimization_level(v);
303 }
304 if let Some(v) = self.engine_cache {
305 p = p.with_engine_cache(v);
306 }
307 if let Some(path) = self.engine_cache_path {
308 p = p.with_engine_cache_path(path);
309 }
310 if let Some(v) = self.timing_cache {
311 p = p.with_timing_cache(v);
312 }
313 if let Some(path) = self.timing_cache_path {
314 p = p.with_timing_cache_path(path);
315 }
316 if let Some(id) = self.device_id {
317 p = p.with_device_id(id);
318 }
319 if let Some(size) = self.max_workspace_size {
320 p = p.with_max_workspace_size(size);
321 }
322 if let Some(size) = self.min_subgraph_size {
323 p = p.with_min_subgraph_size(size);
324 }
325 if let Some(v) = self.layer_norm_fp32_fallback {
326 p = p.with_layer_norm_fp32_fallback(v);
327 }
328
329 p
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 #![allow(clippy::disallowed_methods)]
336 use super::*;
337
338 #[test]
339 fn test_tensorrt_config_default() {
340 let config = TensorRTConfig::default();
341 assert_eq!(config.fp16, Some(true));
342 assert_eq!(config.cuda_graph, Some(true));
343 assert_eq!(config.engine_cache, Some(true));
344 assert_eq!(config.timing_cache, Some(true));
345 assert_eq!(config.builder_optimization_level, Some(3));
346 assert_eq!(config.int8, None);
347 assert_eq!(config.engine_cache_path, None);
348 }
349
350 #[test]
351 fn test_tensorrt_config_new() {
352 let config = TensorRTConfig::new();
353 assert_eq!(config.fp16, Some(true));
354 assert_eq!(config.cuda_graph, Some(true));
355 assert_eq!(config.engine_cache, Some(true));
356 assert_eq!(config.timing_cache, Some(true));
357 assert_eq!(config.builder_optimization_level, Some(3));
358 }
359
360 #[test]
361 fn test_tensorrt_config_builder_pattern() {
362 let config = TensorRTConfig::new()
363 .with_fp16(false)
364 .with_device_id(1)
365 .with_max_workspace_size(1_000_000_000);
366
367 assert_eq!(config.fp16, Some(false));
368 assert_eq!(config.device_id, Some(1));
369 assert_eq!(config.max_workspace_size, Some(1_000_000_000));
370 }
371
372 #[test]
373 fn test_tensorrt_config_disable_all_optimizations() {
374 let config = TensorRTConfig::new()
375 .with_fp16(false)
376 .with_cuda_graph(false)
377 .with_engine_cache(false)
378 .with_timing_cache(false);
379
380 assert_eq!(config.fp16, Some(false));
381 assert_eq!(config.cuda_graph, Some(false));
382 assert_eq!(config.engine_cache, Some(false));
383 assert_eq!(config.timing_cache, Some(false));
384 }
385
386 #[test]
387 fn test_tensorrt_config_cache_paths() {
388 let config = TensorRTConfig::new()
389 .with_engine_cache_path("/tmp/engines")
390 .with_timing_cache_path("/tmp/timing");
391
392 assert_eq!(config.engine_cache_path, Some("/tmp/engines".to_string()));
393 assert_eq!(config.timing_cache_path, Some("/tmp/timing".to_string()));
394 }
395
396 #[test]
397 fn test_tensorrt_config_optimization_levels() {
398 let config0 = TensorRTConfig::new().with_builder_optimization_level(0);
399 let config5 = TensorRTConfig::new().with_builder_optimization_level(5);
400
401 assert_eq!(config0.builder_optimization_level, Some(0));
402 assert_eq!(config5.builder_optimization_level, Some(5));
403 }
404
405 #[test]
406 fn test_tensorrt_config_int8() {
407 let config = TensorRTConfig::new().with_int8(true);
408 assert_eq!(config.int8, Some(true));
409 }
410
411 #[test]
412 fn test_tensorrt_config_layer_norm_fallback() {
413 let config = TensorRTConfig::new().with_layer_norm_fp32_fallback(true);
414 assert_eq!(config.layer_norm_fp32_fallback, Some(true));
415 }
416
417 #[test]
418 fn test_tensorrt_config_min_subgraph_size() {
419 let config = TensorRTConfig::new().with_min_subgraph_size(5);
420 assert_eq!(config.min_subgraph_size, Some(5));
421 }
422}