oar_ocr/core/config/
onnx.rs

1//! ONNX Runtime configuration types and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Graph optimization levels for ONNX Runtime.
6///
7/// This enum represents the different levels of graph optimization that can be applied
8/// during ONNX Runtime session creation.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
10pub enum OrtGraphOptimizationLevel {
11    /// Disable all optimizations.
12    DisableAll,
13    /// Enable basic optimizations.
14    Level1,
15    /// Enable extended optimizations.
16    Level2,
17    /// Enable all optimizations.
18    Level3,
19    /// Enable all optimizations (alias for Level3).
20    All,
21}
22
23impl Default for OrtGraphOptimizationLevel {
24    fn default() -> Self {
25        Self::Level1
26    }
27}
28
29/// Execution providers for ONNX Runtime.
30///
31/// This enum represents the different execution providers that can be used
32/// with ONNX Runtime for model inference.
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub enum OrtExecutionProvider {
35    /// CPU execution provider (always available)
36    CPU,
37    /// NVIDIA CUDA execution provider
38    CUDA {
39        /// CUDA device ID (default: 0)
40        device_id: Option<i32>,
41        /// Memory limit in bytes (optional)
42        gpu_mem_limit: Option<usize>,
43        /// Whether to use arena allocator (default: true)
44        arena_extend_strategy: Option<String>,
45        /// CUDNN convolution algorithm search (default: "EXHAUSTIVE")
46        cudnn_conv_algo_search: Option<String>,
47        /// Whether to do copy in default stream (default: true)
48        do_copy_in_default_stream: Option<bool>,
49        /// CUDNN convolution use max workspace (default: true)
50        cudnn_conv_use_max_workspace: Option<bool>,
51    },
52    /// DirectML execution provider (Windows only)
53    DirectML {
54        /// DirectML device ID (default: 0)
55        device_id: Option<i32>,
56    },
57    /// OpenVINO execution provider
58    OpenVINO {
59        /// Device type (e.g., "CPU", "GPU", "MYRIAD")
60        device_type: Option<String>,
61        /// Number of threads (optional)
62        num_threads: Option<usize>,
63    },
64    /// TensorRT execution provider
65    TensorRT {
66        /// TensorRT device ID (default: 0)
67        device_id: Option<i32>,
68        /// Maximum workspace size in bytes
69        max_workspace_size: Option<usize>,
70        /// Maximum batch size
71        max_batch_size: Option<usize>,
72        /// Minimum subgraph size
73        min_subgraph_size: Option<usize>,
74        /// FP16 enable flag
75        fp16_enable: Option<bool>,
76    },
77    /// CoreML execution provider (macOS/iOS only)
78    CoreML {
79        /// Use Apple Neural Engine only
80        ane_only: Option<bool>,
81        /// Enable subgraphs
82        subgraphs: Option<bool>,
83    },
84    /// WebGPU execution provider
85    WebGPU,
86}
87
88impl Default for OrtExecutionProvider {
89    fn default() -> Self {
90        Self::CPU
91    }
92}
93
94/// Configuration for ONNX Runtime sessions.
95///
96/// This struct contains various configuration options for ONNX Runtime sessions,
97/// including threading, memory management, and optimization settings.
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99pub struct OrtSessionConfig {
100    /// Number of threads used to parallelize execution within nodes
101    pub intra_threads: Option<usize>,
102    /// Number of threads used to parallelize execution across nodes
103    pub inter_threads: Option<usize>,
104    /// Enable parallel execution mode
105    pub parallel_execution: Option<bool>,
106    /// Graph optimization level
107    pub optimization_level: Option<OrtGraphOptimizationLevel>,
108    /// Execution providers in order of preference
109    pub execution_providers: Option<Vec<OrtExecutionProvider>>,
110    /// Enable memory pattern optimization
111    pub enable_mem_pattern: Option<bool>,
112    /// Enable CPU memory arena
113    pub enable_cpu_mem_arena: Option<bool>,
114    /// Memory arena extend strategy
115    pub memory_arena_extend_strategy: Option<String>,
116    /// Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal)
117    pub log_severity_level: Option<i32>,
118    /// Log verbosity level
119    pub log_verbosity_level: Option<i32>,
120    /// Session configuration entries (key-value pairs)
121    pub session_config_entries: Option<std::collections::HashMap<String, String>>,
122}
123
124impl OrtSessionConfig {
125    /// Creates a new OrtSessionConfig with default values.
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    /// Sets the number of intra-op threads.
131    ///
132    /// # Arguments
133    ///
134    /// * `threads` - Number of threads for intra-op parallelism.
135    ///
136    /// # Returns
137    ///
138    /// Self for method chaining.
139    pub fn with_intra_threads(mut self, threads: usize) -> Self {
140        self.intra_threads = Some(threads);
141        self
142    }
143
144    /// Sets the number of inter-op threads.
145    ///
146    /// # Arguments
147    ///
148    /// * `threads` - Number of threads for inter-op parallelism.
149    ///
150    /// # Returns
151    ///
152    /// Self for method chaining.
153    pub fn with_inter_threads(mut self, threads: usize) -> Self {
154        self.inter_threads = Some(threads);
155        self
156    }
157
158    /// Enables or disables parallel execution.
159    ///
160    /// # Arguments
161    ///
162    /// * `enabled` - Whether to enable parallel execution.
163    ///
164    /// # Returns
165    ///
166    /// Self for method chaining.
167    pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
168        self.parallel_execution = Some(enabled);
169        self
170    }
171
172    /// Sets the graph optimization level.
173    ///
174    /// # Arguments
175    ///
176    /// * `level` - The optimization level to use.
177    ///
178    /// # Returns
179    ///
180    /// Self for method chaining.
181    pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
182        self.optimization_level = Some(level);
183        self
184    }
185
186    /// Sets the execution providers.
187    ///
188    /// # Arguments
189    ///
190    /// * `providers` - Vector of execution providers in order of preference.
191    ///
192    /// # Returns
193    ///
194    /// Self for method chaining.
195    pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
196        self.execution_providers = Some(providers);
197        self
198    }
199
200    /// Adds a single execution provider.
201    ///
202    /// # Arguments
203    ///
204    /// * `provider` - The execution provider to add.
205    ///
206    /// # Returns
207    ///
208    /// Self for method chaining.
209    pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
210        if let Some(ref mut providers) = self.execution_providers {
211            providers.push(provider);
212        } else {
213            self.execution_providers = Some(vec![provider]);
214        }
215        self
216    }
217
218    /// Enables or disables memory pattern optimization.
219    ///
220    /// # Arguments
221    ///
222    /// * `enable` - Whether to enable memory pattern optimization.
223    ///
224    /// # Returns
225    ///
226    /// Self for method chaining.
227    pub fn with_memory_pattern(mut self, enable: bool) -> Self {
228        self.enable_mem_pattern = Some(enable);
229        self
230    }
231
232    /// Enables or disables CPU memory arena.
233    ///
234    /// # Arguments
235    ///
236    /// * `enable` - Whether to enable CPU memory arena.
237    ///
238    /// # Returns
239    ///
240    /// Self for method chaining.
241    pub fn with_cpu_memory_arena(mut self, enable: bool) -> Self {
242        self.enable_cpu_mem_arena = Some(enable);
243        self
244    }
245
246    /// Sets the log severity level.
247    ///
248    /// # Arguments
249    ///
250    /// * `level` - Log severity level (0=Verbose, 1=Info, 2=Warning, 3=Error, 4=Fatal).
251    ///
252    /// # Returns
253    ///
254    /// Self for method chaining.
255    pub fn with_log_severity_level(mut self, level: i32) -> Self {
256        self.log_severity_level = Some(level);
257        self
258    }
259
260    /// Sets the log verbosity level.
261    ///
262    /// # Arguments
263    ///
264    /// * `level` - Log verbosity level.
265    ///
266    /// # Returns
267    ///
268    /// Self for method chaining.
269    pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
270        self.log_verbosity_level = Some(level);
271        self
272    }
273
274    /// Adds a session configuration entry.
275    ///
276    /// # Arguments
277    ///
278    /// * `key` - Configuration key.
279    /// * `value` - Configuration value.
280    ///
281    /// # Returns
282    ///
283    /// Self for method chaining.
284    pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
285        if let Some(ref mut entries) = self.session_config_entries {
286            entries.insert(key.into(), value.into());
287        } else {
288            let mut entries = std::collections::HashMap::new();
289            entries.insert(key.into(), value.into());
290            self.session_config_entries = Some(entries);
291        }
292        self
293    }
294
295    /// Gets the effective number of intra-op threads.
296    ///
297    /// # Returns
298    ///
299    /// The number of intra-op threads, or a default value if not set.
300    pub fn get_intra_threads(&self) -> usize {
301        self.intra_threads.unwrap_or_else(|| {
302            std::thread::available_parallelism()
303                .map(|n| n.get())
304                .unwrap_or(1)
305        })
306    }
307
308    /// Gets the effective number of inter-op threads.
309    ///
310    /// # Returns
311    ///
312    /// The number of inter-op threads, or a default value if not set.
313    pub fn get_inter_threads(&self) -> usize {
314        self.inter_threads.unwrap_or(1)
315    }
316
317    /// Gets the effective graph optimization level.
318    ///
319    /// # Returns
320    ///
321    /// The graph optimization level, or a default value if not set.
322    pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
323        self.optimization_level.unwrap_or_default()
324    }
325
326    /// Gets the execution providers.
327    ///
328    /// # Returns
329    ///
330    /// A reference to the execution providers, or a default CPU provider if not set.
331    pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
332        self.execution_providers
333            .clone()
334            .unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_ort_session_config_builder() {
344        let config = OrtSessionConfig::new()
345            .with_intra_threads(4)
346            .with_inter_threads(2)
347            .with_optimization_level(OrtGraphOptimizationLevel::Level2)
348            .with_memory_pattern(true)
349            .add_execution_provider(OrtExecutionProvider::CPU);
350
351        assert_eq!(config.intra_threads, Some(4));
352        assert_eq!(config.inter_threads, Some(2));
353        assert!(matches!(
354            config.optimization_level,
355            Some(OrtGraphOptimizationLevel::Level2)
356        ));
357        assert_eq!(config.enable_mem_pattern, Some(true));
358        assert!(config.execution_providers.is_some());
359    }
360
361    #[test]
362    fn test_ort_session_config_getters() {
363        let config = OrtSessionConfig::new()
364            .with_intra_threads(8)
365            .with_inter_threads(4)
366            .with_optimization_level(OrtGraphOptimizationLevel::All);
367
368        assert_eq!(config.get_intra_threads(), 8);
369        assert_eq!(config.get_inter_threads(), 4);
370        assert!(matches!(
371            config.get_optimization_level(),
372            OrtGraphOptimizationLevel::All
373        ));
374    }
375}