use crate::{api, check, sys, Result};
use std::ffi::CString;
use std::ptr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ArenaState {
#[default]
Default,
Enabled,
Disabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum MemPatternState {
#[default]
Default,
Enabled,
Disabled,
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SessionOptions {
#[cfg_attr(feature = "serde", serde(with = "crate::serde_support::graph_opt"))]
pub(crate) opt_level: sys::GraphOptimizationLevel,
pub(crate) intra_threads: Option<i32>,
pub(crate) inter_threads: Option<i32>,
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) execution_mode: Option<sys::ExecutionMode>,
#[cfg_attr(feature = "serde", serde(with = "crate::serde_support::opt_cstr"))]
pub(crate) log_id: Option<CString>,
pub(crate) log_severity: Option<i32>,
pub(crate) log_verbosity: Option<i32>,
pub(crate) cpu_mem_arena: ArenaState,
pub(crate) mem_pattern: MemPatternState,
pub(crate) use_global_thread_pool: bool,
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) profiling_prefix: Option<CString>,
#[cfg_attr(feature = "serde", serde(with = "crate::serde_support::kv_i64_pairs"))]
pub(crate) free_dimension_overrides: Vec<(CString, i64)>,
#[cfg_attr(feature = "serde", serde(with = "crate::serde_support::kv_i64_pairs"))]
pub(crate) free_dimension_overrides_by_name: Vec<(CString, i64)>,
#[cfg_attr(feature = "serde", serde(with = "crate::serde_support::kv_pairs"))]
pub(crate) config_entries: Vec<(CString, CString)>,
#[cfg(feature = "ep")]
pub(crate) ep_configs: Vec<crate::ep::EpConfig>,
#[cfg(feature = "ep")]
pub(crate) migraphx: Vec<crate::ep::MigraphxOptions>,
#[cfg(feature = "ep")]
pub(crate) openvino: Vec<crate::ep::OpenvinoOptions>,
#[cfg(feature = "ep")]
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) ep_device_attach: Vec<crate::ep_device::EpDeviceAttach>,
#[cfg(feature = "custom-ops")]
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) custom_op_domains: Vec<*mut sys::CustomOpDomainHandle>,
}
impl Default for SessionOptions {
fn default() -> Self {
Self {
opt_level: sys::GraphOptimizationLevel::All,
intra_threads: None,
inter_threads: None,
execution_mode: None,
log_id: None,
log_severity: None,
log_verbosity: None,
cpu_mem_arena: ArenaState::Default,
mem_pattern: MemPatternState::Default,
use_global_thread_pool: true,
profiling_prefix: None,
free_dimension_overrides: Vec::new(),
free_dimension_overrides_by_name: Vec::new(),
config_entries: Vec::new(),
#[cfg(feature = "ep")]
ep_configs: Vec::new(),
#[cfg(feature = "ep")]
migraphx: Vec::new(),
#[cfg(feature = "ep")]
openvino: Vec::new(),
#[cfg(feature = "ep")]
ep_device_attach: Vec::new(),
#[cfg(feature = "custom-ops")]
custom_op_domains: Vec::new(),
}
}
}
impl SessionOptions {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn with_opt_level(mut self, level: sys::GraphOptimizationLevel) -> Self {
self.opt_level = level;
self
}
#[inline]
pub fn with_intra_threads(mut self, n: i32) -> Self {
self.intra_threads = Some(n);
self
}
#[inline]
pub fn with_inter_threads(mut self, n: i32) -> Self {
self.inter_threads = Some(n);
self
}
#[inline]
pub fn with_execution_mode(mut self, mode: sys::ExecutionMode) -> Self {
self.execution_mode = Some(mode);
self
}
#[inline]
pub fn with_sequential_execution(self) -> Self {
self.with_execution_mode(sys::ExecutionMode::Sequential)
}
#[inline]
pub fn with_parallel_execution(self) -> Self {
self.with_execution_mode(sys::ExecutionMode::Parallel)
}
pub fn with_log_id(mut self, id: &str) -> std::result::Result<Self, std::ffi::NulError> {
self.log_id = Some(CString::new(id)?);
Ok(self)
}
#[inline]
pub fn with_log_severity(mut self, level: sys::LoggingLevel) -> Self {
self.log_severity = Some(level as i32);
self
}
#[inline]
pub fn with_log_verbosity(mut self, level: i32) -> Self {
self.log_verbosity = Some(level);
self
}
#[inline]
pub fn use_global_thread_pool(mut self) -> Self {
self.use_global_thread_pool = true;
self
}
#[inline]
pub fn use_per_session_threads(mut self) -> Self {
self.use_global_thread_pool = false;
self
}
#[inline]
pub fn with_cpu_mem_arena(mut self, state: ArenaState) -> Self {
self.cpu_mem_arena = state;
self
}
#[inline]
pub fn enable_cpu_mem_arena(mut self) -> Self {
self.cpu_mem_arena = ArenaState::Enabled;
self
}
#[inline]
pub fn disable_cpu_mem_arena(mut self) -> Self {
self.cpu_mem_arena = ArenaState::Disabled;
self
}
#[inline]
pub fn with_mem_pattern(mut self, state: MemPatternState) -> Self {
self.mem_pattern = state;
self
}
#[inline]
pub fn enable_mem_pattern(mut self) -> Self {
self.mem_pattern = MemPatternState::Enabled;
self
}
#[inline]
pub fn disable_mem_pattern(mut self) -> Self {
self.mem_pattern = MemPatternState::Disabled;
self
}
pub fn enable_profiling(
mut self, profile_file_prefix: &str,
) -> std::result::Result<Self, std::ffi::NulError> {
self.profiling_prefix = Some(CString::new(profile_file_prefix)?);
Ok(self)
}
#[inline]
pub fn disable_profiling(mut self) -> Self {
self.profiling_prefix = None;
self
}
pub fn with_free_dimension_override(
mut self, dimension_denotation: &str, value: i64,
) -> std::result::Result<Self, std::ffi::NulError> {
self.free_dimension_overrides
.push((CString::new(dimension_denotation)?, value));
Ok(self)
}
pub fn with_free_dimension_override_by_name(
mut self, dimension_name: &str, value: i64,
) -> std::result::Result<Self, std::ffi::NulError> {
self.free_dimension_overrides_by_name
.push((CString::new(dimension_name)?, value));
Ok(self)
}
pub fn with_intra_op_spinning(
self, enable: bool,
) -> std::result::Result<Self, std::ffi::NulError> {
self.with_config_entry("session.intra_op.allow_spinning", bool_config_value(enable))
}
pub fn with_inter_op_spinning(
self, enable: bool,
) -> std::result::Result<Self, std::ffi::NulError> {
self.with_config_entry("session.inter_op.allow_spinning", bool_config_value(enable))
}
pub fn with_config_entry(
mut self, key: &str, value: &str,
) -> std::result::Result<Self, std::ffi::NulError> {
self.config_entries
.push((CString::new(key)?, CString::new(value)?));
Ok(self)
}
pub(crate) fn build_handle(&self) -> Result<*mut sys::SessionOptionsHandle> {
let api = api();
let mut opts: *mut sys::SessionOptionsHandle = ptr::null_mut();
check(unsafe { api.create_session_options()(&mut opts) })?;
let opts = crate::ensure_non_null(opts, "session options")?;
let result = (|| {
check(unsafe { api.set_session_graph_optimization_level()(opts, self.opt_level) })?;
if let Some(n) = self.intra_threads {
check(unsafe { api.set_intra_op_num_threads()(opts, n) })?;
}
if let Some(n) = self.inter_threads {
check(unsafe { api.set_inter_op_num_threads()(opts, n) })?;
}
if let Some(mode) = self.execution_mode {
check(unsafe { api.set_session_execution_mode()(opts, mode) })?;
}
if let Some(log_id) = &self.log_id {
check(unsafe { api.set_session_log_id()(opts, log_id.as_ptr()) })?;
}
if let Some(level) = self.log_severity {
check(unsafe { api.set_session_log_severity_level()(opts, level) })?;
}
if let Some(level) = self.log_verbosity {
check(unsafe { api.set_session_log_verbosity_level()(opts, level) })?;
}
match self.cpu_mem_arena {
ArenaState::Default => {},
ArenaState::Enabled => check(unsafe { api.enable_cpu_mem_arena()(opts) })?,
ArenaState::Disabled => check(unsafe { api.disable_cpu_mem_arena()(opts) })?,
}
match self.mem_pattern {
MemPatternState::Default => {},
MemPatternState::Enabled => check(unsafe { api.enable_mem_pattern()(opts) })?,
MemPatternState::Disabled => check(unsafe { api.disable_mem_pattern()(opts) })?,
}
if let Some(prefix) = &self.profiling_prefix {
check(unsafe { api.enable_profiling()(opts, prefix.as_ptr()) })?;
} else {
check(unsafe { api.disable_profiling()(opts) })?;
}
for (denotation, value) in &self.free_dimension_overrides {
check(unsafe {
api.add_free_dimension_override()(opts, denotation.as_ptr(), *value)
})?;
}
for (name, value) in &self.free_dimension_overrides_by_name {
check(unsafe {
api.add_free_dimension_override_by_name()(opts, name.as_ptr(), *value)
})?;
}
for (k, v) in &self.config_entries {
check(unsafe { api.add_session_config_entry()(opts, k.as_ptr(), v.as_ptr()) })?;
}
#[cfg(feature = "ep")]
for cfg in &self.ep_configs {
crate::ep::apply(opts, cfg)?;
}
#[cfg(feature = "ep")]
for m in &self.migraphx {
m.append_raw(opts)?;
}
#[cfg(feature = "ep")]
for o in &self.openvino {
o.append_raw(opts)?;
}
#[cfg(feature = "custom-ops")]
for domain in &self.custom_op_domains {
check(unsafe { api.add_custom_op_domain()(opts, *domain) })?;
}
Ok(opts)
})();
if result.is_err() {
unsafe { api.release_session_options()(opts) };
}
result
}
}
#[inline]
fn bool_config_value(enabled: bool) -> &'static str {
if enabled {
"1"
} else {
"0"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn advanced_options_build_handle() {
let opts = SessionOptions::new()
.with_opt_level(sys::GraphOptimizationLevel::All)
.with_intra_threads(1)
.with_inter_threads(1)
.with_parallel_execution()
.with_log_id("advanced-options")
.expect("log id")
.with_log_severity(sys::LoggingLevel::Verbose)
.with_log_verbosity(1)
.with_free_dimension_override("DATA_BATCH", 4)
.expect("free dim denotation")
.with_free_dimension_override_by_name("batch", 4)
.expect("free dim name")
.with_intra_op_spinning(false)
.expect("intra spin")
.with_inter_op_spinning(false)
.expect("inter spin");
let h = opts.build_handle().expect("advanced options handle");
unsafe {
api().release_session_options()(h);
}
}
}
#[cfg(all(test, feature = "serde"))]
mod serde_tests {
use super::*;
#[test]
fn session_options_round_trip() {
let opts = SessionOptions::new()
.with_opt_level(sys::GraphOptimizationLevel::Extended)
.with_intra_threads(4)
.with_inter_threads(2)
.with_parallel_execution()
.with_log_id("serde-session")
.expect("log id")
.with_log_severity(sys::LoggingLevel::Warning)
.with_log_verbosity(2)
.disable_cpu_mem_arena()
.disable_mem_pattern()
.with_free_dimension_override("DATA_BATCH", 4)
.expect("free dim denotation")
.with_free_dimension_override_by_name("batch", 4)
.expect("free dim name")
.with_intra_op_spinning(false)
.expect("intra spinning")
.with_inter_op_spinning(false)
.expect("inter spinning")
.with_config_entry("session.run", "1")
.expect("config entry");
let json = serde_json::to_string(&opts).expect("serialize");
eprintln!("SessionOptions JSON: {json}");
assert!(
json.contains("\"opt_level\":2"),
"opt_level discriminant (Extended=2) present: {json}"
);
assert!(
json.contains("\"session.run\""),
"config key present: {json}"
);
assert!(json.contains("\"serde-session\""), "log id present: {json}");
assert!(
json.contains("\"log_severity\":2"),
"log severity present: {json}"
);
assert!(
json.contains("\"log_verbosity\":2"),
"log verbosity present: {json}"
);
let back: SessionOptions = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.opt_level, sys::GraphOptimizationLevel::Extended);
assert_eq!(back.intra_threads, Some(4));
assert_eq!(back.inter_threads, Some(2));
assert_eq!(back.execution_mode, None);
assert_eq!(
back.log_id.as_ref().and_then(|id| id.to_str().ok()),
Some("serde-session")
);
assert_eq!(back.log_severity, Some(sys::LoggingLevel::Warning as i32));
assert_eq!(back.log_verbosity, Some(2));
assert_eq!(back.cpu_mem_arena, ArenaState::Disabled);
assert_eq!(back.mem_pattern, MemPatternState::Disabled);
assert_eq!(
back.free_dimension_overrides
.iter()
.filter(|(k, _)| k.to_str() == Ok("DATA_BATCH"))
.count(),
1
);
assert_eq!(
back.free_dimension_overrides_by_name
.iter()
.filter(|(k, _)| k.to_str() == Ok("batch"))
.count(),
1
);
assert_eq!(
back.config_entries
.iter()
.filter(|(k, _)| k.to_str() == Ok("session.run"))
.count(),
1
);
let h = back
.build_handle()
.expect("build handle from deserialized config");
unsafe {
api().release_session_options()(h);
}
let enabled = SessionOptions::new()
.with_cpu_mem_arena(ArenaState::Enabled)
.with_mem_pattern(MemPatternState::Enabled);
assert_eq!(enabled.cpu_mem_arena, ArenaState::Enabled);
assert_eq!(enabled.mem_pattern, MemPatternState::Enabled);
let h = enabled.build_handle().expect("build enabled handle");
unsafe {
api().release_session_options()(h);
}
}
}