ort2 0.1.2

onnxruntime wrapper c/c++ api
Documentation
use std::ffi::{CStr, CString};
use std::sync::atomic::AtomicUsize;
use std::{ffi::c_void, ptr::null_mut};

use ort2_sys::{self as ffi, OrtLoggingFunction, OrtLoggingLevel};
use smart_default::SmartDefault;
use tracing::*;

use crate::api::{api, ok};
use crate::error::Result;

#[allow(improper_ctypes_definitions)]
pub type LoggingFunction = unsafe extern "C" fn(
    param: *mut ::std::os::raw::c_void,
    severity: OrtLoggingLevel,
    category: *const ::std::os::raw::c_char,
    logid: *const ::std::os::raw::c_char,
    code_location: *const ::std::os::raw::c_char,
    message: *const ::std::os::raw::c_char,
);

static ENV_REF_COUNTER: AtomicUsize = AtomicUsize::new(0);

#[derive(Debug)]
pub struct Environment {
    pub inner: *mut ffi::OrtEnv,
}

impl std::clone::Clone for Environment {
    fn clone(&self) -> Self {
        ENV_REF_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
        Self { inner: self.inner }
    }
}

impl Environment {
    pub fn builder() -> EnvironmentBuilder {
        EnvironmentBuilder::default()
    }

    pub fn inner(&self) -> *mut ffi::OrtEnv {
        self.inner
    }
}

impl Drop for Environment {
    fn drop(&mut self) {
        if ENV_REF_COUNTER.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) == 1 {
            api!(ReleaseEnv, self.inner);
        }
    }
}

unsafe impl Send for Environment {}
unsafe impl Sync for Environment {}

#[allow(improper_ctypes_definitions)]
unsafe extern "C" fn default_logging_function(
    param: *mut ::std::os::raw::c_void,
    severity: OrtLoggingLevel,
    category: *const ::std::os::raw::c_char,
    logid: *const ::std::os::raw::c_char,
    code_location: *const ::std::os::raw::c_char,
    message: *const ::std::os::raw::c_char,
) {
    let category = CStr::from_ptr(category);
    let logid = CStr::from_ptr(logid);
    let code_location = CStr::from_ptr(code_location);
    let message = CStr::from_ptr(message);
    match severity {
        OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => trace!(
            ?param,
            ?severity,
            ?category,
            ?logid,
            ?code_location,
            ?message
        ),
        OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => trace!(
            ?param,
            ?severity,
            ?category,
            ?logid,
            ?code_location,
            ?message
        ),
        OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => warn!(
            ?param,
            ?severity,
            ?category,
            ?logid,
            ?code_location,
            ?message
        ),
        OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => error!(
            ?param,
            ?severity,
            ?category,
            ?logid,
            ?code_location,
            ?message
        ),
        OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => error!(
            ?param,
            ?severity,
            ?category,
            ?logid,
            ?code_location,
            ?message
        ),
        _ => todo!(),
    }
}

#[derive(SmartDefault)]
pub struct EnvironmentBuilder {
    #[default("default".into())]
    logid: String,
    #[default(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)]
    level: OrtLoggingLevel,
    logging_function: OrtLoggingFunction,
    #[default(null_mut())]
    params: *mut c_void,
}

impl EnvironmentBuilder {
    pub fn build(self) -> Result<Environment> {
        let Self {
            logid,
            level,
            mut logging_function,
            params,
        } = self;

        if logging_function.is_none() {
            logging_function = Some(default_logging_function);
        }

        trace!(%logid, ?level, ?logging_function, ?params, "create Environment");

        let mut inner = null_mut();

        let logid = CString::new(logid)?;

        ok!(
            CreateEnvWithCustomLogger,
            logging_function,
            params,
            level,
            logid.as_ptr(),
            &mut inner
        )?;

        ENV_REF_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);

        Ok(Environment { inner })
    }

    pub fn with_logging_function(mut self, logging_function: LoggingFunction) -> Self {
        self.logging_function = Some(logging_function);
        self
    }

    pub fn with_logid(mut self, logid: impl AsRef<str>) -> Self {
        self.logid = logid.as_ref().to_owned();
        self
    }

    pub fn with_level(mut self, level: OrtLoggingLevel) -> Self {
        self.level = level;
        self
    }

    pub fn with_params(mut self, params: *mut c_void) -> Self {
        self.params = params;
        self
    }
}

#[cfg(test)]
#[test]
fn test_environment_ok() -> Result<()> {
    Environment::builder().build()?;
    Ok(())
}