#![doc = include_str!("../README.md")]
pub mod download;
pub mod environment;
pub mod error;
pub mod execution_providers;
pub mod io_binding;
pub mod memory;
pub mod metadata;
pub mod session;
pub mod sys;
pub mod tensor;
pub mod value;
use std::{
ffi::{self, CStr},
os::raw::c_char,
ptr,
sync::{atomic::AtomicPtr, Arc, Mutex}
};
use lazy_static::lazy_static;
use tracing::warn;
pub use self::environment::Environment;
pub use self::error::{OrtApiError, OrtError, OrtResult};
pub use self::execution_providers::ExecutionProvider;
pub use self::io_binding::IoBinding;
pub use self::memory::{AllocationDevice, MemoryInfo};
pub use self::session::{InMemorySession, Session, SessionBuilder};
pub use self::tensor::NdArrayExtensions;
pub use self::value::Value;
#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
}
#[cfg(all(target_arch = "x86", target_os = "windows"))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
}
pub(crate) use extern_system_fn;
#[cfg(feature = "load-dynamic")]
lazy_static! {
pub(crate) static ref G_ORT_DYLIB_PATH: Arc<String> = {
let path = match std::env::var("ORT_DYLIB_PATH") {
Ok(s) if !s.is_empty() => s,
#[cfg(target_os = "windows")]
_ => "onnxruntime.dll".to_owned(),
#[cfg(any(target_os = "linux", target_os = "android"))]
_ => "libonnxruntime.so".to_owned(),
#[cfg(target_os = "macos")]
_ => "libonnxruntime.dylib".to_owned()
};
Arc::new(path)
};
pub(crate) static ref G_ORT_LIB: Arc<Mutex<AtomicPtr<libloading::Library>>> = {
unsafe {
let path: std::path::PathBuf = (&**G_ORT_DYLIB_PATH).into();
let absolute_path = if path.is_absolute() {
path
} else {
let relative = std::env::current_exe().expect("could not get current executable path").parent().unwrap().join(&path);
if relative.exists() {
relative
} else {
path
}
};
let lib = libloading::Library::new(&absolute_path).unwrap_or_else(|e| panic!("could not load the library at `{}`: {e:?}", absolute_path.display()));
Arc::new(Mutex::new(AtomicPtr::new(Box::leak(Box::new(lib)) as *mut _)))
}
};
}
lazy_static! {
pub(crate) static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
#[cfg(feature = "load-dynamic")]
unsafe {
let dylib = *G_ORT_LIB
.lock()
.expect("failed to acquire ONNX Runtime dylib lock; another thread panicked?")
.get_mut();
let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const sys::OrtApiBase> = (*dylib).get(b"OrtGetApiBase").expect("");
let base: *const sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());
let get_version_string: extern_system_fn! { unsafe fn () -> *const ffi::c_char } = (*base).GetVersionString.unwrap();
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
let lib_minor_version = version_string.split('.').nth(1).map(|x| x.parse::<u32>().unwrap_or(0)).unwrap_or(0);
match lib_minor_version.cmp(&16) {
std::cmp::Ordering::Less => panic!(
"ort 1.16 is not compatible with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.16.x', but got '{version_string}'",
**G_ORT_DYLIB_PATH
),
std::cmp::Ordering::Greater => warn!(
"ort 1.16 may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.16.x', but got '{version_string}'",
**G_ORT_DYLIB_PATH
),
std::cmp::Ordering::Equal => {}
};
let get_api: extern_system_fn! { unsafe fn(u32) -> *const sys::OrtApi } = (*base).GetApi.unwrap();
let api: *const sys::OrtApi = get_api(sys::ORT_API_VERSION);
Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
}
#[cfg(not(feature = "load-dynamic"))]
{
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, ptr::null());
let get_api: extern_system_fn! { unsafe fn(u32) -> *const sys::OrtApi } = unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
}
};
}
pub fn ort() -> sys::OrtApi {
let mut api_ref = G_ORT_API.lock().expect("failed to acquire OrtApi lock; another thread panicked?");
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
assert_ne!(api_ptr_mut, ptr::null_mut());
unsafe { *api_ptr_mut }
}
macro_rules! ortsys {
($method:tt) => {
$crate::ort().$method.unwrap()
};
(unsafe $method:tt) => {
unsafe { $crate::ort().$method.unwrap() }
};
($method:tt($($n:expr),+ $(,)?)) => {
$crate::ort().$method.unwrap()($($n),+)
};
(unsafe $method:tt($($n:expr),+ $(,)?)) => {
unsafe { $crate::ort().$method.unwrap()($($n),+) }
};
($method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::ort().$method.unwrap()($($n),+);
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:tt($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
unsafe { $crate::ort().$method.unwrap()($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
($method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
};
(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
};
($method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result($crate::ort().$method.unwrap()($($n),+)).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:tt($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result(unsafe { $crate::ort().$method.unwrap()($($n),+) }).map_err($err)?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
}
macro_rules! ortfree {
(unsafe $allocator_ptr:expr, $ptr:tt) => {
unsafe { (*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void) }
};
($allocator_ptr:expr, $ptr:tt) => {
(*$allocator_ptr).Free.unwrap()($allocator_ptr, $ptr as *mut std::ffi::c_void)
};
}
pub(crate) use ortfree;
pub(crate) use ortsys;
pub(crate) fn char_p_to_string(raw: *const c_char) -> OrtResult<String> {
let c_string = unsafe { CStr::from_ptr(raw as *mut c_char).to_owned() };
match c_string.into_string() {
Ok(string) => Ok(string),
Err(e) => Err(OrtApiError::IntoStringError(e))
}
.map_err(OrtError::FfiStringConversion)
}
#[derive(Debug)]
struct CodeLocation<'a> {
file: &'a str,
line: &'a str,
function: &'a str
}
impl<'a> From<&'a str> for CodeLocation<'a> {
fn from(code_location: &'a str) -> Self {
let mut splitter = code_location.split(' ');
let file_and_line = splitter.next().unwrap_or("<unknown file>:<unknown line>");
let function = splitter.next().unwrap_or("<unknown function>");
let mut file_and_line_splitter = file_and_line.split(':');
let file = file_and_line_splitter.next().unwrap_or("<unknown file>");
let line = file_and_line_splitter.next().unwrap_or("<unknown line>");
CodeLocation { file, line, function }
}
}
extern_system_fn! {
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const c_char, log_id: *const c_char, code_location: *const c_char, message: *const c_char) {
use tracing::{span, Level, trace, debug, warn, info, error};
let log_level = match severity {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR
};
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("unknown");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<invalid>");
assert_ne!(log_id, ptr::null());
let log_id = unsafe { CStr::from_ptr(log_id) };
let code_location = CodeLocation::from(code_location);
let span = span!(
Level::TRACE,
"ort",
category = category.to_str().unwrap_or("<unknown>"),
file = code_location.file,
line = code_location.line,
function = code_location.function,
log_id = log_id.to_str().unwrap_or("<unknown>")
);
let _enter = span.enter();
match log_level {
Level::TRACE => trace!("{}", message),
Level::DEBUG => debug!("{}", message),
Level::INFO => info!("{}", message),
Level::WARN => warn!("{}", message),
Level::ERROR => error!("{}", message)
}
}
}
#[derive(Debug)]
pub enum LoggingLevel {
Verbose,
Info,
Warning,
Error,
Fatal
}
impl From<LoggingLevel> for sys::OrtLoggingLevel {
fn from(logging_level: LoggingLevel) -> Self {
match logging_level {
LoggingLevel::Verbose => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
LoggingLevel::Info => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
LoggingLevel::Warning => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
LoggingLevel::Error => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
LoggingLevel::Fatal => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL
}
}
}
#[derive(Debug)]
pub enum GraphOptimizationLevel {
Disable,
Level1,
#[rustfmt::skip]
Level2,
Level3
}
impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
fn from(val: GraphOptimizationLevel) -> Self {
match val {
GraphOptimizationLevel::Disable => sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
GraphOptimizationLevel::Level1 => sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
GraphOptimizationLevel::Level2 => sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
GraphOptimizationLevel::Level3 => sys::GraphOptimizationLevel::ORT_ENABLE_ALL
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum AllocatorType {
Device,
Arena
}
impl From<AllocatorType> for sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
match val {
AllocatorType::Device => sys::OrtAllocatorType::OrtDeviceAllocator,
AllocatorType::Arena => sys::OrtAllocatorType::OrtArenaAllocator
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum MemType {
CPUInput,
CPUOutput,
Default
}
impl MemType {
pub const CPU: MemType = MemType::CPUOutput;
}
impl From<MemType> for sys::OrtMemType {
fn from(val: MemType) -> Self {
match val {
MemType::CPUInput => sys::OrtMemType::OrtMemTypeCPUInput,
MemType::CPUOutput => sys::OrtMemType::OrtMemTypeCPUOutput,
MemType::Default => sys::OrtMemType::OrtMemTypeDefault
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_char_p_to_string() {
let s = ffi::CString::new("foo").unwrap();
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).unwrap());
}
}