use alloc::{boxed::Box, string::String, sync::Arc};
use core::{
any::Any,
ffi::c_void,
fmt,
mem::forget,
ptr::{self, NonNull}
};
#[cfg(all(feature = "api-22", feature = "std"))]
use std::path::Path;
use smallvec::SmallVec;
#[cfg(feature = "api-22")]
use crate::ep::ExecutionProviderLibrary;
use crate::{
AsPointer,
ep::ExecutionProviderDispatch,
error::Result,
logging::{LogLevel, LoggerFunction},
ortsys,
util::{Mutex, OnceLock, STACK_EXECUTION_PROVIDERS, run_on_drop, with_cstr}
};
static G_ENV: Mutex<Option<Arc<Environment>>> = Mutex::new(None);
#[cfg_attr(any(target_os = "linux", target_os = "android"), unsafe(link_section = ".text.exit"))]
unsafe extern "C" fn release_env_on_exit(#[cfg(target_vendor = "apple")] _: *const ()) {
G_ENV.lock().take();
}
#[used]
#[cfg(all(not(windows), not(target_vendor = "apple"), not(target_arch = "wasm32")))]
#[unsafe(link_section = ".fini_array")]
static _ON_EXIT: unsafe extern "C" fn() = release_env_on_exit;
#[used]
#[cfg(windows)]
#[unsafe(link_section = ".CRT$XLB")]
static _ON_EXIT: unsafe extern "system" fn(module: *mut (), reason: u32, reserved: *mut ()) = {
unsafe extern "system" fn on_exit(_h: *mut (), reason: u32, _pv: *mut ()) {
if reason == 0 {
unsafe { release_env_on_exit() };
}
}
on_exit
};
#[cfg(target_vendor = "apple")]
fn register_atexit() {
unsafe extern "C" {
static __dso_handle: *const ();
fn __cxa_atexit(cb: unsafe extern "C" fn(_: *const ()), arg: *const (), dso_handle: *const ());
}
unsafe { __cxa_atexit(release_env_on_exit, core::ptr::null(), __dso_handle) };
}
static G_ENV_OPTIONS: OnceLock<EnvironmentBuilder> = OnceLock::new();
pub struct Environment {
execution_providers: SmallVec<[ExecutionProviderDispatch; STACK_EXECUTION_PROVIDERS]>,
ptr: NonNull<ort_sys::OrtEnv>,
has_global_threadpool: bool,
_thread_manager: Option<Arc<dyn Any>>,
_logger: Option<LoggerFunction>
}
unsafe impl Send for Environment {}
unsafe impl Sync for Environment {}
impl Environment {
pub fn current() -> Result<Arc<Environment>> {
self::current()
}
pub fn set_log_level(&self, level: LogLevel) {
ortsys![unsafe UpdateEnvWithCustomLogLevel(self.ptr().cast_mut(), level.into()).expect("infallible")];
}
pub fn execution_providers(&self) -> &[ExecutionProviderDispatch] {
&self.execution_providers
}
#[cfg(all(feature = "api-22", feature = "std"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "api-22", feature = "std"))))]
pub fn register_ep_library<P: AsRef<Path>>(self: &Arc<Self>, name: impl Into<String>, path: P) -> Result<ExecutionProviderLibrary> {
let name = name.into();
let path = crate::util::path_to_os_char(path);
with_cstr(name.as_bytes(), &|name| {
ortsys![unsafe RegisterExecutionProviderLibrary(self.ptr().cast_mut(), name.as_ptr(), path.as_ptr())?];
Ok(())
})?;
Ok(ExecutionProviderLibrary::new(name, self))
}
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn devices(&self) -> impl DoubleEndedIterator<Item = crate::device::Device<'_>> + '_ {
let mut ptrs = ptr::dangling();
let mut len = 0;
let _ = ortsys![@ort: unsafe GetEpDevices(self.ptr().cast_mut(), &mut ptrs, &mut len) as Result];
unsafe { core::slice::from_raw_parts(ptrs, len) }
.iter()
.filter_map(|c| NonNull::new(c.cast_mut()))
.map(crate::device::Device::new)
}
#[inline]
pub(crate) fn has_global_threadpool(&self) -> bool {
self.has_global_threadpool
}
}
impl fmt::Debug for Environment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Environment").field("ptr", &self.ptr).finish_non_exhaustive()
}
}
impl AsPointer for Environment {
type Sys = ort_sys::OrtEnv;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for Environment {
fn drop(&mut self) {
ortsys![unsafe ReleaseEnv(self.ptr_mut())];
crate::logging::drop!(Environment, self.ptr());
}
}
pub fn current() -> Result<Arc<Environment>> {
let mut env_lock = G_ENV.lock();
if let Some(env) = env_lock.as_ref() {
return Ok(env.clone());
}
let options = G_ENV_OPTIONS.get_or_init(EnvironmentBuilder::new);
let env = options.create_environment().map(Arc::new)?;
*env_lock = Some(Arc::clone(&env));
#[cfg(target_vendor = "apple")]
register_atexit();
Ok(env)
}
#[derive(Debug)]
pub struct GlobalThreadPoolOptions {
ptr: *mut ort_sys::OrtThreadingOptions,
thread_manager: Option<Arc<dyn Any>>
}
unsafe impl Send for GlobalThreadPoolOptions {}
unsafe impl Sync for GlobalThreadPoolOptions {}
impl Default for GlobalThreadPoolOptions {
fn default() -> Self {
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut ptr).expect("failed to create threading options")];
crate::logging::create!(GlobalThreadPoolOptions, ptr);
Self { ptr, thread_manager: None }
}
}
impl GlobalThreadPoolOptions {
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetGlobalInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
Ok(self)
}
pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetGlobalIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
Ok(self)
}
pub fn with_spin_control(mut self, spin_control: bool) -> Result<Self> {
ortsys![unsafe SetGlobalSpinControl(self.ptr_mut(), if spin_control { 1 } else { 0 })?];
Ok(self)
}
pub fn with_intra_affinity(mut self, affinity: impl AsRef<str>) -> Result<Self> {
let ptr = self.ptr_mut();
with_cstr(affinity.as_ref().as_bytes(), &|affinity| {
ortsys![unsafe SetGlobalIntraOpThreadAffinity(ptr, affinity.as_ptr())?];
Ok(())
})?;
Ok(self)
}
pub fn with_flush_to_zero(mut self) -> Result<Self> {
ortsys![unsafe SetGlobalDenormalAsZero(self.ptr_mut())?];
Ok(self)
}
pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
let manager = Arc::new(manager);
ortsys![unsafe SetGlobalCustomThreadCreationOptions(self.ptr_mut(), (&*manager as *const T as *mut T).cast())?];
ortsys![unsafe SetGlobalCustomCreateThreadFn(self.ptr_mut(), Some(thread_create::<T>))?];
ortsys![unsafe SetGlobalCustomJoinThreadFn(self.ptr_mut(), Some(thread_join::<T>))?];
self.thread_manager = Some(manager as Arc<dyn Any>);
Ok(self)
}
}
impl AsPointer for GlobalThreadPoolOptions {
type Sys = ort_sys::OrtThreadingOptions;
fn ptr(&self) -> *const Self::Sys {
self.ptr
}
}
impl Drop for GlobalThreadPoolOptions {
fn drop(&mut self) {
ortsys![unsafe ReleaseThreadingOptions(self.ptr)];
crate::logging::drop!(GlobalThreadPoolOptions, self.ptr);
}
}
pub trait ThreadManager {
type Thread;
fn create(&self, work: impl FnOnce() + Send + 'static) -> crate::Result<Self::Thread>;
fn join(thread: Self::Thread) -> crate::Result<()>;
}
pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
ort_custom_thread_creation_options: *mut c_void,
ort_thread_worker_fn: ort_sys::OrtThreadWorkerFn,
ort_worker_fn_param: *mut c_void
) -> ort_sys::OrtCustomThreadHandle {
struct SendablePtr(*mut c_void);
unsafe impl Send for SendablePtr {}
let ort_worker_fn_param = SendablePtr(ort_worker_fn_param);
let runner = || {
let manager = unsafe { &mut *ort_custom_thread_creation_options.cast::<T>() };
<T as ThreadManager>::create(manager, move || {
let p = ort_worker_fn_param;
unsafe { (ort_thread_worker_fn)(p.0) }
})
};
#[cfg(not(feature = "std"))]
let res = Result::<_, crate::Error>::Ok(runner()); #[cfg(feature = "std")]
let res = std::panic::catch_unwind(runner);
match res {
Ok(Ok(thread)) => (Box::leak(Box::new(thread)) as *mut <T as ThreadManager>::Thread)
.cast_const()
.cast::<ort_sys::OrtCustomHandleType>(),
Ok(Err(e)) => {
crate::error!("Failed to create thread using manager: {e}");
let _ = e;
ptr::null()
}
Err(e) => {
crate::error!("Thread manager panicked: {e:?}");
let _ = e;
ptr::null()
}
}
}
pub(crate) unsafe extern "system" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
let handle = unsafe { Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<<T as ThreadManager>::Thread>()) };
if let Err(e) = <T as ThreadManager>::join(*handle) {
crate::error!("Failed to join thread using manager: {e}");
let _ = e;
}
}
pub struct EnvironmentBuilder {
name: String,
telemetry: bool,
execution_providers: SmallVec<[ExecutionProviderDispatch; STACK_EXECUTION_PROVIDERS]>,
global_thread_pool_options: Option<GlobalThreadPoolOptions>,
logger: Option<LoggerFunction>
}
impl EnvironmentBuilder {
pub(crate) fn new() -> Self {
EnvironmentBuilder {
name: String::from("default"),
telemetry: true,
execution_providers: SmallVec::new(),
global_thread_pool_options: None,
logger: None
}
}
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_name<S>(mut self, name: S) -> Self
where
S: Into<String>
{
self.name = name.into();
self
}
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_telemetry(mut self, enable: bool) -> Self {
self.telemetry = enable;
self
}
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self {
self.execution_providers = execution_providers.as_ref().into();
self
}
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_global_thread_pool(mut self, options: GlobalThreadPoolOptions) -> Self {
self.global_thread_pool_options = Some(options);
self
}
pub fn with_logger(mut self, logger: LoggerFunction) -> Self {
self.logger = Some(logger);
self
}
pub(crate) fn create_environment(&self) -> Result<Environment> {
let logger = self
.logger
.as_ref()
.map(|c| (crate::logging::custom_logger as ort_sys::OrtLoggingFunction, c as *const _ as *mut c_void));
#[cfg(feature = "tracing")]
let logger = logger.or(Some((crate::logging::tracing_logger, ptr::null_mut())));
let env_ptr = with_cstr(self.name.as_bytes(), &|name| {
let mut env_ptr: *mut ort_sys::OrtEnv = ptr::null_mut();
#[allow(clippy::collapsible_else_if)]
if let Some(thread_pool_options) = self.global_thread_pool_options.as_ref() {
if let Some((log_fn, log_ptr)) = logger {
ortsys![
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
log_fn,
log_ptr,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
name.as_ptr(),
thread_pool_options.ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
Ok(env_ptr)
} else {
ortsys![
unsafe CreateEnvWithGlobalThreadPools(
crate::logging::default_log_level(),
name.as_ptr(),
thread_pool_options.ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
Ok(env_ptr)
}
} else {
if let Some((log_fn, log_ptr)) = logger {
ortsys![
unsafe CreateEnvWithCustomLogger(
log_fn,
log_ptr,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
name.as_ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
Ok(env_ptr)
} else {
ortsys![
unsafe CreateEnv(
crate::logging::default_log_level(),
name.as_ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
Ok(env_ptr)
}
}
})?;
let _guard = run_on_drop(|| ortsys![unsafe ReleaseEnv(env_ptr.as_ptr())]);
if self.telemetry {
ortsys![unsafe EnableTelemetryEvents(env_ptr.as_ptr())?];
} else {
ortsys![unsafe DisableTelemetryEvents(env_ptr.as_ptr())?];
}
forget(_guard);
crate::logging::create!(Environment, env_ptr);
Ok(Environment {
execution_providers: self.execution_providers.clone(),
ptr: env_ptr,
has_global_threadpool: self.global_thread_pool_options.is_some(),
_thread_manager: self
.global_thread_pool_options
.as_ref()
.and_then(|options| options.thread_manager.clone()),
_logger: self.logger.clone()
})
}
pub fn commit(self) -> bool {
G_ENV_OPTIONS.try_insert_with(|| self)
}
}
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn init() -> EnvironmentBuilder {
EnvironmentBuilder::new()
}
#[cfg(all(feature = "load-dynamic", not(target_arch = "wasm32")))]
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn init_from<P: AsRef<std::path::Path>>(path: P) -> Result<EnvironmentBuilder> {
crate::load_dylib_from_path(path.as_ref())?;
Ok(EnvironmentBuilder::new())
}