use alloc::sync::Arc;
use core::{
any::Any,
ffi::{c_int, c_void},
ptr
};
#[cfg(feature = "std")]
use std::path::Path;
use super::{BuilderResult, SessionBuilder};
#[cfg(feature = "std")]
use crate::util::path_to_os_char;
use crate::{
AsPointer, Error, ErrorCode,
environment::{self, ThreadManager},
ep::{ExecutionProviderDispatch, apply_execution_providers},
logging::{LogLevel, LoggerFunction},
memory::MemoryInfo,
operator::OperatorDomain,
ortsys,
util::with_cstr,
value::DynValue
};
impl SessionBuilder {
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> BuilderResult {
match apply_execution_providers(&mut self, execution_providers.as_ref(), "session options") {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_intra_threads(mut self, num_threads: usize) -> BuilderResult {
match ortsys![@ort: unsafe SetIntraOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_inter_threads(mut self, num_threads: usize) -> BuilderResult {
match ortsys![@ort: unsafe SetInterOpNumThreads(self.ptr_mut(), num_threads as _) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_parallel_execution(mut self, parallel_execution: bool) -> BuilderResult {
let execution_mode = if parallel_execution {
ort_sys::ExecutionMode::ORT_PARALLEL
} else {
ort_sys::ExecutionMode::ORT_SEQUENTIAL
};
match ortsys![@ort: unsafe SetSessionExecutionMode(self.ptr_mut(), execution_mode) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> BuilderResult {
match ortsys![@ort: unsafe SetSessionGraphOptimizationLevel(self.ptr_mut(), opt_level.into()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn with_optimized_model_path<S: AsRef<Path>>(mut self, path: S) -> BuilderResult {
let path = crate::util::path_to_os_char(path);
match ortsys![@ort: unsafe SetOptimizedModelFilePath(self.ptr_mut(), path.as_ptr()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn with_profiling<S: AsRef<Path>>(mut self, profiling_file: S) -> BuilderResult {
let profiling_file = crate::util::path_to_os_char(profiling_file);
match ortsys![@ort: unsafe EnableProfiling(self.ptr_mut(), profiling_file.as_ptr()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_memory_pattern(mut self, enable: bool) -> BuilderResult {
let result = if enable {
ortsys![@ort: unsafe EnableMemPattern(self.ptr_mut()) as Result]
} else {
ortsys![@ort: unsafe DisableMemPattern(self.ptr_mut()) as Result]
};
match result {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_allocator(mut self, info: MemoryInfo) -> BuilderResult {
if !info.is_cpu_accessible() {
return Err(
Error::new_with_code(ErrorCode::InvalidArgument, "SessionBuilder::with_allocator may only use a CPU-accessible allocator").with_recover(self)
);
}
self.memory_info = Some(Arc::new(info));
Ok(self)
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn with_operator_library(mut self, lib_path: impl AsRef<Path>) -> BuilderResult {
let path_cstr = path_to_os_char(lib_path);
match ortsys![@ort: unsafe RegisterCustomOpsLibrary_V2(self.ptr_mut(), path_cstr.as_ptr()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_extensions(mut self) -> BuilderResult {
match ortsys![@ort: unsafe EnableOrtCustomOps(self.ptr_mut()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_operators(mut self, domain: impl Into<Arc<OperatorDomain>>) -> BuilderResult {
let domain: Arc<OperatorDomain> = domain.into();
match ortsys![@ort: unsafe AddCustomOpDomain(self.ptr_mut(), domain.ptr().cast_mut()) as Result] {
Ok(()) => {
self.operator_domains.push(domain);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_deterministic_compute(mut self, enable: bool) -> BuilderResult {
match ortsys![@ort: unsafe SetDeterministicCompute(self.ptr_mut(), enable) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
let ptr = self.ptr_mut();
let value: Arc<DynValue> = value.into();
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddInitializer(ptr, name.as_ptr(), value.ptr()) as Result]) {
Ok(()) => {
self.initializers.push(value);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: impl Into<Arc<DynValue>>) -> BuilderResult {
let ptr = self.ptr_mut();
let value: Arc<DynValue> = value.into();
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddExternalInitializers(ptr, &name.as_ptr(), &value.ptr(), 1) as Result]) {
Ok(()) => {
self.initializers.push(value);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
#[cfg(all(feature = "std", feature = "api-18"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "std", feature = "api-18"))))]
pub fn with_external_initializer_file_in_memory(mut self, file_name: impl AsRef<Path>, buffer: alloc::borrow::Cow<'static, [u8]>) -> BuilderResult {
let file_name = path_to_os_char(file_name);
let sizes = [buffer.len()];
match ortsys![@ort:
unsafe AddExternalInitializersFromMemory(
self.ptr_mut(),
&file_name.as_ptr(),
&buffer.as_ptr().cast::<core::ffi::c_char>().cast_mut(),
sizes.as_ptr(),
1
) as Result
] {
Ok(()) => {
self.external_initializer_buffers.push(buffer);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_log_id(mut self, id: impl AsRef<str>) -> BuilderResult {
let ptr = self.ptr_mut();
match with_cstr(id.as_ref().as_bytes(), &|id| ortsys![@ort: unsafe SetSessionLogId(ptr, id.as_ptr()) as Result]) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_dimension_override(mut self, name: impl AsRef<str>, size: i64) -> BuilderResult {
let ptr = self.ptr_mut();
match with_cstr(name.as_ref().as_bytes(), &|name| ortsys![@ort: unsafe AddFreeDimensionOverrideByName(ptr, name.as_ptr(), size) as Result]) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_dimension_override_by_denotation(mut self, denotation: impl AsRef<str>, size: i64) -> BuilderResult {
let ptr = self.ptr_mut();
match with_cstr(denotation.as_ref().as_bytes(), &|denotation| ortsys![@ort: unsafe AddFreeDimensionOverride(ptr, denotation.as_ptr(), size) as Result])
{
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_prepacked_weights(mut self, weights: &PrepackedWeights) -> BuilderResult {
self.prepacked_weights = Some(weights.clone());
Ok(self)
}
pub fn with_independent_thread_pool(mut self) -> BuilderResult {
self.no_global_thread_pool = true;
Ok(self)
}
pub fn with_no_environment_execution_providers(mut self) -> BuilderResult {
self.no_env_eps = true;
Ok(self)
}
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> BuilderResult {
let manager = Arc::new(manager);
let ptr = self.ptr_mut();
match ortsys![@ort: unsafe SessionOptionsSetCustomThreadCreationOptions(ptr, (&*manager as *const T) as *mut c_void) as Result]
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomCreateThreadFn(ptr, Some(environment::thread_create::<T>)) as Result])
.and_then(|()| ortsys![@ort: unsafe SessionOptionsSetCustomJoinThreadFn(ptr, Some(environment::thread_join::<T>)) as Result])
{
Ok(()) => {
self.thread_manager = Some(manager as Arc<dyn Any>);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_logger(mut self, logger: LoggerFunction) -> BuilderResult {
let logger = Arc::new(logger);
match ortsys![@ort: unsafe SetUserLoggingFunction(self.ptr_mut(), crate::logging::custom_logger, Arc::as_ptr(&logger) as *mut c_void) as Result] {
Ok(()) => {
self.logger = Some(logger);
Ok(self)
}
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_log_level(mut self, level: LogLevel) -> BuilderResult {
match ortsys![@ort: unsafe SetSessionLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
pub fn with_log_verbosity(mut self, verbosity: c_int) -> BuilderResult {
match ortsys![@ort: unsafe SetSessionLogVerbosityLevel(self.ptr_mut(), verbosity) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn with_auto_device(mut self, policy: AutoDevicePolicy) -> BuilderResult {
match ortsys![@ort: unsafe SessionOptionsSetEpSelectionPolicy(self.ptr_mut(), policy.into()) as Result] {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn with_devices<'e>(
mut self,
devices: impl IntoIterator<Item = crate::device::Device<'e>>,
options: Option<&[(alloc::string::String, alloc::string::String)]>
) -> BuilderResult {
use alloc::vec::Vec;
use smallvec::SmallVec;
use crate::util::{MiniMap, with_cstr_ptr_array};
#[derive(Default)]
struct DeviceGroup<'o> {
device_ptrs: SmallVec<[*const ort_sys::OrtEpDevice; 2]>,
option_keys: Vec<&'o str>,
option_values: Vec<&'o str>
}
let existing_devices: SmallVec<[_; 4]> = self.environment.devices().map(|x| x.ptr()).collect();
let mut device_groups = MiniMap::<&str, DeviceGroup<'_>>::new();
let mut group_prefix = [0u8; 128];
for device in devices {
let ptr = device.ptr();
if !existing_devices.contains(&ptr) {
return Err(Error::new("device comes from different environment").with_recover(self));
}
let group = device.ep().expect("invalid utf-8");
group_prefix[..group.len()].copy_from_slice(group.as_bytes());
group_prefix[group.len()] = b'.';
let group_prefix = unsafe { core::str::from_utf8_unchecked(core::slice::from_raw_parts(group_prefix.as_ptr(), group.len() + 1)) };
let group = device_groups.get_or_insert_with(group, DeviceGroup::default);
group.device_ptrs.push(ptr);
if let Some(options) = options {
for (key, value) in options.iter() {
if let Some(real_key) = key.strip_prefix(group_prefix) {
group.option_keys.push(real_key);
group.option_values.push(value.as_str());
}
}
}
}
for (_, group) in device_groups.iter() {
let ptr = self.ptr_mut();
let env_ptr = self.environment.ptr().cast_mut();
if let Err(e) = with_cstr_ptr_array(&group.option_keys, &|option_keys| {
with_cstr_ptr_array(&group.option_values, &|option_values| {
ortsys![unsafe SessionOptionsAppendExecutionProvider_V2(
ptr,
env_ptr,
group.device_ptrs.as_ptr(),
group.device_ptrs.len(),
option_keys.as_ptr(),
option_values.as_ptr(),
option_keys.len()
)?];
Ok(())
})
}) {
return Err(e.with_recover(self));
}
}
Ok(self)
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub enum GraphOptimizationLevel {
Disable,
Level1,
#[rustfmt::skip]
Level2,
Level3,
All
}
impl From<GraphOptimizationLevel> for ort_sys::GraphOptimizationLevel {
fn from(val: GraphOptimizationLevel) -> Self {
match val {
GraphOptimizationLevel::Disable => ort_sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
GraphOptimizationLevel::Level1 => ort_sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
GraphOptimizationLevel::Level2 => ort_sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
GraphOptimizationLevel::Level3 => ort_sys::GraphOptimizationLevel::ORT_ENABLE_LAYOUT,
GraphOptimizationLevel::All => ort_sys::GraphOptimizationLevel::ORT_ENABLE_ALL
}
}
}
#[derive(Debug)]
pub(crate) struct PrepackedWeightsInner(*mut ort_sys::OrtPrepackedWeightsContainer);
unsafe impl Send for PrepackedWeightsInner {}
unsafe impl Sync for PrepackedWeightsInner {}
impl Drop for PrepackedWeightsInner {
fn drop(&mut self) {
ortsys![unsafe ReleasePrepackedWeightsContainer(self.0)];
crate::logging::drop!(PrepackedWeights, self.0);
}
}
#[derive(Debug, Clone)]
pub struct PrepackedWeights {
pub(crate) inner: Arc<PrepackedWeightsInner>
}
impl PrepackedWeights {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut ptr: *mut ort_sys::OrtPrepackedWeightsContainer = ptr::null_mut();
ortsys![unsafe CreatePrepackedWeightsContainer(&mut ptr).expect("Failed to create prepacked weights container")];
crate::logging::create!(PrepackedWeights, ptr);
Self {
inner: Arc::new(PrepackedWeightsInner(ptr))
}
}
}
impl AsPointer for PrepackedWeights {
type Sys = ort_sys::OrtPrepackedWeightsContainer;
fn ptr(&self) -> *const Self::Sys {
self.inner.0
}
fn ptr_mut(&mut self) -> *mut Self::Sys {
self.inner.0
}
}
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AutoDevicePolicy {
#[default]
Default,
PreferCPU,
PreferNPU,
PreferGPU,
MaxPerformance,
MaxEfficiency,
MinPower
}
#[cfg(feature = "api-22")]
impl From<AutoDevicePolicy> for ort_sys::OrtExecutionProviderDevicePolicy {
fn from(val: AutoDevicePolicy) -> Self {
match val {
AutoDevicePolicy::Default => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_DEFAULT,
AutoDevicePolicy::PreferCPU => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_CPU,
AutoDevicePolicy::PreferNPU => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU,
AutoDevicePolicy::PreferGPU => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_GPU,
AutoDevicePolicy::MaxPerformance => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE,
AutoDevicePolicy::MaxEfficiency => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY,
AutoDevicePolicy::MinPower => ort_sys::OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER
}
}
}