birdnet_onnx/
tensorrt_config.rs

1//! `TensorRT` execution provider configuration
2//!
3//! This module provides fine-grained control over `TensorRT` optimization settings.
4//! For most users, the default [`crate::ClassifierBuilder::with_tensorrt()`] method provides
5//! optimal performance. Use [`TensorRTConfig`] when you need custom settings.
6//!
7//! # Performance Notes
8//!
9//! The default configuration enables:
10//! - **FP16 precision**: 2x faster inference on GPUs with tensor cores
11//! - **CUDA graphs**: Reduced CPU launch overhead for models with many small layers
12//! - **Engine caching**: Reduces session creation from minutes to seconds
13//! - **Timing cache**: Accelerates future builds with similar layer configurations
14//!
15//! # Example
16//!
17//! ```no_run
18//! use birdnet_onnx::{Classifier, TensorRTConfig};
19//!
20//! let config = TensorRTConfig::new()
21//!     .with_fp16(false)
22//!     .with_builder_optimization_level(5)
23//!     .with_engine_cache_path("/tmp/trt_cache");
24//!
25//! let classifier = Classifier::builder()
26//!     .model_path("model.onnx")
27//!     .labels_path("labels.txt")
28//!     .with_tensorrt_config(config)
29//!     .build()?;
30//! # Ok::<(), birdnet_onnx::Error>(())
31//! ```
32
33/// Configuration for `TensorRT` execution provider
34///
35/// This struct provides fine-grained control over `TensorRT` optimization settings.
36/// For most users, the default [`crate::ClassifierBuilder::with_tensorrt()`] method provides
37/// optimal performance.
38///
39/// # Performance Notes
40///
41/// The default configuration enables:
42/// - **FP16 precision**: 2x faster inference on GPUs with tensor cores
43/// - **CUDA graphs**: Reduced CPU launch overhead for models with many small layers
44/// - **Engine caching**: Reduces session creation from minutes to seconds
45/// - **Timing cache**: Accelerates future builds with similar layer configurations
46///
47/// # Example: Custom Configuration
48///
49/// ```no_run
50/// use birdnet_onnx::TensorRTConfig;
51///
52/// let config = TensorRTConfig::new()
53///     .with_fp16(false)
54///     .with_builder_optimization_level(5)
55///     .with_engine_cache_path("/tmp/trt_cache")
56///     .with_device_id(1);  // Use second GPU
57/// # let _ = config;
58/// ```
59///
60/// # Example: Disable Optimizations
61///
62/// ```no_run
63/// use birdnet_onnx::TensorRTConfig;
64///
65/// let config = TensorRTConfig::new()
66///     .with_fp16(false)
67///     .with_cuda_graph(false)
68///     .with_engine_cache(false)
69///     .with_timing_cache(false);
70/// # let _ = config;
71/// ```
72#[derive(Debug, Clone)]
73pub struct TensorRTConfig {
74    // Performance options
75    fp16: Option<bool>,
76    int8: Option<bool>,
77    cuda_graph: Option<bool>,
78    builder_optimization_level: Option<u8>,
79
80    // Caching options
81    engine_cache: Option<bool>,
82    engine_cache_path: Option<String>,
83    timing_cache: Option<bool>,
84    timing_cache_path: Option<String>,
85
86    // Hardware options
87    device_id: Option<i32>,
88    max_workspace_size: Option<usize>,
89
90    // Advanced options
91    min_subgraph_size: Option<usize>,
92    layer_norm_fp32_fallback: Option<bool>,
93}
94
95impl Default for TensorRTConfig {
96    fn default() -> Self {
97        Self {
98            // Performance defaults (same as with_tensorrt())
99            fp16: Some(true),
100            cuda_graph: Some(true),
101            engine_cache: Some(true),
102            timing_cache: Some(true),
103            builder_optimization_level: Some(3),
104
105            // Everything else None (uses TensorRT defaults)
106            int8: None,
107            engine_cache_path: None,
108            timing_cache_path: None,
109            device_id: None,
110            max_workspace_size: None,
111            min_subgraph_size: None,
112            layer_norm_fp32_fallback: None,
113        }
114    }
115}
116
117impl TensorRTConfig {
118    /// Create a new `TensorRT` configuration with optimized defaults
119    #[must_use]
120    pub const fn new() -> Self {
121        Self {
122            // Performance defaults
123            fp16: Some(true),
124            cuda_graph: Some(true),
125            engine_cache: Some(true),
126            timing_cache: Some(true),
127            builder_optimization_level: Some(3),
128
129            // Everything else None
130            int8: None,
131            engine_cache_path: None,
132            timing_cache_path: None,
133            device_id: None,
134            max_workspace_size: None,
135            min_subgraph_size: None,
136            layer_norm_fp32_fallback: None,
137        }
138    }
139
140    /// Enable or disable FP16 precision mode
141    ///
142    /// FP16 provides ~2x speedup on GPUs with tensor cores (Volta and newer).
143    /// Disable if you need full FP32 precision for accuracy-critical applications.
144    ///
145    /// Default: `true`
146    #[must_use]
147    pub const fn with_fp16(mut self, enable: bool) -> Self {
148        self.fp16 = Some(enable);
149        self
150    }
151
152    /// Enable or disable INT8 precision mode
153    ///
154    /// Requires calibration data. Provides additional speedup over FP16.
155    /// See `TensorRT` documentation for calibration requirements.
156    ///
157    /// Default: `false` (not enabled by default)
158    #[must_use]
159    pub const fn with_int8(mut self, enable: bool) -> Self {
160        self.int8 = Some(enable);
161        self
162    }
163
164    /// Enable or disable CUDA graph capture
165    ///
166    /// Reduces CPU launch overhead for models with many small layers.
167    /// Provides significant speedup by batching GPU operations.
168    ///
169    /// Default: `true`
170    #[must_use]
171    pub const fn with_cuda_graph(mut self, enable: bool) -> Self {
172        self.cuda_graph = Some(enable);
173        self
174    }
175
176    /// Set builder optimization level (0-5)
177    ///
178    /// Higher values take longer to build but may produce faster engines.
179    /// - Level 3 (default): Balanced optimization
180    /// - Level 5: Maximum optimization (longer build time)
181    /// - Level 0-2: Faster builds, may sacrifice performance
182    ///
183    /// Default: `3`
184    #[must_use]
185    pub const fn with_builder_optimization_level(mut self, level: u8) -> Self {
186        self.builder_optimization_level = Some(level);
187        self
188    }
189
190    /// Enable or disable engine caching
191    ///
192    /// Caches compiled `TensorRT` engines to disk, dramatically reducing
193    /// session creation time on subsequent runs (384s → 9s in benchmarks).
194    ///
195    /// **Important**: Clear cache when model, ONNX Runtime, or `TensorRT` version changes.
196    ///
197    /// Default: `true`
198    #[must_use]
199    pub const fn with_engine_cache(mut self, enable: bool) -> Self {
200        self.engine_cache = Some(enable);
201        self
202    }
203
204    /// Set custom path for engine cache
205    ///
206    /// By default, `TensorRT` uses system temp directory.
207    /// Set a custom path for persistent caching across system restarts.
208    ///
209    /// Default: None (uses `TensorRT` default)
210    #[must_use]
211    pub fn with_engine_cache_path(mut self, path: impl Into<String>) -> Self {
212        self.engine_cache_path = Some(path.into());
213        self
214    }
215
216    /// Enable or disable timing cache
217    ///
218    /// Stores kernel timing data to accelerate future builds with similar
219    /// layer configurations (34.6s → 7.7s in benchmarks).
220    ///
221    /// Default: `true`
222    #[must_use]
223    pub const fn with_timing_cache(mut self, enable: bool) -> Self {
224        self.timing_cache = Some(enable);
225        self
226    }
227
228    /// Set custom path for timing cache
229    ///
230    /// By default, `TensorRT` uses system temp directory.
231    ///
232    /// Default: None (uses `TensorRT` default)
233    #[must_use]
234    pub fn with_timing_cache_path(mut self, path: impl Into<String>) -> Self {
235        self.timing_cache_path = Some(path.into());
236        self
237    }
238
239    /// Set GPU device ID for multi-GPU systems
240    ///
241    /// Default: None (uses default GPU)
242    #[must_use]
243    pub const fn with_device_id(mut self, device_id: i32) -> Self {
244        self.device_id = Some(device_id);
245        self
246    }
247
248    /// Set maximum workspace size in bytes
249    ///
250    /// `TensorRT` may allocate up to this much GPU memory for optimization.
251    /// Larger values may enable more optimizations but use more memory.
252    ///
253    /// Default: None (uses `TensorRT` default)
254    #[must_use]
255    pub const fn with_max_workspace_size(mut self, max_size: usize) -> Self {
256        self.max_workspace_size = Some(max_size);
257        self
258    }
259
260    /// Set minimum subgraph size for `TensorRT` acceleration
261    ///
262    /// Subgraphs smaller than this will not be accelerated by `TensorRT`.
263    ///
264    /// Default: None (uses `TensorRT` default)
265    #[must_use]
266    pub const fn with_min_subgraph_size(mut self, min_size: usize) -> Self {
267        self.min_subgraph_size = Some(min_size);
268        self
269    }
270
271    /// Enable or disable FP32 fallback for layer normalization
272    ///
273    /// When enabled, layer norm operations use FP32 even in FP16 mode,
274    /// improving accuracy at slight performance cost.
275    ///
276    /// Default: None (uses `TensorRT` default)
277    #[must_use]
278    pub const fn with_layer_norm_fp32_fallback(mut self, enable: bool) -> Self {
279        self.layer_norm_fp32_fallback = Some(enable);
280        self
281    }
282
283    /// Apply configuration to a `TensorRT` execution provider
284    ///
285    /// This is an internal method used by `ClassifierBuilder::with_tensorrt_config()`.
286    pub(crate) fn apply_to(
287        self,
288        provider: ort::execution_providers::TensorRTExecutionProvider,
289    ) -> ort::execution_providers::TensorRTExecutionProvider {
290        let mut p = provider;
291
292        if let Some(v) = self.fp16 {
293            p = p.with_fp16(v);
294        }
295        if let Some(v) = self.int8 {
296            p = p.with_int8(v);
297        }
298        if let Some(v) = self.cuda_graph {
299            p = p.with_cuda_graph(v);
300        }
301        if let Some(v) = self.builder_optimization_level {
302            p = p.with_builder_optimization_level(v);
303        }
304        if let Some(v) = self.engine_cache {
305            p = p.with_engine_cache(v);
306        }
307        if let Some(path) = self.engine_cache_path {
308            p = p.with_engine_cache_path(path);
309        }
310        if let Some(v) = self.timing_cache {
311            p = p.with_timing_cache(v);
312        }
313        if let Some(path) = self.timing_cache_path {
314            p = p.with_timing_cache_path(path);
315        }
316        if let Some(id) = self.device_id {
317            p = p.with_device_id(id);
318        }
319        if let Some(size) = self.max_workspace_size {
320            p = p.with_max_workspace_size(size);
321        }
322        if let Some(size) = self.min_subgraph_size {
323            p = p.with_min_subgraph_size(size);
324        }
325        if let Some(v) = self.layer_norm_fp32_fallback {
326            p = p.with_layer_norm_fp32_fallback(v);
327        }
328
329        p
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    #![allow(clippy::disallowed_methods)]
336    use super::*;
337
338    #[test]
339    fn test_tensorrt_config_default() {
340        let config = TensorRTConfig::default();
341        assert_eq!(config.fp16, Some(true));
342        assert_eq!(config.cuda_graph, Some(true));
343        assert_eq!(config.engine_cache, Some(true));
344        assert_eq!(config.timing_cache, Some(true));
345        assert_eq!(config.builder_optimization_level, Some(3));
346        assert_eq!(config.int8, None);
347        assert_eq!(config.engine_cache_path, None);
348    }
349
350    #[test]
351    fn test_tensorrt_config_new() {
352        let config = TensorRTConfig::new();
353        assert_eq!(config.fp16, Some(true));
354        assert_eq!(config.cuda_graph, Some(true));
355        assert_eq!(config.engine_cache, Some(true));
356        assert_eq!(config.timing_cache, Some(true));
357        assert_eq!(config.builder_optimization_level, Some(3));
358    }
359
360    #[test]
361    fn test_tensorrt_config_builder_pattern() {
362        let config = TensorRTConfig::new()
363            .with_fp16(false)
364            .with_device_id(1)
365            .with_max_workspace_size(1_000_000_000);
366
367        assert_eq!(config.fp16, Some(false));
368        assert_eq!(config.device_id, Some(1));
369        assert_eq!(config.max_workspace_size, Some(1_000_000_000));
370    }
371
372    #[test]
373    fn test_tensorrt_config_disable_all_optimizations() {
374        let config = TensorRTConfig::new()
375            .with_fp16(false)
376            .with_cuda_graph(false)
377            .with_engine_cache(false)
378            .with_timing_cache(false);
379
380        assert_eq!(config.fp16, Some(false));
381        assert_eq!(config.cuda_graph, Some(false));
382        assert_eq!(config.engine_cache, Some(false));
383        assert_eq!(config.timing_cache, Some(false));
384    }
385
386    #[test]
387    fn test_tensorrt_config_cache_paths() {
388        let config = TensorRTConfig::new()
389            .with_engine_cache_path("/tmp/engines")
390            .with_timing_cache_path("/tmp/timing");
391
392        assert_eq!(config.engine_cache_path, Some("/tmp/engines".to_string()));
393        assert_eq!(config.timing_cache_path, Some("/tmp/timing".to_string()));
394    }
395
396    #[test]
397    fn test_tensorrt_config_optimization_levels() {
398        let config0 = TensorRTConfig::new().with_builder_optimization_level(0);
399        let config5 = TensorRTConfig::new().with_builder_optimization_level(5);
400
401        assert_eq!(config0.builder_optimization_level, Some(0));
402        assert_eq!(config5.builder_optimization_level, Some(5));
403    }
404
405    #[test]
406    fn test_tensorrt_config_int8() {
407        let config = TensorRTConfig::new().with_int8(true);
408        assert_eq!(config.int8, Some(true));
409    }
410
411    #[test]
412    fn test_tensorrt_config_layer_norm_fallback() {
413        let config = TensorRTConfig::new().with_layer_norm_fp32_fallback(true);
414        assert_eq!(config.layer_norm_fp32_fallback, Some(true));
415    }
416
417    #[test]
418    fn test_tensorrt_config_min_subgraph_size() {
419        let config = TensorRTConfig::new().with_min_subgraph_size(5);
420        assert_eq!(config.min_subgraph_size, Some(5));
421    }
422}