onnxruntime_ng/
environment.rs

1//! Module containing environment types
2
3use std::{
4    ffi::CString,
5    sync::{atomic::AtomicPtr, Arc, Mutex},
6};
7
8use tracing::{debug, error, warn};
9
10use lazy_static::lazy_static;
11
12use onnxruntime_sys_ng as sys;
13
14use crate::{
15    error::{status_to_result, OrtError, Result},
16    g_ort,
17    onnxruntime::custom_logger,
18    session::SessionBuilder,
19    LoggingLevel,
20};
21
22lazy_static! {
23    static ref G_ENV: Arc<Mutex<EnvironmentSingleton>> =
24        Arc::new(Mutex::new(EnvironmentSingleton {
25            name: String::from("uninitialized"),
26            env_ptr: AtomicPtr::new(std::ptr::null_mut()),
27        }));
28}
29
30#[derive(Debug)]
31struct EnvironmentSingleton {
32    name: String,
33    env_ptr: AtomicPtr<sys::OrtEnv>,
34}
35
36/// An [`Environment`](session/struct.Environment.html) is the main entry point of the ONNX Runtime.
37///
38/// Only one ONNX environment can be created per process. The `onnxruntime` crate
39/// uses a singleton (through `lazy_static!()`) to enforce this.
40///
41/// Once an environment is created, a [`Session`](../session/struct.Session.html)
42/// can be obtained from it.
43///
44/// **NOTE**: While the [`Environment`](environment/struct.Environment.html) constructor takes a `name` parameter
45/// to name the environment, only the first name will be considered if many environments
46/// are created.
47///
48/// # Example
49///
50/// ```no_run
51/// # use std::error::Error;
52/// # use onnxruntime::{environment::Environment, LoggingLevel};
53/// # fn main() -> Result<(), Box<dyn Error>> {
54/// let environment = Environment::builder()
55///     .with_name("test")
56///     .with_log_level(LoggingLevel::Verbose)
57///     .build()?;
58/// # Ok(())
59/// # }
60/// ```
61#[derive(Debug, Clone)]
62pub struct Environment {
63    env: Arc<Mutex<EnvironmentSingleton>>,
64}
65
66impl Environment {
67    /// Create a new environment builder using default values
68    /// (name: `default`, log level: [LoggingLevel::Warning](../enum.LoggingLevel.html#variant.Warning))
69    pub fn builder() -> EnvBuilder {
70        EnvBuilder {
71            name: "default".into(),
72            log_level: LoggingLevel::Warning,
73        }
74    }
75
76    /// Return the name of the current environment
77    pub fn name(&self) -> String {
78        self.env.lock().unwrap().name.to_string()
79    }
80
81    pub(crate) fn env_ptr(&self) -> *const sys::OrtEnv {
82        *self.env.lock().unwrap().env_ptr.get_mut()
83    }
84
85    #[tracing::instrument]
86    fn new(name: String, log_level: LoggingLevel) -> Result<Environment> {
87        // NOTE: Because 'G_ENV' is a lazy_static, locking it will, initially, create
88        //      a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
89        //      Cloning it to embed it inside the 'Environment' to return
90        //      will thus increase the strong count to 2.
91        let mut environment_guard = G_ENV
92            .lock()
93            .expect("Failed to acquire lock: another thread panicked?");
94        let g_env_ptr = environment_guard.env_ptr.get_mut();
95        if g_env_ptr.is_null() {
96            debug!("Environment not yet initialized, creating a new one.");
97
98            let mut env_ptr: *mut sys::OrtEnv = std::ptr::null_mut();
99
100            let logging_function: sys::OrtLoggingFunction = Some(custom_logger);
101            // FIXME: What should go here?
102            let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
103
104            let cname = CString::new(name.clone()).unwrap();
105
106            let create_env_with_custom_logger = g_ort().CreateEnvWithCustomLogger.unwrap();
107            let status = {
108                unsafe {
109                    create_env_with_custom_logger(
110                        logging_function,
111                        logger_param,
112                        log_level.into(),
113                        cname.as_ptr(),
114                        &mut env_ptr,
115                    )
116                }
117            };
118
119            status_to_result(status).map_err(OrtError::Environment)?;
120
121            debug!(
122                env_ptr = format!("{:?}", env_ptr).as_str(),
123                "Environment created."
124            );
125
126            *g_env_ptr = env_ptr;
127            environment_guard.name = name;
128
129            // NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
130            //       If this 'Environment' is the only one in the process, the strong count
131            //       will be 2:
132            //          * one lazy_static 'G_ENV'
133            //          * one inside the 'Environment' returned
134            Ok(Environment { env: G_ENV.clone() })
135        } else {
136            warn!(
137                name = environment_guard.name.as_str(),
138                env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
139                "Environment already initialized, reusing it.",
140            );
141
142            // NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
143            //       If this 'Environment' is the only one in the process, the strong count
144            //       will be 2:
145            //          * one lazy_static 'G_ENV'
146            //          * one inside the 'Environment' returned
147            Ok(Environment { env: G_ENV.clone() })
148        }
149    }
150
151    /// Create a new [`SessionBuilder`](../session/struct.SessionBuilder.html)
152    /// used to create a new ONNX session.
153    pub fn new_session_builder(&self) -> Result<SessionBuilder> {
154        SessionBuilder::new(self)
155    }
156}
157
158impl Drop for Environment {
159    #[tracing::instrument]
160    fn drop(&mut self) {
161        debug!(
162            global_arc_count = Arc::strong_count(&G_ENV),
163            "Dropping the Environment.",
164        );
165
166        let mut environment_guard = self
167            .env
168            .lock()
169            .expect("Failed to acquire lock: another thread panicked?");
170
171        // NOTE: If we drop an 'Environment' we (obviously) have _at least_
172        //       one 'G_ENV' strong count (the one in the 'env' member).
173        //       There is also the "original" 'G_ENV' which is a the lazy_static global.
174        //       If there is no other environment, the strong count should be two and we
175        //       can properly free the sys::OrtEnv pointer.
176        if Arc::strong_count(&G_ENV) == 2 {
177            let release_env = g_ort().ReleaseEnv.unwrap();
178            let env_ptr: *mut sys::OrtEnv = *environment_guard.env_ptr.get_mut();
179
180            debug!(
181                global_arc_count = Arc::strong_count(&G_ENV),
182                "Releasing the Environment.",
183            );
184
185            assert_ne!(env_ptr, std::ptr::null_mut());
186            if env_ptr.is_null() {
187                error!("Environment pointer is null, not dropping!");
188            } else {
189                unsafe { release_env(env_ptr) };
190            }
191
192            environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
193            environment_guard.name = String::from("uninitialized");
194        }
195    }
196}
197
198/// Struct used to build an environment [`Environment`](environment/struct.Environment.html)
199///
200/// This is the crate's main entry point. An environment _must_ be created
201/// as the first step. An [`Environment`](environment/struct.Environment.html) can only be built
202/// using `EnvBuilder` to configure it.
203///
204/// **NOTE**: If the same configuration method (for example [`with_name()`](struct.EnvBuilder.html#method.with_name))
205/// is called multiple times, the last value will have precedence.
206pub struct EnvBuilder {
207    name: String,
208    log_level: LoggingLevel,
209}
210
211impl EnvBuilder {
212    /// Configure the environment with a given name
213    ///
214    /// **NOTE**: Since ONNX can only define one environment per process,
215    /// creating multiple environments using multiple `EnvBuilder` will
216    /// end up re-using the same environment internally; a new one will _not_
217    /// be created. New parameters will be ignored.
218    pub fn with_name<S>(mut self, name: S) -> EnvBuilder
219    where
220        S: Into<String>,
221    {
222        self.name = name.into();
223        self
224    }
225
226    /// Configure the environment with a given log level
227    ///
228    /// **NOTE**: Since ONNX can only define one environment per process,
229    /// creating multiple environments using multiple `EnvBuilder` will
230    /// end up re-using the same environment internally; a new one will _not_
231    /// be created. New parameters will be ignored.
232    pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvBuilder {
233        self.log_level = log_level;
234        self
235    }
236
237    /// Commit the configuration to a new [`Environment`](environment/struct.Environment.html)
238    pub fn build(self) -> Result<Environment> {
239        Environment::new(self.name, self.log_level)
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::sync::{RwLock, RwLockWriteGuard};
247    use test_env_log::test;
248
249    impl G_ENV {
250        fn is_initialized(&self) -> bool {
251            Arc::strong_count(self) >= 2
252        }
253
254        // fn name(&self) -> String {
255        //     *self.lock().unwrap().name.clone()
256        // }
257
258        fn env_ptr(&self) -> *const sys::OrtEnv {
259            *self.lock().unwrap().env_ptr.get_mut()
260        }
261    }
262
263    struct ConcurrentTestRun {
264        lock: Arc<RwLock<()>>,
265    }
266
267    lazy_static! {
268        static ref CONCURRENT_TEST_RUN: ConcurrentTestRun = ConcurrentTestRun {
269            lock: Arc::new(RwLock::new(()))
270        };
271    }
272
273    impl CONCURRENT_TEST_RUN {
274        // fn run(&self) -> std::sync::RwLockReadGuard<()> {
275        //     self.lock.read().unwrap()
276        // }
277        fn single_test_run(&self) -> RwLockWriteGuard<()> {
278            self.lock.write().unwrap()
279        }
280    }
281
282    #[test]
283    fn env_is_initialized() {
284        let _run_lock = CONCURRENT_TEST_RUN.single_test_run();
285
286        assert!(!G_ENV.is_initialized());
287        assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
288
289        let env = Environment::builder()
290            .with_name("env_is_initialized")
291            .with_log_level(LoggingLevel::Warning)
292            .build()
293            .unwrap();
294        assert!(G_ENV.is_initialized());
295        assert_ne!(G_ENV.env_ptr(), std::ptr::null_mut());
296
297        std::mem::drop(env);
298        assert!(!G_ENV.is_initialized());
299        assert_eq!(G_ENV.env_ptr(), std::ptr::null_mut());
300    }
301
302    #[ignore]
303    #[test]
304    fn sequential_environment_creation() {
305        let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
306
307        let mut prev_env_ptr = G_ENV.env_ptr();
308
309        for i in 0..10 {
310            let name = format!("sequential_environment_creation: {}", i);
311            let env = Environment::builder()
312                .with_name(name.clone())
313                .with_log_level(LoggingLevel::Warning)
314                .build()
315                .unwrap();
316            let next_env_ptr = G_ENV.env_ptr();
317            assert_ne!(next_env_ptr, prev_env_ptr);
318            prev_env_ptr = next_env_ptr;
319
320            assert_eq!(env.name(), name);
321        }
322    }
323
324    #[test]
325    fn concurrent_environment_creations() {
326        let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();
327
328        let initial_name = String::from("concurrent_environment_creation");
329        let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning).unwrap();
330        let main_env_ptr = main_env.env_ptr() as usize;
331
332        let children: Vec<_> = (0..10)
333            .map(|t| {
334                let initial_name_cloned = initial_name.clone();
335                std::thread::spawn(move || {
336                    let name = format!("concurrent_environment_creation: {}", t);
337                    let env = Environment::builder()
338                        .with_name(name)
339                        .with_log_level(LoggingLevel::Warning)
340                        .build()
341                        .unwrap();
342
343                    assert_eq!(env.name(), initial_name_cloned);
344                    assert_eq!(env.env_ptr() as usize, main_env_ptr);
345                })
346            })
347            .collect();
348
349        assert_eq!(main_env.name(), initial_name);
350        assert_eq!(main_env.env_ptr() as usize, main_env_ptr);
351
352        let res: Vec<std::thread::Result<_>> =
353            children.into_iter().map(|child| child.join()).collect();
354        assert!(res.into_iter().all(|r| std::result::Result::is_ok(&r)));
355    }
356}