oar_ocr_core/core/config/onnx.rs
1//! ONNX Runtime configuration types and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Graph optimization levels for ONNX Runtime.
6///
7/// This enum represents the different levels of graph optimization that can be applied
8/// during ONNX Runtime session creation.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
10pub enum OrtGraphOptimizationLevel {
11 /// Disable all optimizations.
12 DisableAll,
13 /// Enable basic optimizations.
14 #[default]
15 Level1,
16 /// Enable extended optimizations.
17 Level2,
18 /// Enable all optimizations.
19 Level3,
20 /// Enable all optimizations (alias for Level3).
21 All,
22}
23
24/// Execution providers for ONNX Runtime.
25///
26/// This enum represents the different execution providers that can be used
27/// with ONNX Runtime for model inference.
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
29pub enum OrtExecutionProvider {
30 /// CPU execution provider (always available)
31 #[default]
32 CPU,
33 /// NVIDIA CUDA execution provider
34 CUDA {
35 /// CUDA device ID (default: 0)
36 device_id: Option<i32>,
37 /// Memory limit in bytes (optional)
38 gpu_mem_limit: Option<usize>,
39 /// Arena extend strategy: "NextPowerOfTwo" or "SameAsRequested"
40 arena_extend_strategy: Option<String>,
41 /// CUDNN convolution algorithm search: "Exhaustive", "Heuristic", or "Default"
42 cudnn_conv_algo_search: Option<String>,
43 /// CUDNN convolution use max workspace (default: true)
44 cudnn_conv_use_max_workspace: Option<bool>,
45 },
46 /// DirectML execution provider (Windows only)
47 DirectML {
48 /// DirectML device ID (default: 0)
49 device_id: Option<i32>,
50 },
51 /// OpenVINO execution provider
52 OpenVINO {
53 /// Device type (e.g., "CPU", "GPU", "MYRIAD")
54 device_type: Option<String>,
55 /// Number of threads (optional)
56 num_threads: Option<usize>,
57 },
58 /// TensorRT execution provider
59 TensorRT {
60 /// TensorRT device ID (default: 0)
61 device_id: Option<i32>,
62 /// Maximum workspace size in bytes
63 max_workspace_size: Option<usize>,
64 /// Minimum subgraph size for TensorRT acceleration
65 min_subgraph_size: Option<usize>,
66 /// FP16 enable flag
67 fp16_enable: Option<bool>,
68 },
69 /// CoreML execution provider (macOS/iOS only)
70 CoreML {
71 /// Use Apple Neural Engine only
72 ane_only: Option<bool>,
73 /// Enable subgraphs
74 subgraphs: Option<bool>,
75 },
76 /// WebGPU execution provider
77 WebGPU,
78}
79
80/// Configuration for ONNX Runtime sessions.
81///
82/// This struct contains various configuration options for ONNX Runtime sessions,
83/// including threading, memory management, and optimization settings.
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct OrtSessionConfig {
86 /// Number of threads used to parallelize execution within nodes
87 pub intra_threads: Option<usize>,
88 /// Number of threads used to parallelize execution across nodes
89 pub inter_threads: Option<usize>,
90 /// Enable parallel execution mode
91 pub parallel_execution: Option<bool>,
92 /// Graph optimization level
93 pub optimization_level: Option<OrtGraphOptimizationLevel>,
94 /// Execution providers in order of preference
95 pub execution_providers: Option<Vec<OrtExecutionProvider>>,
96 /// Enable memory pattern optimization
97 pub enable_mem_pattern: Option<bool>,
98 /// Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal)
99 pub log_severity_level: Option<i32>,
100 /// Log verbosity level
101 pub log_verbosity_level: Option<i32>,
102 /// Session configuration entries (key-value pairs)
103 pub session_config_entries: Option<std::collections::HashMap<String, String>>,
104}
105
106impl OrtSessionConfig {
107 /// Creates a new OrtSessionConfig with default values.
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 /// Sets the number of intra-op threads.
113 ///
114 /// # Arguments
115 ///
116 /// * `threads` - Number of threads for intra-op parallelism.
117 ///
118 /// # Returns
119 ///
120 /// Self for method chaining.
121 pub fn with_intra_threads(mut self, threads: usize) -> Self {
122 self.intra_threads = Some(threads);
123 self
124 }
125
126 /// Sets the number of inter-op threads.
127 ///
128 /// # Arguments
129 ///
130 /// * `threads` - Number of threads for inter-op parallelism.
131 ///
132 /// # Returns
133 ///
134 /// Self for method chaining.
135 pub fn with_inter_threads(mut self, threads: usize) -> Self {
136 self.inter_threads = Some(threads);
137 self
138 }
139
140 /// Enables or disables parallel execution.
141 ///
142 /// # Arguments
143 ///
144 /// * `enabled` - Whether to enable parallel execution.
145 ///
146 /// # Returns
147 ///
148 /// Self for method chaining.
149 pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
150 self.parallel_execution = Some(enabled);
151 self
152 }
153
154 /// Sets the graph optimization level.
155 ///
156 /// # Arguments
157 ///
158 /// * `level` - The optimization level to use.
159 ///
160 /// # Returns
161 ///
162 /// Self for method chaining.
163 pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
164 self.optimization_level = Some(level);
165 self
166 }
167
168 /// Sets the execution providers.
169 ///
170 /// # Arguments
171 ///
172 /// * `providers` - Vector of execution providers in order of preference.
173 ///
174 /// # Returns
175 ///
176 /// Self for method chaining.
177 pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
178 self.execution_providers = Some(providers);
179 self
180 }
181
182 /// Adds a single execution provider.
183 ///
184 /// # Arguments
185 ///
186 /// * `provider` - The execution provider to add.
187 ///
188 /// # Returns
189 ///
190 /// Self for method chaining.
191 pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
192 if let Some(ref mut providers) = self.execution_providers {
193 providers.push(provider);
194 } else {
195 self.execution_providers = Some(vec![provider]);
196 }
197 self
198 }
199
200 /// Enables or disables memory pattern optimization.
201 ///
202 /// # Arguments
203 ///
204 /// * `enable` - Whether to enable memory pattern optimization.
205 ///
206 /// # Returns
207 ///
208 /// Self for method chaining.
209 pub fn with_memory_pattern(mut self, enable: bool) -> Self {
210 self.enable_mem_pattern = Some(enable);
211 self
212 }
213
214 /// Sets the log severity level.
215 ///
216 /// # Arguments
217 ///
218 /// * `level` - Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal).
219 ///
220 /// # Returns
221 ///
222 /// Self for method chaining.
223 pub fn with_log_severity_level(mut self, level: i32) -> Self {
224 self.log_severity_level = Some(level);
225 self
226 }
227
228 /// Sets the log verbosity level.
229 ///
230 /// # Arguments
231 ///
232 /// * `level` - Log verbosity level.
233 ///
234 /// # Returns
235 ///
236 /// Self for method chaining.
237 pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
238 self.log_verbosity_level = Some(level);
239 self
240 }
241
242 /// Adds a session configuration entry.
243 ///
244 /// # Arguments
245 ///
246 /// * `key` - Configuration key.
247 /// * `value` - Configuration value.
248 ///
249 /// # Returns
250 ///
251 /// Self for method chaining.
252 pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
253 if let Some(ref mut entries) = self.session_config_entries {
254 entries.insert(key.into(), value.into());
255 } else {
256 let mut entries = std::collections::HashMap::new();
257 entries.insert(key.into(), value.into());
258 self.session_config_entries = Some(entries);
259 }
260 self
261 }
262
263 /// Gets the effective number of intra-op threads.
264 ///
265 /// # Returns
266 ///
267 /// The number of intra-op threads, or a default value if not set.
268 pub fn get_intra_threads(&self) -> usize {
269 self.intra_threads.unwrap_or_else(|| {
270 std::thread::available_parallelism()
271 .map(|n| n.get())
272 .unwrap_or(1)
273 })
274 }
275
276 /// Gets the effective number of inter-op threads.
277 ///
278 /// # Returns
279 ///
280 /// The number of inter-op threads, or a default value if not set.
281 pub fn get_inter_threads(&self) -> usize {
282 self.inter_threads.unwrap_or(1)
283 }
284
285 /// Gets the effective graph optimization level.
286 ///
287 /// # Returns
288 ///
289 /// The graph optimization level, or a default value if not set.
290 pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
291 self.optimization_level.unwrap_or_default()
292 }
293
294 /// Gets the execution providers.
295 ///
296 /// # Returns
297 ///
298 /// A reference to the execution providers, or a default CPU provider if not set.
299 pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
300 self.execution_providers
301 .clone()
302 .unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_ort_session_config_builder() {
312 let config = OrtSessionConfig::new()
313 .with_intra_threads(4)
314 .with_inter_threads(2)
315 .with_optimization_level(OrtGraphOptimizationLevel::Level2)
316 .with_memory_pattern(true)
317 .add_execution_provider(OrtExecutionProvider::CPU);
318
319 assert_eq!(config.intra_threads, Some(4));
320 assert_eq!(config.inter_threads, Some(2));
321 assert!(matches!(
322 config.optimization_level,
323 Some(OrtGraphOptimizationLevel::Level2)
324 ));
325 assert_eq!(config.enable_mem_pattern, Some(true));
326 assert!(config.execution_providers.is_some());
327 }
328
329 #[test]
330 fn test_ort_session_config_getters() {
331 let config = OrtSessionConfig::new()
332 .with_intra_threads(8)
333 .with_inter_threads(4)
334 .with_optimization_level(OrtGraphOptimizationLevel::All);
335
336 assert_eq!(config.get_intra_threads(), 8);
337 assert_eq!(config.get_inter_threads(), 4);
338 assert!(matches!(
339 config.get_optimization_level(),
340 OrtGraphOptimizationLevel::All
341 ));
342 }
343}