use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum OrtGraphOptimizationLevel {
DisableAll,
#[default]
Level1,
Level2,
Level3,
All,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub enum OrtExecutionProvider {
#[default]
CPU,
CUDA {
device_id: Option<i32>,
gpu_mem_limit: Option<usize>,
arena_extend_strategy: Option<String>,
cudnn_conv_algo_search: Option<String>,
cudnn_conv_use_max_workspace: Option<bool>,
},
DirectML {
device_id: Option<i32>,
},
OpenVINO {
device_type: Option<String>,
num_threads: Option<usize>,
},
TensorRT {
device_id: Option<i32>,
max_workspace_size: Option<usize>,
min_subgraph_size: Option<usize>,
fp16_enable: Option<bool>,
timing_cache: Option<bool>,
timing_cache_path: Option<String>,
force_timing_cache: Option<bool>,
engine_cache: Option<bool>,
engine_cache_path: Option<String>,
dump_ep_context_model: Option<bool>,
ep_context_file_path: Option<String>,
},
CoreML {
ane_only: Option<bool>,
subgraphs: Option<bool>,
},
WebGPU,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OrtSessionConfig {
pub intra_threads: Option<usize>,
pub inter_threads: Option<usize>,
pub parallel_execution: Option<bool>,
pub optimization_level: Option<OrtGraphOptimizationLevel>,
pub execution_providers: Option<Vec<OrtExecutionProvider>>,
pub enable_mem_pattern: Option<bool>,
pub log_severity_level: Option<i32>,
pub log_verbosity_level: Option<i32>,
pub session_config_entries: Option<std::collections::HashMap<String, String>>,
}
impl OrtSessionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_intra_threads(mut self, threads: usize) -> Self {
self.intra_threads = Some(threads);
self
}
pub fn with_inter_threads(mut self, threads: usize) -> Self {
self.inter_threads = Some(threads);
self
}
pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
self.parallel_execution = Some(enabled);
self
}
pub fn with_optimization_level(mut self, level: OrtGraphOptimizationLevel) -> Self {
self.optimization_level = Some(level);
self
}
pub fn with_execution_providers(mut self, providers: Vec<OrtExecutionProvider>) -> Self {
self.execution_providers = Some(providers);
self
}
pub fn add_execution_provider(mut self, provider: OrtExecutionProvider) -> Self {
if let Some(ref mut providers) = self.execution_providers {
providers.push(provider);
} else {
self.execution_providers = Some(vec![provider]);
}
self
}
pub fn with_memory_pattern(mut self, enable: bool) -> Self {
self.enable_mem_pattern = Some(enable);
self
}
pub fn with_log_severity_level(mut self, level: i32) -> Self {
self.log_severity_level = Some(level);
self
}
pub fn with_log_verbosity_level(mut self, level: i32) -> Self {
self.log_verbosity_level = Some(level);
self
}
pub fn add_config_entry<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
if let Some(ref mut entries) = self.session_config_entries {
entries.insert(key.into(), value.into());
} else {
let mut entries = std::collections::HashMap::new();
entries.insert(key.into(), value.into());
self.session_config_entries = Some(entries);
}
self
}
pub fn get_intra_threads(&self) -> usize {
self.intra_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
})
}
pub fn get_inter_threads(&self) -> usize {
self.inter_threads.unwrap_or(1)
}
pub fn get_optimization_level(&self) -> OrtGraphOptimizationLevel {
self.optimization_level.unwrap_or_default()
}
pub fn get_execution_providers(&self) -> Vec<OrtExecutionProvider> {
self.execution_providers
.clone()
.unwrap_or_else(|| vec![OrtExecutionProvider::CPU])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ort_session_config_builder() {
let config = OrtSessionConfig::new()
.with_intra_threads(4)
.with_inter_threads(2)
.with_optimization_level(OrtGraphOptimizationLevel::Level2)
.with_memory_pattern(true)
.add_execution_provider(OrtExecutionProvider::CPU);
assert_eq!(config.intra_threads, Some(4));
assert_eq!(config.inter_threads, Some(2));
assert!(matches!(
config.optimization_level,
Some(OrtGraphOptimizationLevel::Level2)
));
assert_eq!(config.enable_mem_pattern, Some(true));
assert!(config.execution_providers.is_some());
}
#[test]
fn test_ort_session_config_getters() {
let config = OrtSessionConfig::new()
.with_intra_threads(8)
.with_inter_threads(4)
.with_optimization_level(OrtGraphOptimizationLevel::All);
assert_eq!(config.get_intra_threads(), 8);
assert_eq!(config.get_inter_threads(), 4);
assert!(matches!(
config.get_optimization_level(),
OrtGraphOptimizationLevel::All
));
}
}