Skip to main content

oar_ocr_core/core/config/
onnx.rs

1//! ONNX Runtime configuration types and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Graph optimization levels for ONNX Runtime.
6///
7/// This enum represents the different levels of graph optimization that can be applied
8/// during ONNX Runtime session creation.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
10pub enum OrtGraphOptimizationLevel {
11    /// Disable all optimizations.
12    DisableAll,
13    /// Enable basic optimizations.
14    #[default]
15    Level1,
16    /// Enable extended optimizations.
17    Level2,
18    /// Enable all optimizations.
19    Level3,
20    /// Enable all optimizations (alias for Level3).
21    All,
22}
23
24/// Execution providers for ONNX Runtime.
25///
26/// This enum represents the different execution providers that can be used
27/// with ONNX Runtime for model inference.
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
29pub enum OrtExecutionProvider {
30    /// CPU execution provider (always available)
31    #[default]
32    CPU,
33    /// NVIDIA CUDA execution provider
34    CUDA {
35        /// CUDA device ID (default: 0)
36        device_id: Option<i32>,
37        /// Memory limit in bytes (optional)
38        gpu_mem_limit: Option<usize>,
39        /// Arena extend strategy: "NextPowerOfTwo" or "SameAsRequested"
40        arena_extend_strategy: Option<String>,
41        /// CUDNN convolution algorithm search: "Exhaustive", "Heuristic", or "Default"
42        cudnn_conv_algo_search: Option<String>,
43        /// CUDNN convolution use max workspace (default: true)
44        cudnn_conv_use_max_workspace: Option<bool>,
45    },
46    /// DirectML execution provider (Windows only)
47    DirectML {
48        /// DirectML device ID (default: 0)
49        device_id: Option<i32>,
50    },
51    /// OpenVINO execution provider
52    OpenVINO {
53        /// Device type (e.g., "CPU", "GPU", "MYRIAD")
54        device_type: Option<String>,
55        /// Number of threads (optional)
56        num_threads: Option<usize>,
57    },
58    /// TensorRT execution provider
59    TensorRT {
60        /// TensorRT device ID (default: 0)
61        device_id: Option<i32>,
62        /// Maximum workspace size in bytes
63        max_workspace_size: Option<usize>,
64        /// Minimum subgraph size for TensorRT acceleration
65        min_subgraph_size: Option<usize>,
66        /// FP16 enable flag
67        fp16_enable: Option<bool>,
68    },
69    /// CoreML execution provider (macOS/iOS only)
70    CoreML {
71        /// Use Apple Neural Engine only
72        ane_only: Option<bool>,
73        /// Enable subgraphs
74        subgraphs: Option<bool>,
75    },
76    /// WebGPU execution provider
77    WebGPU,
78}
79
80/// Configuration for ONNX Runtime sessions.
81///
82/// This struct contains various configuration options for ONNX Runtime sessions,
83/// including threading, memory management, and optimization settings.
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct OrtSessionConfig {
86    /// Number of threads used to parallelize execution within nodes
87    pub intra_threads: Option<usize>,
88    /// Number of threads used to parallelize execution across nodes
89    pub inter_threads: Option<usize>,
90    /// Enable parallel execution mode
91    pub parallel_execution: Option<bool>,
92    /// Graph optimization level
93    pub optimization_level: Option<OrtGraphOptimizationLevel>,
94    /// Execution providers in order of preference
95    pub execution_providers: Option<Vec<OrtExecutionProvider>>,
96    /// Enable memory pattern optimization
97    pub enable_mem_pattern: Option<bool>,
98    /// Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal)
99    pub log_severity_level: Option<i32>,
100    /// Log verbosity level
101    pub log_verbosity_level: Option<i32>,
102    /// Session configuration entries (key-value pairs)
103    pub session_config_entries: Option<std::collections::HashMap<String, String>>,
104}
105
106impl OrtSessionConfig {
107    /// Creates a new OrtSessionConfig with default values.
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    /// Sets the number of intra-op threads.
113    ///
114    /// # Arguments
115    ///
116    /// * `threads` - Number of threads for intra-op parallelism.
117    ///
118    /// # Returns
119    ///
120    /// Self for method chaining.
121    pub fn with_intra_threads(mut self, threads: usize) -> Self {
122        self.intra_threads = Some(threads);
123        self
124    }
125
126    /// Sets the number of inter-op threads.
127    ///
128    /// # Arguments
129    ///
130    /// * `threads` - Number of threads for inter-op parallelism.
131    ///
132    /// # Returns
133    ///
134    /// Self for method chaining.
135    pub fn with_inter_threads(mut self, threads: usize) -> Self {
136        self.inter_threads = Some(threads);
137        self
138    }
139
140    /// Enables or disables parallel execution.
141    ///
142    /// # Arguments
143    ///
144    /// * `enabled` - Whether to enable parallel execution.
145    ///
146    /// # Returns
147    ///
148    /// Self for method chaining.
149    pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
150        self.parallel_execution = Some(enabled);
151        self
152    }
153
154    /// Sets the graph optimization level.
155    ///
156    /// # Arguments
157    ///
158    /// * `level` - The optimization level to use.
159    ///
160    /// # Returns
161    ///
162    /// Self for method chaining.
163    pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
164        self.optimization_level = Some(level);
165        self
166    }
167
168    /// Sets the execution providers.
169    ///
170    /// # Arguments
171    ///
172    /// * `providers` - Vector of execution providers in order of preference.
173    ///
174    /// # Returns
175    ///
176    /// Self for method chaining.
177    pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
178        self.execution_providers = Some(providers);
179        self
180    }
181
182    /// Adds a single execution provider.
183    ///
184    /// # Arguments
185    ///
186    /// * `provider` - The execution provider to add.
187    ///
188    /// # Returns
189    ///
190    /// Self for method chaining.
191    pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
192        if let Some(ref mut providers) = self.execution_providers {
193            providers.push(provider);
194        } else {
195            self.execution_providers = Some(vec![provider]);
196        }
197        self
198    }
199
200    /// Enables or disables memory pattern optimization.
201    ///
202    /// # Arguments
203    ///
204    /// * `enable` - Whether to enable memory pattern optimization.
205    ///
206    /// # Returns
207    ///
208    /// Self for method chaining.
209    pub fn with_memory_pattern(mut self, enable: bool) -> Self {
210        self.enable_mem_pattern = Some(enable);
211        self
212    }
213
214    /// Sets the log severity level.
215    ///
216    /// # Arguments
217    ///
218    /// * `level` - Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal).
219    ///
220    /// # Returns
221    ///
222    /// Self for method chaining.
223    pub fn with_log_severity_level(mut self, level: i32) -> Self {
224        self.log_severity_level = Some(level);
225        self
226    }
227
228    /// Sets the log verbosity level.
229    ///
230    /// # Arguments
231    ///
232    /// * `level` - Log verbosity level.
233    ///
234    /// # Returns
235    ///
236    /// Self for method chaining.
237    pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
238        self.log_verbosity_level = Some(level);
239        self
240    }
241
242    /// Adds a session configuration entry.
243    ///
244    /// # Arguments
245    ///
246    /// * `key` - Configuration key.
247    /// * `value` - Configuration value.
248    ///
249    /// # Returns
250    ///
251    /// Self for method chaining.
252    pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
253        if let Some(ref mut entries) = self.session_config_entries {
254            entries.insert(key.into(), value.into());
255        } else {
256            let mut entries = std::collections::HashMap::new();
257            entries.insert(key.into(), value.into());
258            self.session_config_entries = Some(entries);
259        }
260        self
261    }
262
263    /// Gets the effective number of intra-op threads.
264    ///
265    /// # Returns
266    ///
267    /// The number of intra-op threads, or a default value if not set.
268    pub fn get_intra_threads(&self) -> usize {
269        self.intra_threads.unwrap_or_else(|| {
270            std::thread::available_parallelism()
271                .map(|n| n.get())
272                .unwrap_or(1)
273        })
274    }
275
276    /// Gets the effective number of inter-op threads.
277    ///
278    /// # Returns
279    ///
280    /// The number of inter-op threads, or a default value if not set.
281    pub fn get_inter_threads(&self) -> usize {
282        self.inter_threads.unwrap_or(1)
283    }
284
285    /// Gets the effective graph optimization level.
286    ///
287    /// # Returns
288    ///
289    /// The graph optimization level, or a default value if not set.
290    pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
291        self.optimization_level.unwrap_or_default()
292    }
293
294    /// Gets the execution providers.
295    ///
296    /// # Returns
297    ///
298    /// A reference to the execution providers, or a default CPU provider if not set.
299    pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
300        self.execution_providers
301            .clone()
302            .unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_ort_session_config_builder() {
312        let config = OrtSessionConfig::new()
313            .with_intra_threads(4)
314            .with_inter_threads(2)
315            .with_optimization_level(OrtGraphOptimizationLevel::Level2)
316            .with_memory_pattern(true)
317            .add_execution_provider(OrtExecutionProvider::CPU);
318
319        assert_eq!(config.intra_threads, Some(4));
320        assert_eq!(config.inter_threads, Some(2));
321        assert!(matches!(
322            config.optimization_level,
323            Some(OrtGraphOptimizationLevel::Level2)
324        ));
325        assert_eq!(config.enable_mem_pattern, Some(true));
326        assert!(config.execution_providers.is_some());
327    }
328
329    #[test]
330    fn test_ort_session_config_getters() {
331        let config = OrtSessionConfig::new()
332            .with_intra_threads(8)
333            .with_inter_threads(4)
334            .with_optimization_level(OrtGraphOptimizationLevel::All);
335
336        assert_eq!(config.get_intra_threads(), 8);
337        assert_eq!(config.get_inter_threads(), 4);
338        assert!(matches!(
339            config.get_optimization_level(),
340            OrtGraphOptimizationLevel::All
341        ));
342    }
343}