use std::collections::HashMap;
use std::{
ffi::CString,
sync::{atomic::AtomicPtr, Arc, Mutex}
};
use lazy_static::lazy_static;
use tracing::{debug, error, warn};
use super::{
custom_logger,
error::{status_to_result, OrtError, OrtResult},
ort, ortsys, sys, ExecutionProvider, LoggingLevel
};
lazy_static! {
static ref G_ENV: Arc<Mutex<EnvironmentSingleton>> = Arc::new(Mutex::new(EnvironmentSingleton {
name: String::from("uninitialized"),
env_ptr: AtomicPtr::new(std::ptr::null_mut())
}));
}
#[derive(Debug)]
struct EnvironmentSingleton {
name: String,
env_ptr: AtomicPtr<sys::OrtEnv>
}
type CreateEnvFunction = fn(&str, LoggingLevel, HashMap<String, String>) -> (sys::OrtStatusPtr, *mut sys::OrtEnv);
#[derive(Debug, Clone)]
pub struct Environment {
env: Arc<Mutex<EnvironmentSingleton>>,
pub(crate) execution_providers: Vec<ExecutionProvider>
}
unsafe impl Send for Environment {}
unsafe impl Sync for Environment {}
impl Environment {
pub fn builder() -> EnvBuilder {
EnvBuilder {
name: "default".into(),
log_level: LoggingLevel::Warning,
execution_providers: Vec::new(),
global_thread_pool_options: vec![]
}
}
pub fn name(&self) -> String {
self.env.lock().unwrap().name.to_string()
}
pub fn into_arc(self) -> Arc<Environment> {
Arc::new(self)
}
pub fn ptr(&self) -> *const sys::OrtEnv {
*self.env.lock().unwrap().env_ptr.get_mut()
}
fn create_custom_log_env(name: &str, log_level: LoggingLevel, _: HashMap<String, String>) -> (sys::OrtStatusPtr, *mut sys::OrtEnv) {
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(name).unwrap();
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, log_level.into(), cname.as_ptr(), &mut env_ptr) };
(status, env_ptr)
}
fn create_global_thread_pool_env(name: &str, log_level: LoggingLevel, mut options: HashMap<String, String>) -> (sys::OrtStatusPtr, *mut sys::OrtEnv) {
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let mut thread_options: *mut sys::OrtThreadingOptions = std::ptr::null_mut();
let cname = CString::new(name).unwrap();
let create_thread_options = ortsys![CreateThreadingOptions];
let release_thread_options = ortsys![ReleaseThreadingOptions];
let create_env_with_global_thread_pool = ortsys![CreateEnvWithCustomLoggerAndGlobalThreadPools];
let set_global_intra_op_thread_affinity = ortsys![SetGlobalIntraOpThreadAffinity];
let set_global_intra_op_num_threads = ortsys![SetGlobalIntraOpNumThreads];
let set_global_inter_op_num_threads = ortsys![SetGlobalInterOpNumThreads];
let set_global_spin_control = ortsys![SetGlobalSpinControl];
unsafe {
create_thread_options(&mut thread_options);
}
options
.remove("inter_op_parallelism")
.map(|v| unsafe { set_global_inter_op_num_threads(thread_options, v.parse::<i32>().unwrap()) });
options
.remove("intra_op_parallelism")
.map(|v| unsafe { set_global_intra_op_num_threads(thread_options, v.parse::<i32>().unwrap()) });
options
.remove("spin_control")
.map(|v| unsafe { set_global_spin_control(thread_options, v.parse::<i32>().unwrap()) }); options.remove("intra_op_thread_affinity").map(|v| unsafe {
let c_str = CString::new(v).unwrap();
set_global_intra_op_thread_affinity(thread_options, c_str.as_ptr())
});
if !options.is_empty() {
warn!("Unknown options passed to create_global_thread_pool_env: {:?}", options);
}
let status =
unsafe { create_env_with_global_thread_pool(logging_function, logger_param, log_level.into(), cname.as_ptr(), thread_options, &mut env_ptr) };
unsafe {
release_thread_options(thread_options);
}
(status, env_ptr)
}
fn new(
name: String,
log_level: LoggingLevel,
execution_providers: Vec<ExecutionProvider>,
create_env_fn: CreateEnvFunction,
options: HashMap<String, String>
) -> OrtResult<Environment> {
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one");
let (status, env_ptr) = create_env_fn(name.as_str(), log_level, options);
status_to_result(status).map_err(OrtError::CreateEnvironment)?;
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
*g_env_ptr = env_ptr;
environment_guard.name = name;
Ok(Environment {
env: G_ENV.clone(),
execution_providers
})
} else {
warn!(
name = environment_guard.name.as_str(),
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
"Environment already initialized for this thread, reusing it",
);
Ok(Environment {
env: G_ENV.clone(),
execution_providers
})
}
}
}
impl Default for Environment {
fn default() -> Self {
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one");
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new("default".to_string()).unwrap();
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr) };
status_to_result(status).map_err(OrtError::CreateEnvironment).unwrap();
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
*g_env_ptr = env_ptr;
environment_guard.name = "default".to_string();
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
} else {
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
}
}
}
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Dropping environment");
let mut environment_guard = self.env.lock().expect("Failed to acquire lock: another thread panicked?");
if Arc::strong_count(&G_ENV) == 2 {
let release_env = ort().ReleaseEnv.unwrap();
let env_ptr: *mut sys::OrtEnv = *environment_guard.env_ptr.get_mut();
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Releasing environment");
assert_ne!(env_ptr, std::ptr::null_mut());
if env_ptr.is_null() {
error!("Environment pointer is null, not dropping!");
} else {
unsafe { release_env(env_ptr) };
}
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
environment_guard.name = String::from("uninitialized");
}
}
}
pub struct EnvBuilder {
name: String,
log_level: LoggingLevel,
execution_providers: Vec<ExecutionProvider>,
global_thread_pool_options: Vec<(String, String)>
}
impl EnvBuilder {
pub fn with_name<S>(mut self, name: S) -> EnvBuilder
where
S: Into<String>
{
self.name = name.into();
self
}
pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder {
self.log_level = log_level;
self
}
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProvider]>) -> EnvBuilder {
self.execution_providers = execution_providers.as_ref().to_vec();
self
}
pub fn with_global_thread_pool(mut self, options: Vec<(String, String)>) -> EnvBuilder {
self.global_thread_pool_options = options;
self
}
pub fn build(self) -> OrtResult<Environment> {
if self.global_thread_pool_options.is_empty() {
Environment::new(self.name, self.log_level, self.execution_providers, Environment::create_custom_log_env, vec![].into_iter().collect())
} else {
Environment::new(
self.name,
self.log_level,
self.execution_providers,
Environment::create_global_thread_pool_env,
self.global_thread_pool_options.into_iter().collect()
)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{RwLock, RwLockWriteGuard};
use test_log::test;
use super::*;
impl G_ENV {
fn is_initialized(&self) -> bool {
Arc::strong_count(self) >= 2
}
fn env_ptr(&self) -> *const sys::OrtEnv {
*self.lock().unwrap().env_ptr.get_mut()
}
}
struct ConcurrentTestRun {
lock: Arc<RwLock<()>>
}
lazy_static! {
static ref CONCURRENT_TEST_RUN: ConcurrentTestRun = ConcurrentTestRun { lock: Arc::new(RwLock::new(())) };
}
impl CONCURRENT_TEST_RUN {
fn single_test_run(&self) -> RwLockWriteGuard<()> {
self.lock.write().unwrap()
}
}
#[test]
fn env_is_initialized() {
let _run_lock = CONCURRENT_TEST_RUN.single_test_run();
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
let env = Environment::builder()
.with_name("env_is_initialized")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert!(G_ENV.is_initialized());
assert_ne!(G_ENV.env_ptr(), std::ptr::null_mut());
std::mem::drop(env);
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
}
#[ignore]
#[test]
fn sequential_environment_creation() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let mut prev_env_ptr = G_ENV.env_ptr();
for i in 0..10 {
let name = format!("sequential_environment_creation: {}", i);
let env = Environment::builder()
.with_name(name.clone())
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let next_env_ptr = G_ENV.env_ptr();
assert_ne!(next_env_ptr, prev_env_ptr);
prev_env_ptr = next_env_ptr;
assert_eq!(env.name(), name);
}
}
#[test]
fn concurrent_environment_creations() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let initial_name = String::from("concurrent_environment_creation");
let main_env =
Environment::new(initial_name.clone(), LoggingLevel::Warning, Vec::new(), Environment::create_custom_log_env, vec![].into_iter().collect())
.unwrap();
let main_env_ptr = main_env.ptr() as usize;
assert_eq!(main_env.name(), initial_name);
assert_eq!(main_env.ptr() as usize, main_env_ptr);
assert!(
(0..10)
.map(|t| {
let initial_name_cloned = initial_name.clone();
std::thread::spawn(move || {
let name = format!("concurrent_environment_creation: {}", t);
let env = Environment::builder()
.with_name(name)
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert_eq!(env.name(), initial_name_cloned);
assert_eq!(env.ptr() as usize, main_env_ptr);
})
})
.map(|child| child.join())
.all(|r| std::result::Result::is_ok(&r))
);
}
}