onnxruntime_ng/
environment.rs1use std::{
4 ffi::CString,
5 sync::{atomic::AtomicPtr, Arc, Mutex},
6};
7
8use tracing::{debug, error, warn};
9
10use lazy_static::lazy_static;
11
12use onnxruntime_sys_ng as sys;
13
14use crate::{
15 error::{status_to_result, OrtError, Result},
16 g_ort,
17 onnxruntime::custom_logger,
18 session::SessionBuilder,
19 LoggingLevel,
20};
21
22lazy_static! {
23 static ref G_ENV: Arc<Mutex<EnvironmentSingleton>> =
24 Arc::new(Mutex::new(EnvironmentSingleton {
25 name: String::from("uninitialized"),
26 env_ptr: AtomicPtr::new(std::ptr::null_mut()),
27 }));
28}
29
30#[derive(Debug)]
31struct EnvironmentSingleton {
32 name: String,
33 env_ptr: AtomicPtr<sys::OrtEnv>,
34}
35
36#[derive(Debug, Clone)]
62pub struct Environment {
63 env: Arc<Mutex<EnvironmentSingleton>>,
64}
65
66impl Environment {
67 pub fn builder() -> EnvBuilder {
70 EnvBuilder {
71 name: "default".into(),
72 log_level: LoggingLevel::Warning,
73 }
74 }
75
76 pub fn name(&self) -> String {
78 self.env.lock().unwrap().name.to_string()
79 }
80
81 pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
82 *self.env.lock().unwrap().env_ptr.get_mut()
83 }
84
85 #[tracing::instrument]
86 fn new(name: String, log_level: LoggingLevel) -> Result<Environment> {
87 let mut environment_guard = G_ENV
92 .lock()
93 .expect("Failed to acquire lock: another thread panicked?");
94 let g_env_ptr = environment_guard.env_ptr.get_mut();
95 if g_env_ptr.is_null() {
96 debug!("Environment not yet initialized, creating a new one.");
97
98 let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
99
100 let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
101 let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
103
104 let cname = CString::new(name.clone()).unwrap();
105
106 let create_env_with_custom_logger = g_ort().CreateEnvWithCustomLogger.unwrap();
107 let status = {
108 unsafe {
109 create_env_with_custom_logger(
110 logging_function,
111 logger_param,
112 log_level.into(),
113 cname.as_ptr(),
114 &mut env_ptr,
115 )
116 }
117 };
118
119 status_to_result(status).map_err(OrtError::Environment)?;
120
121 debug!(
122 env_ptr = format!("{:?}", env_ptr).as_str(),
123 "Environment created."
124 );
125
126 *g_env_ptr = env_ptr;
127 environment_guard.name = name;
128
129 Ok(Environment { env: G_ENV.clone() })
135 } else {
136 warn!(
137 name = environment_guard.name.as_str(),
138 env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
139 "Environment already initialized, reusing it.",
140 );
141
142 Ok(Environment { env: G_ENV.clone() })
148 }
149 }
150
151 pub fn new_session_builder(&self) -> Result<SessionBuilder> {
154 SessionBuilder::new(self)
155 }
156}
157
158impl Drop for Environment {
159 #[tracing::instrument]
160 fn drop(&mut self) {
161 debug!(
162 global_arc_count = Arc::strong_count(&G_ENV),
163 "Dropping the Environment.",
164 );
165
166 let mut environment_guard = self
167 .env
168 .lock()
169 .expect("Failed to acquire lock: another thread panicked?");
170
171 if Arc::strong_count(&G_ENV) == 2 {
177 let release_env = g_ort().ReleaseEnv.unwrap();
178 let env_ptr: *mut sys::OrtEnv = *environment_guard.env_ptr.get_mut();
179
180 debug!(
181 global_arc_count = Arc::strong_count(&G_ENV),
182 "Releasing the Environment.",
183 );
184
185 assert_ne!(env_ptr, std::ptr::null_mut());
186 if env_ptr.is_null() {
187 error!("Environment pointer is null, not dropping!");
188 } else {
189 unsafe { release_env(env_ptr) };
190 }
191
192 environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
193 environment_guard.name = String::from("uninitialized");
194 }
195 }
196}
197
198pub struct EnvBuilder {
207 name: String,
208 log_level: LoggingLevel,
209}
210
211impl EnvBuilder {
212 pub fn with_name<S>(mut self, name: S) -> EnvBuilder
219 where
220 S: Into<String>,
221 {
222 self.name = name.into();
223 self
224 }
225
226 pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder {
233 self.log_level = log_level;
234 self
235 }
236
237 pub fn build(self) -> Result<Environment> {
239 Environment::new(self.name, self.log_level)
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::sync::{RwLock, RwLockWriteGuard};
247 use test_env_log::test;
248
249 impl G_ENV {
250 fn is_initialized(&self) -> bool {
251 Arc::strong_count(self) >= 2
252 }
253
254 fn env_ptr(&self) -> *const sys::OrtEnv {
259 *self.lock().unwrap().env_ptr.get_mut()
260 }
261 }
262
263 struct ConcurrentTestRun {
264 lock: Arc<RwLock<()>>,
265 }
266
267 lazy_static! {
268 static ref CONCURRENT_TEST_RUN: ConcurrentTestRun = ConcurrentTestRun {
269 lock: Arc::new(RwLock::new(()))
270 };
271 }
272
273 impl CONCURRENT_TEST_RUN {
274 fn single_test_run(&self) -> RwLockWriteGuard<()> {
278 self.lock.write().unwrap()
279 }
280 }
281
282 #[test]
283 fn env_is_initialized() {
284 let _run_lock = CONCURRENT_TEST_RUN.single_test_run();
285
286 assert!(!G_ENV.is_initialized());
287 assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
288
289 let env = Environment::builder()
290 .with_name("env_is_initialized")
291 .with_log_level(LoggingLevel::Warning)
292 .build()
293 .unwrap();
294 assert!(G_ENV.is_initialized());
295 assert_ne!(G_ENV.env_ptr(), std::ptr::null_mut());
296
297 std::mem::drop(env);
298 assert!(!G_ENV.is_initialized());
299 assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
300 }
301
302 #[ignore]
303 #[test]
304 fn sequential_environment_creation() {
305 let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
306
307 let mut prev_env_ptr = G_ENV.env_ptr();
308
309 for i in 0..10 {
310 let name = format!("sequential_environment_creation: {}", i);
311 let env = Environment::builder()
312 .with_name(name.clone())
313 .with_log_level(LoggingLevel::Warning)
314 .build()
315 .unwrap();
316 let next_env_ptr = G_ENV.env_ptr();
317 assert_ne!(next_env_ptr, prev_env_ptr);
318 prev_env_ptr = next_env_ptr;
319
320 assert_eq!(env.name(), name);
321 }
322 }
323
324 #[test]
325 fn concurrent_environment_creations() {
326 let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
327
328 let initial_name = String::from("concurrent_environment_creation");
329 let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning).unwrap();
330 let main_env_ptr = main_env.env_ptr() as usize;
331
332 let children: Vec<_> = (0..10)
333 .map(|t| {
334 let initial_name_cloned = initial_name.clone();
335 std::thread::spawn(move || {
336 let name = format!("concurrent_environment_creation: {}", t);
337 let env = Environment::builder()
338 .with_name(name)
339 .with_log_level(LoggingLevel::Warning)
340 .build()
341 .unwrap();
342
343 assert_eq!(env.name(), initial_name_cloned);
344 assert_eq!(env.env_ptr() as usize, main_env_ptr);
345 })
346 })
347 .collect();
348
349 assert_eq!(main_env.name(), initial_name);
350 assert_eq!(main_env.env_ptr() as usize, main_env_ptr);
351
352 let res: Vec<std::thread::Result<_>> =
353 children.into_iter().map(|child| child.join()).collect();
354 assert!(res.into_iter().all(|r| std::result::Result::is_ok(&r)));
355 }
356}