Skip to main content

oar_ocr_core/core/config/
onnx.rs

1//! ONNX Runtime configuration types and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Graph optimization levels for ONNX Runtime.
6///
7/// This enum represents the different levels of graph optimization that can be applied
8/// during ONNX Runtime session creation.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
10pub enum OrtGraphOptimizationLevel {
11    /// Disable all optimizations.
12    DisableAll,
13    /// Enable basic optimizations.
14    #[default]
15    Level1,
16    /// Enable extended optimizations.
17    Level2,
18    /// Enable all optimizations.
19    Level3,
20    /// Enable all optimizations (alias for Level3).
21    All,
22}
23
24/// Execution providers for ONNX Runtime.
25///
26/// This enum represents the different execution providers that can be used
27/// with ONNX Runtime for model inference.
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
29pub enum OrtExecutionProvider {
30    /// CPU execution provider (always available)
31    #[default]
32    CPU,
33    /// NVIDIA CUDA execution provider
34    CUDA {
35        /// CUDA device ID (default: 0)
36        device_id: Option<i32>,
37        /// Memory limit in bytes (optional)
38        gpu_mem_limit: Option<usize>,
39        /// Arena extend strategy: "NextPowerOfTwo" or "SameAsRequested"
40        arena_extend_strategy: Option<String>,
41        /// CUDNN convolution algorithm search: "Exhaustive", "Heuristic", or "Default"
42        cudnn_conv_algo_search: Option<String>,
43        /// CUDNN convolution use max workspace (default: true)
44        cudnn_conv_use_max_workspace: Option<bool>,
45    },
46    /// DirectML execution provider (Windows only)
47    DirectML {
48        /// DirectML device ID (default: 0)
49        device_id: Option<i32>,
50    },
51    /// OpenVINO execution provider
52    OpenVINO {
53        /// Device type (e.g., "CPU", "GPU", "MYRIAD")
54        device_type: Option<String>,
55        /// Number of threads (optional)
56        num_threads: Option<usize>,
57    },
58    /// TensorRT execution provider
59    TensorRT {
60        /// TensorRT device ID (default: 0)
61        device_id: Option<i32>,
62        /// Maximum workspace size in bytes
63        max_workspace_size: Option<usize>,
64        /// Minimum subgraph size for TensorRT acceleration
65        min_subgraph_size: Option<usize>,
66        /// FP16 enable flag
67        fp16_enable: Option<bool>,
68        /// Enable use of timing cache to speed up builds
69        timing_cache: Option<bool>,
70        /// Set path for storing timing cache
71        timing_cache_path: Option<String>,
72        /// Force use of timing cache regardless of GPU match
73        force_timing_cache: Option<bool>,
74        /// Enable caching of TensorRT engines
75        engine_cache: Option<bool>,
76        /// Set path to store cached TensorRT engines
77        engine_cache_path: Option<String>,
78        /// Dump ep context model
79        dump_ep_context_model: Option<bool>,
80        /// The path of an embedded engine model
81        ep_context_file_path: Option<String>,
82    },
83    /// CoreML execution provider (macOS/iOS only)
84    CoreML {
85        /// Use Apple Neural Engine only
86        ane_only: Option<bool>,
87        /// Enable subgraphs
88        subgraphs: Option<bool>,
89    },
90    /// WebGPU execution provider
91    WebGPU,
92}
93
94/// Configuration for ONNX Runtime sessions.
95///
96/// This struct contains various configuration options for ONNX Runtime sessions,
97/// including threading, memory management, and optimization settings.
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct OrtSessionConfig {
100    /// Number of threads used to parallelize execution within nodes
101    pub intra_threads: Option<usize>,
102    /// Number of threads used to parallelize execution across nodes
103    pub inter_threads: Option<usize>,
104    /// Enable parallel execution mode
105    pub parallel_execution: Option<bool>,
106    /// Graph optimization level
107    pub optimization_level: Option<OrtGraphOptimizationLevel>,
108    /// Execution providers in order of preference
109    pub execution_providers: Option<Vec<OrtExecutionProvider>>,
110    /// Enable memory pattern optimization
111    pub enable_mem_pattern: Option<bool>,
112    /// Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal)
113    pub log_severity_level: Option<i32>,
114    /// Log verbosity level
115    pub log_verbosity_level: Option<i32>,
116    /// Session configuration entries (key-value pairs)
117    pub session_config_entries: Option<std::collections::HashMap<String, String>>,
118}
119
120impl OrtSessionConfig {
121    /// Creates a new OrtSessionConfig with default values.
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    /// Sets the number of intra-op threads.
127    ///
128    /// # Arguments
129    ///
130    /// * `threads` - Number of threads for intra-op parallelism.
131    ///
132    /// # Returns
133    ///
134    /// Self for method chaining.
135    pub fn with_intra_threads(mut self, threads: usize) -> Self {
136        self.intra_threads = Some(threads);
137        self
138    }
139
140    /// Sets the number of inter-op threads.
141    ///
142    /// # Arguments
143    ///
144    /// * `threads` - Number of threads for inter-op parallelism.
145    ///
146    /// # Returns
147    ///
148    /// Self for method chaining.
149    pub fn with_inter_threads(mut self, threads: usize) -> Self {
150        self.inter_threads = Some(threads);
151        self
152    }
153
154    /// Enables or disables parallel execution.
155    ///
156    /// # Arguments
157    ///
158    /// * `enabled` - Whether to enable parallel execution.
159    ///
160    /// # Returns
161    ///
162    /// Self for method chaining.
163    pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
164        self.parallel_execution = Some(enabled);
165        self
166    }
167
168    /// Sets the graph optimization level.
169    ///
170    /// # Arguments
171    ///
172    /// * `level` - The optimization level to use.
173    ///
174    /// # Returns
175    ///
176    /// Self for method chaining.
177    pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
178        self.optimization_level = Some(level);
179        self
180    }
181
182    /// Sets the execution providers.
183    ///
184    /// # Arguments
185    ///
186    /// * `providers` - Vector of execution providers in order of preference.
187    ///
188    /// # Returns
189    ///
190    /// Self for method chaining.
191    pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
192        self.execution_providers = Some(providers);
193        self
194    }
195
196    /// Adds a single execution provider.
197    ///
198    /// # Arguments
199    ///
200    /// * `provider` - The execution provider to add.
201    ///
202    /// # Returns
203    ///
204    /// Self for method chaining.
205    pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
206        if let Some(ref mut providers) = self.execution_providers {
207            providers.push(provider);
208        } else {
209            self.execution_providers = Some(vec![provider]);
210        }
211        self
212    }
213
214    /// Enables or disables memory pattern optimization.
215    ///
216    /// # Arguments
217    ///
218    /// * `enable` - Whether to enable memory pattern optimization.
219    ///
220    /// # Returns
221    ///
222    /// Self for method chaining.
223    pub fn with_memory_pattern(mut self, enable: bool) -> Self {
224        self.enable_mem_pattern = Some(enable);
225        self
226    }
227
228    /// Sets the log severity level.
229    ///
230    /// # Arguments
231    ///
232    /// * `level` - Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal).
233    ///
234    /// # Returns
235    ///
236    /// Self for method chaining.
237    pub fn with_log_severity_level(mut self, level: i32) -> Self {
238        self.log_severity_level = Some(level);
239        self
240    }
241
242    /// Sets the log verbosity level.
243    ///
244    /// # Arguments
245    ///
246    /// * `level` - Log verbosity level.
247    ///
248    /// # Returns
249    ///
250    /// Self for method chaining.
251    pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
252        self.log_verbosity_level = Some(level);
253        self
254    }
255
256    /// Adds a session configuration entry.
257    ///
258    /// # Arguments
259    ///
260    /// * `key` - Configuration key.
261    /// * `value` - Configuration value.
262    ///
263    /// # Returns
264    ///
265    /// Self for method chaining.
266    pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
267        if let Some(ref mut entries) = self.session_config_entries {
268            entries.insert(key.into(), value.into());
269        } else {
270            let mut entries = std::collections::HashMap::new();
271            entries.insert(key.into(), value.into());
272            self.session_config_entries = Some(entries);
273        }
274        self
275    }
276
277    /// Gets the effective number of intra-op threads.
278    ///
279    /// # Returns
280    ///
281    /// The number of intra-op threads, or a default value if not set.
282    pub fn get_intra_threads(&self) -> usize {
283        self.intra_threads.unwrap_or_else(|| {
284            std::thread::available_parallelism()
285                .map(|n| n.get())
286                .unwrap_or(1)
287        })
288    }
289
290    /// Gets the effective number of inter-op threads.
291    ///
292    /// # Returns
293    ///
294    /// The number of inter-op threads, or a default value if not set.
295    pub fn get_inter_threads(&self) -> usize {
296        self.inter_threads.unwrap_or(1)
297    }
298
299    /// Gets the effective graph optimization level.
300    ///
301    /// # Returns
302    ///
303    /// The graph optimization level, or a default value if not set.
304    pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
305        self.optimization_level.unwrap_or_default()
306    }
307
308    /// Gets the execution providers.
309    ///
310    /// # Returns
311    ///
312    /// A reference to the execution providers, or a default CPU provider if not set.
313    pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
314        self.execution_providers
315            .clone()
316            .unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_ort_session_config_builder() {
326        let config = OrtSessionConfig::new()
327            .with_intra_threads(4)
328            .with_inter_threads(2)
329            .with_optimization_level(OrtGraphOptimizationLevel::Level2)
330            .with_memory_pattern(true)
331            .add_execution_provider(OrtExecutionProvider::CPU);
332
333        assert_eq!(config.intra_threads, Some(4));
334        assert_eq!(config.inter_threads, Some(2));
335        assert!(matches!(
336            config.optimization_level,
337            Some(OrtGraphOptimizationLevel::Level2)
338        ));
339        assert_eq!(config.enable_mem_pattern, Some(true));
340        assert!(config.execution_providers.is_some());
341    }
342
343    #[test]
344    fn test_ort_session_config_getters() {
345        let config = OrtSessionConfig::new()
346            .with_intra_threads(8)
347            .with_inter_threads(4)
348            .with_optimization_level(OrtGraphOptimizationLevel::All);
349
350        assert_eq!(config.get_intra_threads(), 8);
351        assert_eq!(config.get_inter_threads(), 4);
352        assert!(matches!(
353            config.get_optimization_level(),
354            OrtGraphOptimizationLevel::All
355        ));
356    }
357}