oar_ocr/core/config/onnx.rs
1//! ONNX Runtime configuration types and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Graph optimization levels for ONNX Runtime.
6///
7/// This enum represents the different levels of graph optimization that can be applied
8/// during ONNX Runtime session creation.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
10pub enum OrtGraphOptimizationLevel {
11 /// Disable all optimizations.
12 DisableAll,
13 /// Enable basic optimizations.
14 Level1,
15 /// Enable extended optimizations.
16 Level2,
17 /// Enable all optimizations.
18 Level3,
19 /// Enable all optimizations (alias for Level3).
20 All,
21}
22
23impl Default for OrtGraphOptimizationLevel {
24 fn default() -> Self {
25 Self::Level1
26 }
27}
28
29/// Execution providers for ONNX Runtime.
30///
31/// This enum represents the different execution providers that can be used
32/// with ONNX Runtime for model inference.
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub enum OrtExecutionProvider {
35 /// CPU execution provider (always available)
36 CPU,
37 /// NVIDIA CUDA execution provider
38 CUDA {
39 /// CUDA device ID (default: 0)
40 device_id: Option<i32>,
41 /// Memory limit in bytes (optional)
42 gpu_mem_limit: Option<usize>,
43 /// Whether to use arena allocator (default: true)
44 arena_extend_strategy: Option<String>,
45 /// CUDNN convolution algorithm search (default: "EXHAUSTIVE")
46 cudnn_conv_algo_search: Option<String>,
47 /// Whether to do copy in default stream (default: true)
48 do_copy_in_default_stream: Option<bool>,
49 /// CUDNN convolution use max workspace (default: true)
50 cudnn_conv_use_max_workspace: Option<bool>,
51 },
52 /// DirectML execution provider (Windows only)
53 DirectML {
54 /// DirectML device ID (default: 0)
55 device_id: Option<i32>,
56 },
57 /// OpenVINO execution provider
58 OpenVINO {
59 /// Device type (e.g., "CPU", "GPU", "MYRIAD")
60 device_type: Option<String>,
61 /// Number of threads (optional)
62 num_threads: Option<usize>,
63 },
64 /// TensorRT execution provider
65 TensorRT {
66 /// TensorRT device ID (default: 0)
67 device_id: Option<i32>,
68 /// Maximum workspace size in bytes
69 max_workspace_size: Option<usize>,
70 /// Maximum batch size
71 max_batch_size: Option<usize>,
72 /// Minimum subgraph size
73 min_subgraph_size: Option<usize>,
74 /// FP16 enable flag
75 fp16_enable: Option<bool>,
76 },
77 /// CoreML execution provider (macOS/iOS only)
78 CoreML {
79 /// Use Apple Neural Engine only
80 ane_only: Option<bool>,
81 /// Enable subgraphs
82 subgraphs: Option<bool>,
83 },
84 /// WebGPU execution provider
85 WebGPU,
86}
87
88impl Default for OrtExecutionProvider {
89 fn default() -> Self {
90 Self::CPU
91 }
92}
93
94/// Configuration for ONNX Runtime sessions.
95///
96/// This struct contains various configuration options for ONNX Runtime sessions,
97/// including threading, memory management, and optimization settings.
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct OrtSessionConfig {
100 /// Number of threads used to parallelize execution within nodes
101 pub intra_threads: Option<usize>,
102 /// Number of threads used to parallelize execution across nodes
103 pub inter_threads: Option<usize>,
104 /// Enable parallel execution mode
105 pub parallel_execution: Option<bool>,
106 /// Graph optimization level
107 pub optimization_level: Option<OrtGraphOptimizationLevel>,
108 /// Execution providers in order of preference
109 pub execution_providers: Option<Vec<OrtExecutionProvider>>,
110 /// Enable memory pattern optimization
111 pub enable_mem_pattern: Option<bool>,
112 /// Enable CPU memory arena
113 pub enable_cpu_mem_arena: Option<bool>,
114 /// Memory arena extend strategy
115 pub memory_arena_extend_strategy: Option<String>,
116 /// Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal)
117 pub log_severity_level: Option<i32>,
118 /// Log verbosity level
119 pub log_verbosity_level: Option<i32>,
120 /// Session configuration entries (key-value pairs)
121 pub session_config_entries: Option<std::collections::HashMap<String, String>>,
122}
123
124impl OrtSessionConfig {
125 /// Creates a new OrtSessionConfig with default values.
126 pub fn new() -> Self {
127 Self::default()
128 }
129
130 /// Sets the number of intra-op threads.
131 ///
132 /// # Arguments
133 ///
134 /// * `threads` - Number of threads for intra-op parallelism.
135 ///
136 /// # Returns
137 ///
138 /// Self for method chaining.
139 pub fn with_intra_threads(mut self, threads: usize) -> Self {
140 self.intra_threads = Some(threads);
141 self
142 }
143
144 /// Sets the number of inter-op threads.
145 ///
146 /// # Arguments
147 ///
148 /// * `threads` - Number of threads for inter-op parallelism.
149 ///
150 /// # Returns
151 ///
152 /// Self for method chaining.
153 pub fn with_inter_threads(mut self, threads: usize) -> Self {
154 self.inter_threads = Some(threads);
155 self
156 }
157
158 /// Enables or disables parallel execution.
159 ///
160 /// # Arguments
161 ///
162 /// * `enabled` - Whether to enable parallel execution.
163 ///
164 /// # Returns
165 ///
166 /// Self for method chaining.
167 pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
168 self.parallel_execution = Some(enabled);
169 self
170 }
171
172 /// Sets the graph optimization level.
173 ///
174 /// # Arguments
175 ///
176 /// * `level` - The optimization level to use.
177 ///
178 /// # Returns
179 ///
180 /// Self for method chaining.
181 pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
182 self.optimization_level = Some(level);
183 self
184 }
185
186 /// Sets the execution providers.
187 ///
188 /// # Arguments
189 ///
190 /// * `providers` - Vector of execution providers in order of preference.
191 ///
192 /// # Returns
193 ///
194 /// Self for method chaining.
195 pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
196 self.execution_providers = Some(providers);
197 self
198 }
199
200 /// Adds a single execution provider.
201 ///
202 /// # Arguments
203 ///
204 /// * `provider` - The execution provider to add.
205 ///
206 /// # Returns
207 ///
208 /// Self for method chaining.
209 pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
210 if let Some(ref mut providers) = self.execution_providers {
211 providers.push(provider);
212 } else {
213 self.execution_providers = Some(vec![provider]);
214 }
215 self
216 }
217
218 /// Enables or disables memory pattern optimization.
219 ///
220 /// # Arguments
221 ///
222 /// * `enable` - Whether to enable memory pattern optimization.
223 ///
224 /// # Returns
225 ///
226 /// Self for method chaining.
227 pub fn with_memory_pattern(mut self, enable: bool) -> Self {
228 self.enable_mem_pattern = Some(enable);
229 self
230 }
231
232 /// Enables or disables CPU memory arena.
233 ///
234 /// # Arguments
235 ///
236 /// * `enable` - Whether to enable CPU memory arena.
237 ///
238 /// # Returns
239 ///
240 /// Self for method chaining.
241 pub fn with_cpu_memory_arena(mut self, enable: bool) -> Self {
242 self.enable_cpu_mem_arena = Some(enable);
243 self
244 }
245
246 /// Sets the log severity level.
247 ///
248 /// # Arguments
249 ///
250 /// * `level` - Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal).
251 ///
252 /// # Returns
253 ///
254 /// Self for method chaining.
255 pub fn with_log_severity_level(mut self, level: i32) -> Self {
256 self.log_severity_level = Some(level);
257 self
258 }
259
260 /// Sets the log verbosity level.
261 ///
262 /// # Arguments
263 ///
264 /// * `level` - Log verbosity level.
265 ///
266 /// # Returns
267 ///
268 /// Self for method chaining.
269 pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
270 self.log_verbosity_level = Some(level);
271 self
272 }
273
274 /// Adds a session configuration entry.
275 ///
276 /// # Arguments
277 ///
278 /// * `key` - Configuration key.
279 /// * `value` - Configuration value.
280 ///
281 /// # Returns
282 ///
283 /// Self for method chaining.
284 pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
285 if let Some(ref mut entries) = self.session_config_entries {
286 entries.insert(key.into(), value.into());
287 } else {
288 let mut entries = std::collections::HashMap::new();
289 entries.insert(key.into(), value.into());
290 self.session_config_entries = Some(entries);
291 }
292 self
293 }
294
295 /// Gets the effective number of intra-op threads.
296 ///
297 /// # Returns
298 ///
299 /// The number of intra-op threads, or a default value if not set.
300 pub fn get_intra_threads(&self) -> usize {
301 self.intra_threads.unwrap_or_else(|| {
302 std::thread::available_parallelism()
303 .map(|n| n.get())
304 .unwrap_or(1)
305 })
306 }
307
308 /// Gets the effective number of inter-op threads.
309 ///
310 /// # Returns
311 ///
312 /// The number of inter-op threads, or a default value if not set.
313 pub fn get_inter_threads(&self) -> usize {
314 self.inter_threads.unwrap_or(1)
315 }
316
317 /// Gets the effective graph optimization level.
318 ///
319 /// # Returns
320 ///
321 /// The graph optimization level, or a default value if not set.
322 pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
323 self.optimization_level.unwrap_or_default()
324 }
325
326 /// Gets the execution providers.
327 ///
328 /// # Returns
329 ///
330 /// A reference to the execution providers, or a default CPU provider if not set.
331 pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
332 self.execution_providers
333 .clone()
334 .unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_ort_session_config_builder() {
344 let config = OrtSessionConfig::new()
345 .with_intra_threads(4)
346 .with_inter_threads(2)
347 .with_optimization_level(OrtGraphOptimizationLevel::Level2)
348 .with_memory_pattern(true)
349 .add_execution_provider(OrtExecutionProvider::CPU);
350
351 assert_eq!(config.intra_threads, Some(4));
352 assert_eq!(config.inter_threads, Some(2));
353 assert!(matches!(
354 config.optimization_level,
355 Some(OrtGraphOptimizationLevel::Level2)
356 ));
357 assert_eq!(config.enable_mem_pattern, Some(true));
358 assert!(config.execution_providers.is_some());
359 }
360
361 #[test]
362 fn test_ort_session_config_getters() {
363 let config = OrtSessionConfig::new()
364 .with_intra_threads(8)
365 .with_inter_threads(4)
366 .with_optimization_level(OrtGraphOptimizationLevel::All);
367
368 assert_eq!(config.get_intra_threads(), 8);
369 assert_eq!(config.get_inter_threads(), 4);
370 assert!(matches!(
371 config.get_optimization_level(),
372 OrtGraphOptimizationLevel::All
373 ));
374 }
375}