use std::{
ffi::CString,
sync::{atomic::AtomicPtr, Arc, Mutex},
};
use lazy_static::lazy_static;
use tracing::{debug, warn};
use onnxruntime_sys as sys;
use crate::{
error::{status_to_result, OrtError, Result},
g_ort,
onnxruntime::custom_logger,
session::SessionBuilder,
LoggingLevel,
};
lazy_static! {
static ref G_ENV: Arc<Mutex<EnvironmentSingleton>> =
Arc::new(Mutex::new(EnvironmentSingleton {
name: String::from("uninitialized"),
env_ptr: AtomicPtr::new(std::ptr::null_mut()),
}));
}
#[derive(Debug)]
struct EnvironmentSingleton {
name: String,
env_ptr: AtomicPtr<sys::OrtEnv>,
}
#[derive(Debug, Clone)]
pub struct Environment {
env: Arc<Mutex<EnvironmentSingleton>>,
}
impl Environment {
pub fn builder() -> EnvBuilder {
EnvBuilder {
name: "default".into(),
log_level: LoggingLevel::Warning,
}
}
pub fn name(&self) -> String {
self.env.lock().unwrap().name.to_string()
}
pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
*self.env.lock().unwrap().env_ptr.get_mut()
}
#[tracing::instrument]
fn new(name: String, log_level: LoggingLevel) -> Result<Environment> {
let mut environment_guard = G_ENV
.lock()
.expect("Failed to acquire lock: another thread panicked?");
let g_env_ptr = environment_guard.env_ptr.get_mut();
if g_env_ptr.is_null() {
debug!("Environment not yet initialized, creating a new one.");
let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(name.clone()).unwrap();
let create_env_with_custom_logger = g_ort().CreateEnvWithCustomLogger.unwrap();
let status = {
unsafe {
create_env_with_custom_logger(
logging_function,
logger_param,
log_level.into(),
cname.as_ptr(),
&mut env_ptr,
)
}
};
status_to_result(status).map_err(OrtError::Environment)?;
debug!(
env_ptr = format!("{:?}", env_ptr).as_str(),
"Environment created."
);
*g_env_ptr = env_ptr;
environment_guard.name = name;
Ok(Environment { env: G_ENV.clone() })
} else {
warn!(
name = environment_guard.name.as_str(),
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
"Environment already initialized, reusing it.",
);
Ok(Environment { env: G_ENV.clone() })
}
}
pub fn new_session_builder(&self) -> Result<SessionBuilder> {
SessionBuilder::new(self)
}
}
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
debug!(
global_arc_count = Arc::strong_count(&G_ENV),
"Dropping the Environment.",
);
let mut environment_guard = self
.env
.lock()
.expect("Failed to acquire lock: another thread panicked?");
if Arc::strong_count(&G_ENV) == 2 {
let release_env = g_ort().ReleaseEnv.unwrap();
let env_ptr: *mut sys::OrtEnv = *environment_guard.env_ptr.get_mut();
debug!(
global_arc_count = Arc::strong_count(&G_ENV),
"Releasing the Environment.",
);
assert_ne!(env_ptr, std::ptr::null_mut());
unsafe { release_env(env_ptr) };
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
environment_guard.name = String::from("uninitialized");
}
}
}
pub struct EnvBuilder {
name: String,
log_level: LoggingLevel,
}
impl EnvBuilder {
pub fn with_name<S>(mut self, name: S) -> EnvBuilder
where
S: Into<String>,
{
self.name = name.into();
self
}
pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder {
self.log_level = log_level;
self
}
pub fn build(self) -> Result<Environment> {
Environment::new(self.name, self.log_level)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{RwLock, RwLockWriteGuard};
use test_env_log::test;
impl G_ENV {
fn is_initialized(&self) -> bool {
Arc::strong_count(self) >= 2
}
fn env_ptr(&self) -> *const sys::OrtEnv {
*self.lock().unwrap().env_ptr.get_mut()
}
}
struct ConcurrentTestRun {
lock: Arc<RwLock<()>>,
}
lazy_static! {
static ref CONCURRENT_TEST_RUN: ConcurrentTestRun = ConcurrentTestRun {
lock: Arc::new(RwLock::new(()))
};
}
impl CONCURRENT_TEST_RUN {
fn single_test_run(&self) -> RwLockWriteGuard<()> {
self.lock.write().unwrap()
}
}
#[test]
fn env_is_initialized() {
let _run_lock = CONCURRENT_TEST_RUN.single_test_run();
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
let env = Environment::builder()
.with_name("env_is_initialized")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert!(G_ENV.is_initialized());
assert_ne!(G_ENV.env_ptr(), std::ptr::null_mut());
std::mem::drop(env);
assert!(!G_ENV.is_initialized());
assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
}
#[ignore]
#[test]
fn sequential_environment_creation() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let mut prev_env_ptr = G_ENV.env_ptr();
for i in 0..10 {
let name = format!("sequential_environment_creation: {}", i);
let env = Environment::builder()
.with_name(name.clone())
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let next_env_ptr = G_ENV.env_ptr();
assert_ne!(next_env_ptr, prev_env_ptr);
prev_env_ptr = next_env_ptr;
assert_eq!(env.name(), name);
}
}
#[test]
fn concurrent_environment_creations() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
let initial_name = String::from("concurrent_environment_creation");
let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning).unwrap();
let main_env_ptr = main_env.env_ptr() as usize;
let children: Vec<_> = (0..10)
.map(|t| {
let initial_name_cloned = initial_name.clone();
std::thread::spawn(move || {
let name = format!("concurrent_environment_creation: {}", t);
let env = Environment::builder()
.with_name(name)
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
assert_eq!(env.name(), initial_name_cloned);
assert_eq!(env.env_ptr() as usize, main_env_ptr);
})
})
.collect();
assert_eq!(main_env.name(), initial_name);
assert_eq!(main_env.env_ptr() as usize, main_env_ptr);
let res: Vec<std::thread::Result<_>> =
children.into_iter().map(|child| child.join()).collect();
assert!(res.into_iter().all(|r| std::result::Result::is_ok(&r)));
}
}