pub mod download;
pub mod environment;
pub mod error;
pub mod execution_providers;
pub mod memory;
pub mod metadata;
pub mod session;
pub mod sys;
pub mod tensor;
use std::{
ffi::{self, CStr},
os::raw::c_char,
ptr,
sync::{atomic::AtomicPtr, Arc, Mutex}
};
pub use environment::Environment;
pub use error::{OrtApiError, OrtError, OrtResult};
pub use execution_providers::ExecutionProvider;
use lazy_static::lazy_static;
pub use session::{Session, SessionBuilder};
use self::sys::OnnxEnumInt;
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
}
pub(crate) use extern_system_fn;
lazy_static! {
pub(crate) static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const sys::OrtApi } = unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
};
}
pub fn ort() -> sys::OrtApi {
let mut api_ref = G_ORT_API.lock().expect("failed to acquire OrtApi lock; another thread panicked?");
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
assert_ne!(api_ptr_mut, ptr::null_mut());
unsafe { *api_ptr_mut }
}
macro_rules! ortsys {
($method:tt) => {
$crate::ort().$method.unwrap()
};
(unsafe $method:tt) => {
unsafe { $crate::ort().$method.unwrap() }
};
($method:tt($($n:expr),+ $(,)?)) => {
$crate::ort().$method.unwrap()($($n),+)
};
(unsafe $method:tt($($n:expr),+ $(,)?)) => {
unsafe { $crate::ort().$method.unwrap()($($n),+) }
};
($method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::ort().$method.unwrap()($($n),+);
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
unsafe { $crate::ort().$method.unwrap()($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
($method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
};
(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
};
($method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
}
macro_rules! ortfree {
(unsafe $allocator_ptr:expr, $ptr:tt) => {
unsafe { (*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void) }
};
($allocator_ptr:expr, $ptr:tt) => {
(*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void)
};
}
pub(crate) use ortfree;
pub(crate) use ortsys;
pub(crate) fn char_p_to_string(raw: *const c_char) -> OrtResult<String> {
let c_string = unsafe { CStr::from_ptr(raw as *mut c_char).to_owned() };
match c_string.into_string() {
Ok(string) => Ok(string),
Err(e) => Err(OrtApiError::IntoStringError(e))
}
.map_err(OrtError::FfiStringConversion)
}
#[derive(Debug)]
struct CodeLocation<'a> {
file: &'a str,
line: &'a str,
function: &'a str
}
impl<'a> From<&'a str> for CodeLocation<'a> {
fn from(code_location: &'a str) -> Self {
let mut splitter = code_location.split(' ');
let file_and_line = splitter.next().unwrap_or("<unknown file>:<unknown line>");
let function = splitter.next().unwrap_or("<unknown function>");
let mut file_and_line_splitter = file_and_line.split(':');
let file = file_and_line_splitter.next().unwrap_or("<unknown file>");
let line = file_and_line_splitter.next().unwrap_or("<unknown line>");
CodeLocation { file, line, function }
}
}
extern_system_fn! {
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const c_char, log_id: *const c_char, code_location: *const c_char, message: *const c_char) {
use tracing::{span, Level, trace, debug, warn, info, error};
let log_level = match severity {
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING => Level::INFO,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR => Level::WARN,
sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
_ => Level::TRACE
};
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("unknown");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) };
assert_ne!(log_id, ptr::null());
let log_id = unsafe { CStr::from_ptr(log_id) };
let code_location = CodeLocation::from(code_location);
let span = span!(
Level::TRACE,
"ort",
category = category.to_str().unwrap_or("<unknown>"),
file = code_location.file,
line = code_location.line,
function = code_location.function,
log_id = log_id.to_str().unwrap_or("<unknown>")
);
let _enter = span.enter();
match log_level {
Level::TRACE => trace!("{:?}", message),
Level::DEBUG => debug!("{:?}", message),
Level::INFO => info!("{:?}", message),
Level::WARN => warn!("{:?}", message),
Level::ERROR => error!("{:?}", message)
}
}
}
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum LoggingLevel {
Verbose = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
Info = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
Warning = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
Error = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
Fatal = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt
}
impl From<LoggingLevel> for sys::OrtLoggingLevel {
fn from(logging_level: LoggingLevel) -> Self {
match logging_level {
LoggingLevel::Verbose => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE,
LoggingLevel::Info => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO,
LoggingLevel::Warning => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING,
LoggingLevel::Error => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR,
LoggingLevel::Fatal => sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL
}
}
}
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum GraphOptimizationLevel {
Disable = sys::GraphOptimizationLevel_ORT_DISABLE_ALL as OnnxEnumInt,
Level1 = sys::GraphOptimizationLevel_ORT_ENABLE_BASIC as OnnxEnumInt,
#[rustfmt::skip]
Level2 = sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED as OnnxEnumInt,
Level3 = sys::GraphOptimizationLevel_ORT_ENABLE_ALL as OnnxEnumInt
}
impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
fn from(val: GraphOptimizationLevel) -> Self {
match val {
GraphOptimizationLevel::Disable => sys::GraphOptimizationLevel_ORT_DISABLE_ALL,
GraphOptimizationLevel::Level1 => sys::GraphOptimizationLevel_ORT_ENABLE_BASIC,
GraphOptimizationLevel::Level2 => sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED,
GraphOptimizationLevel::Level3 => sys::GraphOptimizationLevel_ORT_ENABLE_ALL
}
}
}
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum AllocatorType {
Device = sys::OrtAllocatorType_OrtDeviceAllocator,
Arena = sys::OrtAllocatorType_OrtArenaAllocator
}
impl From<AllocatorType> for sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
match val {
AllocatorType::Device => sys::OrtAllocatorType_OrtDeviceAllocator,
AllocatorType::Arena => sys::OrtAllocatorType_OrtArenaAllocator
}
}
}
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum MemType {
CPUInput = sys::OrtMemType_OrtMemTypeCPUInput,
CPUOutput = sys::OrtMemType_OrtMemTypeCPUOutput,
Default = sys::OrtMemType_OrtMemTypeDefault
}
impl MemType {
pub const CPU: MemType = MemType::CPUOutput;
}
impl From<MemType> for sys::OrtMemType {
fn from(val: MemType) -> Self {
match val {
MemType::CPUInput => sys::OrtMemType_OrtMemTypeCPUInput,
MemType::CPUOutput => sys::OrtMemType_OrtMemTypeCPUOutput,
MemType::Default => sys::OrtMemType_OrtMemTypeDefault
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_char_p_to_string() {
let s = ffi::CString::new("foo").unwrap();
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).unwrap());
}
}