use std::{
ffi::CString,
sync::{atomic::AtomicPtr, Arc, Mutex}
};
use lazy_static::lazy_static;
use tracing::{debug, error, warn};
use super::{
custom_logger,
error::{status_to_result, OrtError, OrtResult},
ort, ortsys, sys, ExecutionProvider, LoggingLevel
};
lazy_static! {
static ref G_ENV: Arc<Mutex<EnvironmentSingleton>> = Arc::new(Mutex::new(EnvironmentSingleton {
name: String::from("uninitialized"),
env_ptr: AtomicPtr::new(std::ptr::null_mut())
}));
}
#[derive(Debug)]
struct EnvironmentSingleton {
name: String,
env_ptr: AtomicPtr<sys::OrtEnv>
}
#[derive(Debug, Clone)]
pub struct Environment {
env: Arc<Mutex<EnvironmentSingleton>>,
pub(crate) execution_providers: Vec<ExecutionProvider>
}
impl Environment {
pub fn builder() -> EnvBuilder {
EnvBuilder {
name: "default".into(),
log_level: LoggingLevel::Warning,
execution_providers: Vec::new()
}
}
pub fn name(&self) -> String {
self.env.lock().unwrap().name.to_string()
}
pub fn into_arc(self) -> Arc<Environment> {
Arc::new(self)
}
pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
*self.env.lock().unwrap().env_ptr.get_mut()
}
fn new(name: String, log_level: LoggingLevel, execution_providers: Vec<ExecutionProvider>) -> OrtResult<Environment> {
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one.");
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(name.clone()).unwrap();
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, log_level.into(), cname.as_ptr(), &mut env_ptr) };
status_to_result(status).map_err(OrtError::CreateEnvironment)?;
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created.");
*g_env_ptr = env_ptr;
environment_guard.name = name;
Ok(Environment {
env: G_ENV.clone(),
execution_providers
})
} else {
warn!(
name = environment_guard.name.as_str(),
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
"Environment already initialized, reusing it.",
);
Ok(Environment {
env: G_ENV.clone(),
execution_providers
})
}
}
}
impl Default for Environment {
fn default() -> Self {
let mut environment_guard = G_ENV.lock().expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one.");
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new("default".to_string()).unwrap();
let create_env_with_custom_logger = ortsys![CreateEnvWithCustomLogger];
let status = unsafe { create_env_with_custom_logger(logging_function, logger_param, LoggingLevel::Warning.into(), cname.as_ptr(), &mut env_ptr) };
status_to_result(status).map_err(OrtError::CreateEnvironment).unwrap();
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created.");
*g_env_ptr = env_ptr;
environment_guard.name = "default".to_string();
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
} else {
Environment {
env: G_ENV.clone(),
execution_providers: vec![]
}
}
}
}
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Dropping the Environment.",);
let mut environment_guard = self.env.lock().expect("Failed to acquire lock: another thread panicked?");
if Arc::strong_count(&G_ENV) == 2 {
let release_env = ort().ReleaseEnv.unwrap();
let env_ptr: *mut sys::OrtEnv = *environment_guard.env_ptr.get_mut();
debug!(global_arc_count = Arc::strong_count(&G_ENV), "Releasing the Environment.",);
assert_ne!(env_ptr, std::ptr::null_mut());
if env_ptr.is_null() {
error!("Environment pointer is null, not dropping!");
} else {
unsafe { release_env(env_ptr) };
}
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
environment_guard.name = String::from("uninitialized");
}
}
}
pub struct EnvBuilder {
name: String,
log_level: LoggingLevel,
execution_providers: Vec<ExecutionProvider>
}
impl EnvBuilder {
pub fn with_name<S>(mut self, name: S) -> EnvBuilder
where
S: Into<String>
{
self.name = name.into();
self
}
pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder {
self.log_level = log_level;
self
}
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProvider]>) -> EnvBuilder {
self.execution_providers = execution_providers.as_ref().to_vec();
self
}
pub fn build(self) -> OrtResult<Environment> {
Environment::new(self.name, self.log_level, self.execution_providers)
}
}
#[cfg(test)]
mod tests {
use std::sync::{RwLock, RwLockWriteGuard};
use test_log::test;
use super::*;
impl G_ENV {
fn is_initialized(&self) -> bool {
Arc::strong_count(self) >= 2
}
fn env_ptr(&self) -> *const sys::OrtEnv {
*self.lock().unwrap().env_ptr.get_mut()
}
}
struct ConcurrentTestRun {
lock: Arc<RwLock<()>>
}
lazy_static! {
static ref CONCURRENT_TEST_RUN: ConcurrentTestRun = ConcurrentTestRun { lock: Arc::new(RwLock::new(())) };
}
impl CONCURRENT_TEST_RUN {
fn single_test_run(&self) -> RwLockWriteGuard<()> {
self.lock.write().unwrap()
}
}
#[test]
fn env_is_initialized() {
let _run_lock = CONCURRENT_TEST_RUN.single_test_run();
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
let env = Environment::builder()
.with_name("env_is_initialized")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert!(G_ENV.is_initialized());
assert_ne!(G_ENV.env_ptr(), std::ptr::null_mut());
std::mem::drop(env);
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
}
#[ignore]
#[test]
fn sequential_environment_creation() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let mut prev_env_ptr = G_ENV.env_ptr();
for i in 0..10 {
let name = format!("sequential_environment_creation: {}", i);
let env = Environment::builder()
.with_name(name.clone())
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let next_env_ptr = G_ENV.env_ptr();
assert_ne!(next_env_ptr, prev_env_ptr);
prev_env_ptr = next_env_ptr;
assert_eq!(env.name(), name);
}
}
#[test]
fn concurrent_environment_creations() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let initial_name = String::from("concurrent_environment_creation");
let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning, Vec::new()).unwrap();
let main_env_ptr = main_env.env_ptr() as usize;
assert_eq!(main_env.name(), initial_name);
assert_eq!(main_env.env_ptr() as usize, main_env_ptr);
assert!(
(0..10)
.map(|t| {
let initial_name_cloned = initial_name.clone();
std::thread::spawn(move || {
let name = format!("concurrent_environment_creation: {}", t);
let env = Environment::builder()
.with_name(name)
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert_eq!(env.name(), initial_name_cloned);
assert_eq!(env.env_ptr() as usize, main_env_ptr);
})
})
.map(|child| child.join())
.all(|r| std::result::Result::is_ok(&r))
);
}
}