oar_ocr_core/core/config/onnx.rs
1//! ONNX Runtime configuration types and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Graph optimization levels for ONNX Runtime.
6///
7/// This enum represents the different levels of graph optimization that can be applied
8/// during ONNX Runtime session creation.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
10pub enum OrtGraphOptimizationLevel {
11 /// Disable all optimizations.
12 DisableAll,
13 /// Enable basic optimizations.
14 #[default]
15 Level1,
16 /// Enable extended optimizations.
17 Level2,
18 /// Enable all optimizations.
19 Level3,
20 /// Enable all optimizations (alias for Level3).
21 All,
22}
23
24/// Execution providers for ONNX Runtime.
25///
26/// This enum represents the different execution providers that can be used
27/// with ONNX Runtime for model inference.
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
29pub enum OrtExecutionProvider {
30 /// CPU execution provider (always available)
31 #[default]
32 CPU,
33 /// NVIDIA CUDA execution provider
34 CUDA {
35 /// CUDA device ID (default: 0)
36 device_id: Option<i32>,
37 /// Memory limit in bytes (optional)
38 gpu_mem_limit: Option<usize>,
39 /// Arena extend strategy: "NextPowerOfTwo" or "SameAsRequested"
40 arena_extend_strategy: Option<String>,
41 /// CUDNN convolution algorithm search: "Exhaustive", "Heuristic", or "Default"
42 cudnn_conv_algo_search: Option<String>,
43 /// CUDNN convolution use max workspace (default: true)
44 cudnn_conv_use_max_workspace: Option<bool>,
45 },
46 /// DirectML execution provider (Windows only)
47 DirectML {
48 /// DirectML device ID (default: 0)
49 device_id: Option<i32>,
50 },
51 /// OpenVINO execution provider
52 OpenVINO {
53 /// Device type (e.g., "CPU", "GPU", "MYRIAD")
54 device_type: Option<String>,
55 /// Number of threads (optional)
56 num_threads: Option<usize>,
57 },
58 /// TensorRT execution provider
59 TensorRT {
60 /// TensorRT device ID (default: 0)
61 device_id: Option<i32>,
62 /// Maximum workspace size in bytes
63 max_workspace_size: Option<usize>,
64 /// Minimum subgraph size for TensorRT acceleration
65 min_subgraph_size: Option<usize>,
66 /// FP16 enable flag
67 fp16_enable: Option<bool>,
68 /// Enable use of timing cache to speed up builds
69 timing_cache: Option<bool>,
70 /// Set path for storing timing cache
71 timing_cache_path: Option<String>,
72 /// Force use of timing cache regardless of GPU match
73 force_timing_cache: Option<bool>,
74 /// Enable caching of TensorRT engines
75 engine_cache: Option<bool>,
76 /// Set path to store cached TensorRT engines
77 engine_cache_path: Option<String>,
78 /// Dump ep context model
79 dump_ep_context_model: Option<bool>,
80 /// The path of an embedded engine model
81 ep_context_file_path: Option<String>,
82 },
83 /// CoreML execution provider (macOS/iOS only)
84 CoreML {
85 /// Use Apple Neural Engine only
86 ane_only: Option<bool>,
87 /// Enable subgraphs
88 subgraphs: Option<bool>,
89 },
90 /// WebGPU execution provider
91 WebGPU,
92}
93
94/// Configuration for ONNX Runtime sessions.
95///
96/// This struct contains various configuration options for ONNX Runtime sessions,
97/// including threading, memory management, and optimization settings.
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct OrtSessionConfig {
100 /// Number of threads used to parallelize execution within nodes
101 pub intra_threads: Option<usize>,
102 /// Number of threads used to parallelize execution across nodes
103 pub inter_threads: Option<usize>,
104 /// Enable parallel execution mode
105 pub parallel_execution: Option<bool>,
106 /// Graph optimization level
107 pub optimization_level: Option<OrtGraphOptimizationLevel>,
108 /// Execution providers in order of preference
109 pub execution_providers: Option<Vec<OrtExecutionProvider>>,
110 /// Enable memory pattern optimization
111 pub enable_mem_pattern: Option<bool>,
112 /// Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal)
113 pub log_severity_level: Option<i32>,
114 /// Log verbosity level
115 pub log_verbosity_level: Option<i32>,
116 /// Session configuration entries (key-value pairs)
117 pub session_config_entries: Option<std::collections::HashMap<String, String>>,
118}
119
120impl OrtSessionConfig {
121 /// Creates a new OrtSessionConfig with default values.
122 pub fn new() -> Self {
123 Self::default()
124 }
125
126 /// Sets the number of intra-op threads.
127 ///
128 /// # Arguments
129 ///
130 /// * `threads` - Number of threads for intra-op parallelism.
131 ///
132 /// # Returns
133 ///
134 /// Self for method chaining.
135 pub fn with_intra_threads(mut self, threads: usize) -> Self {
136 self.intra_threads = Some(threads);
137 self
138 }
139
140 /// Sets the number of inter-op threads.
141 ///
142 /// # Arguments
143 ///
144 /// * `threads` - Number of threads for inter-op parallelism.
145 ///
146 /// # Returns
147 ///
148 /// Self for method chaining.
149 pub fn with_inter_threads(mut self, threads: usize) -> Self {
150 self.inter_threads = Some(threads);
151 self
152 }
153
154 /// Enables or disables parallel execution.
155 ///
156 /// # Arguments
157 ///
158 /// * `enabled` - Whether to enable parallel execution.
159 ///
160 /// # Returns
161 ///
162 /// Self for method chaining.
163 pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
164 self.parallel_execution = Some(enabled);
165 self
166 }
167
168 /// Sets the graph optimization level.
169 ///
170 /// # Arguments
171 ///
172 /// * `level` - The optimization level to use.
173 ///
174 /// # Returns
175 ///
176 /// Self for method chaining.
177 pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
178 self.optimization_level = Some(level);
179 self
180 }
181
182 /// Sets the execution providers.
183 ///
184 /// # Arguments
185 ///
186 /// * `providers` - Vector of execution providers in order of preference.
187 ///
188 /// # Returns
189 ///
190 /// Self for method chaining.
191 pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
192 self.execution_providers = Some(providers);
193 self
194 }
195
196 /// Adds a single execution provider.
197 ///
198 /// # Arguments
199 ///
200 /// * `provider` - The execution provider to add.
201 ///
202 /// # Returns
203 ///
204 /// Self for method chaining.
205 pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
206 if let Some(ref mut providers) = self.execution_providers {
207 providers.push(provider);
208 } else {
209 self.execution_providers = Some(vec![provider]);
210 }
211 self
212 }
213
214 /// Enables or disables memory pattern optimization.
215 ///
216 /// # Arguments
217 ///
218 /// * `enable` - Whether to enable memory pattern optimization.
219 ///
220 /// # Returns
221 ///
222 /// Self for method chaining.
223 pub fn with_memory_pattern(mut self, enable: bool) -> Self {
224 self.enable_mem_pattern = Some(enable);
225 self
226 }
227
228 /// Sets the log severity level.
229 ///
230 /// # Arguments
231 ///
232 /// * `level` - Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal).
233 ///
234 /// # Returns
235 ///
236 /// Self for method chaining.
237 pub fn with_log_severity_level(mut self, level: i32) -> Self {
238 self.log_severity_level = Some(level);
239 self
240 }
241
242 /// Sets the log verbosity level.
243 ///
244 /// # Arguments
245 ///
246 /// * `level` - Log verbosity level.
247 ///
248 /// # Returns
249 ///
250 /// Self for method chaining.
251 pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
252 self.log_verbosity_level = Some(level);
253 self
254 }
255
256 /// Adds a session configuration entry.
257 ///
258 /// # Arguments
259 ///
260 /// * `key` - Configuration key.
261 /// * `value` - Configuration value.
262 ///
263 /// # Returns
264 ///
265 /// Self for method chaining.
266 pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
267 if let Some(ref mut entries) = self.session_config_entries {
268 entries.insert(key.into(), value.into());
269 } else {
270 let mut entries = std::collections::HashMap::new();
271 entries.insert(key.into(), value.into());
272 self.session_config_entries = Some(entries);
273 }
274 self
275 }
276
277 /// Gets the effective number of intra-op threads.
278 ///
279 /// # Returns
280 ///
281 /// The number of intra-op threads, or a default value if not set.
282 pub fn get_intra_threads(&self) -> usize {
283 self.intra_threads.unwrap_or_else(|| {
284 std::thread::available_parallelism()
285 .map(|n| n.get())
286 .unwrap_or(1)
287 })
288 }
289
290 /// Gets the effective number of inter-op threads.
291 ///
292 /// # Returns
293 ///
294 /// The number of inter-op threads, or a default value if not set.
295 pub fn get_inter_threads(&self) -> usize {
296 self.inter_threads.unwrap_or(1)
297 }
298
299 /// Gets the effective graph optimization level.
300 ///
301 /// # Returns
302 ///
303 /// The graph optimization level, or a default value if not set.
304 pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
305 self.optimization_level.unwrap_or_default()
306 }
307
308 /// Gets the execution providers.
309 ///
310 /// # Returns
311 ///
312 /// A reference to the execution providers, or a default CPU provider if not set.
313 pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
314 self.execution_providers
315 .clone()
316 .unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_ort_session_config_builder() {
326 let config = OrtSessionConfig::new()
327 .with_intra_threads(4)
328 .with_inter_threads(2)
329 .with_optimization_level(OrtGraphOptimizationLevel::Level2)
330 .with_memory_pattern(true)
331 .add_execution_provider(OrtExecutionProvider::CPU);
332
333 assert_eq!(config.intra_threads, Some(4));
334 assert_eq!(config.inter_threads, Some(2));
335 assert!(matches!(
336 config.optimization_level,
337 Some(OrtGraphOptimizationLevel::Level2)
338 ));
339 assert_eq!(config.enable_mem_pattern, Some(true));
340 assert!(config.execution_providers.is_some());
341 }
342
343 #[test]
344 fn test_ort_session_config_getters() {
345 let config = OrtSessionConfig::new()
346 .with_intra_threads(8)
347 .with_inter_threads(4)
348 .with_optimization_level(OrtGraphOptimizationLevel::All);
349
350 assert_eq!(config.get_intra_threads(), 8);
351 assert_eq!(config.get_inter_threads(), 4);
352 assert!(matches!(
353 config.get_optimization_level(),
354 OrtGraphOptimizationLevel::All
355 ));
356 }
357}