use std::ffi::{CStr, CString};
use std::sync::atomic::AtomicUsize;
use std::{ffi::c_void, ptr::null_mut};
use ort2_sys::{self as ffi, OrtLoggingFunction, OrtLoggingLevel};
use smart_default::SmartDefault;
use tracing::*;
use crate::api::{api, ok};
use crate::error::Result;
#[allow(improper_ctypes_definitions)]
pub type LoggingFunction = unsafe extern "C" fn(
param: *mut ::std::os::raw::c_void,
severity: OrtLoggingLevel,
category: *const ::std::os::raw::c_char,
logid: *const ::std::os::raw::c_char,
code_location: *const ::std::os::raw::c_char,
message: *const ::std::os::raw::c_char,
);
static ENV_REF_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
pub struct Environment {
pub inner: *mut ffi::OrtEnv,
}
impl std::clone::Clone for Environment {
fn clone(&self) -> Self {
ENV_REF_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self { inner: self.inner }
}
}
impl Environment {
pub fn builder() -> EnvironmentBuilder {
EnvironmentBuilder::default()
}
pub fn inner(&self) -> *mut ffi::OrtEnv {
self.inner
}
}
impl Drop for Environment {
fn drop(&mut self) {
if ENV_REF_COUNTER.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) == 1 {
api!(ReleaseEnv, self.inner);
}
}
}
unsafe impl Send for Environment {}
unsafe impl Sync for Environment {}
#[allow(improper_ctypes_definitions)]
unsafe extern "C" fn default_logging_function(
param: *mut ::std::os::raw::c_void,
severity: OrtLoggingLevel,
category: *const ::std::os::raw::c_char,
logid: *const ::std::os::raw::c_char,
code_location: *const ::std::os::raw::c_char,
message: *const ::std::os::raw::c_char,
) {
let category = CStr::from_ptr(category);
let logid = CStr::from_ptr(logid);
let code_location = CStr::from_ptr(code_location);
let message = CStr::from_ptr(message);
match severity {
OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => trace!(
?param,
?severity,
?category,
?logid,
?code_location,
?message
),
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => trace!(
?param,
?severity,
?category,
?logid,
?code_location,
?message
),
OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => warn!(
?param,
?severity,
?category,
?logid,
?code_location,
?message
),
OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => error!(
?param,
?severity,
?category,
?logid,
?code_location,
?message
),
OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => error!(
?param,
?severity,
?category,
?logid,
?code_location,
?message
),
_ => todo!(),
}
}
#[derive(SmartDefault)]
pub struct EnvironmentBuilder {
#[default("default".into())]
logid: String,
#[default(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)]
level: OrtLoggingLevel,
logging_function: OrtLoggingFunction,
#[default(null_mut())]
params: *mut c_void,
}
impl EnvironmentBuilder {
pub fn build(self) -> Result<Environment> {
let Self {
logid,
level,
mut logging_function,
params,
} = self;
if logging_function.is_none() {
logging_function = Some(default_logging_function);
}
trace!(%logid, ?level, ?logging_function, ?params, "create Environment");
let mut inner = null_mut();
let logid = CString::new(logid)?;
ok!(
CreateEnvWithCustomLogger,
logging_function,
params,
level,
logid.as_ptr(),
&mut inner
)?;
ENV_REF_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(Environment { inner })
}
pub fn with_logging_function(mut self, logging_function: LoggingFunction) -> Self {
self.logging_function = Some(logging_function);
self
}
pub fn with_logid(mut self, logid: impl AsRef<str>) -> Self {
self.logid = logid.as_ref().to_owned();
self
}
pub fn with_level(mut self, level: OrtLoggingLevel) -> Self {
self.level = level;
self
}
pub fn with_params(mut self, params: *mut c_void) -> Self {
self.params = params;
self
}
}
#[cfg(test)]
#[test]
fn test_environment_ok() -> Result<()> {
Environment::builder().build()?;
Ok(())
}