#![warn(missing_docs)]
#,
a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using
[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method
(requires the `model-fetching` feature).
```no_run
# use std::error::Error;
# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel};
# fn main() -> Result<(), Box<dyn Error>> {
# let environment = Environment::builder()
# .with_name("test")
# .with_log_level(LoggingLevel::Verbose)
# .build()?;
let mut session = environment
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_number_threads(1)?
.with_model_downloaded(ImageClassification::SqueezeNet)?;
# Ok(())
# }
```
See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available
to download.
"##
)]
use std::sync::{atomic::AtomicPtr, Arc, Mutex};
use lazy_static::lazy_static;
use onnxruntime_sys as sys;
#[cfg(all(target_os = "windows", target_arch = "x86"))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "stdcall" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "stdcall" fn $($tt)*);
}
#[cfg(not(all(target_os = "windows", target_arch = "x86")))]
macro_rules! extern_system_fn {
($(#[$meta:meta])* fn $($tt:tt)*) => ($(#[$meta])* extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis fn $($tt:tt)*) => ($(#[$meta])* $vis extern "C" fn $($tt)*);
($(#[$meta:meta])* unsafe fn $($tt:tt)*) => ($(#[$meta])* unsafe extern "C" fn $($tt)*);
($(#[$meta:meta])* $vis:vis unsafe fn $($tt:tt)*) => ($(#[$meta])* $vis unsafe extern "C" fn $($tt)*);
}
pub mod download;
pub mod environment;
pub mod error;
mod memory;
pub mod session;
pub mod tensor;
pub use error::{OrtApiError, OrtError, Result};
use sys::OnnxEnumInt;
pub use ndarray;
lazy_static! {
static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, std::ptr::null());
let get_api: extern_system_fn!{ unsafe fn(u32) -> *const onnxruntime_sys::OrtApi } =
unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
};
}
fn g_ort() -> sys::OrtApi {
let mut api_ref = G_ORT_API
.lock()
.expect("Failed to acquire lock: another thread panicked?");
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
assert_ne!(api_ptr_mut, std::ptr::null_mut());
unsafe { *api_ptr_mut }
}
fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
match c_string.into_string() {
Ok(string) => Ok(string),
Err(e) => Err(OrtApiError::IntoStringError(e)),
}
.map_err(OrtError::StringConversion)
}
mod onnxruntime {
use std::ffi::CStr;
use tracing::{debug, error, info, span, trace, warn, Level};
use onnxruntime_sys as sys;
#[derive(Debug)]
struct CodeLocation<'a> {
file: &'a str,
line_number: &'a str,
function: &'a str,
}
impl<'a> From<&'a str> for CodeLocation<'a> {
fn from(code_location: &'a str) -> Self {
let mut splitter = code_location.split(' ');
let file_and_line_number = splitter.next().unwrap_or("<unknown file:line>");
let function = splitter.next().unwrap_or("<unknown module>");
let mut file_and_line_number_splitter = file_and_line_number.split(':');
let file = file_and_line_number_splitter
.next()
.unwrap_or("<unknown file>");
let line_number = file_and_line_number_splitter
.next()
.unwrap_or("<unknown line number>");
CodeLocation {
file,
line_number,
function,
}
}
}
extern_system_fn! {
pub(crate) fn custom_logger(
_params: *mut std::ffi::c_void,
severity: sys::OrtLoggingLevel,
category: *const i8,
logid: *const i8,
code_location: *const i8,
message: *const i8,
) {
let log_level = match severity {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN,
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
};
assert_ne!(category, std::ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, std::ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }
.to_str()
.unwrap_or("unknown");
assert_ne!(message, std::ptr::null());
let message = unsafe { CStr::from_ptr(message) };
assert_ne!(logid, std::ptr::null());
let logid = unsafe { CStr::from_ptr(logid) };
let code_location: CodeLocation = code_location.into();
let span = span!(
Level::TRACE,
"onnxruntime",
category = category.to_str().unwrap_or("<unknown>"),
file = code_location.file,
line_number = code_location.line_number,
function = code_location.function,
logid = logid.to_str().unwrap_or("<unknown>"),
);
let _enter = span.enter();
match log_level {
Level::TRACE => trace!("{:?}", message),
Level::DEBUG => debug!("{:?}", message),
Level::INFO => info!("{:?}", message),
Level::WARN => warn!("{:?}", message),
Level::ERROR => error!("{:?}", message),
}
}
}
}
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum LoggingLevel {
Verbose = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
Info = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
Warning = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
Error = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
Fatal = sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt,
}
impl From<LoggingLevel> for sys::OrtLoggingLevel {
fn from(val: LoggingLevel) -> Self {
match val {
LoggingLevel::Verbose => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
LoggingLevel::Info => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
LoggingLevel::Warning => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
LoggingLevel::Error => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
LoggingLevel::Fatal => sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL,
}
}
}
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum GraphOptimizationLevel {
DisableAll = sys::GraphOptimizationLevel::ORT_DISABLE_ALL as OnnxEnumInt,
Basic = sys::GraphOptimizationLevel::ORT_ENABLE_BASIC as OnnxEnumInt,
Extended = sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED as OnnxEnumInt,
All = sys::GraphOptimizationLevel::ORT_ENABLE_ALL as OnnxEnumInt,
}
impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
fn from(val: GraphOptimizationLevel) -> Self {
use GraphOptimizationLevel::*;
match val {
DisableAll => sys::GraphOptimizationLevel::ORT_DISABLE_ALL,
Basic => sys::GraphOptimizationLevel::ORT_ENABLE_BASIC,
Extended => sys::GraphOptimizationLevel::ORT_ENABLE_EXTENDED,
All => sys::GraphOptimizationLevel::ORT_ENABLE_ALL,
}
}
}
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum TensorElementDataType {
Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
}
impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
fn from(val: TensorElementDataType) -> Self {
use TensorElementDataType::*;
match val {
Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
}
}
}
pub trait TypeToTensorElementDataType {
fn tensor_element_data_type() -> TensorElementDataType;
fn try_utf8_bytes(&self) -> Option<&[u8]>;
}
macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl TypeToTensorElementDataType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
}
fn try_utf8_bytes(&self) -> Option<&[u8]> {
None
}
}
};
}
impl_type_trait!(f32, Float);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
impl_type_trait!(f64, Double);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
pub trait Utf8Data {
fn utf8_bytes(&self) -> &[u8];
}
impl Utf8Data for String {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl<'a> Utf8Data for &'a str {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl<T: Utf8Data> TypeToTensorElementDataType for T {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}
fn try_utf8_bytes(&self) -> Option<&[u8]> {
Some(self.utf8_bytes())
}
}
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum AllocatorType {
Device = sys::OrtAllocatorType::OrtDeviceAllocator as i32,
Arena = sys::OrtAllocatorType::OrtArenaAllocator as i32,
}
impl From<AllocatorType> for sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
use AllocatorType::*;
match val {
Device => sys::OrtAllocatorType::OrtDeviceAllocator,
Arena => sys::OrtAllocatorType::OrtArenaAllocator,
}
}
}
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum MemType {
Default = sys::OrtMemType::OrtMemTypeDefault as i32,
}
impl From<MemType> for sys::OrtMemType {
fn from(val: MemType) -> Self {
use MemType::*;
match val {
Default => sys::OrtMemType::OrtMemTypeDefault,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_char_p_to_string() {
let s = std::ffi::CString::new("foo").unwrap();
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).unwrap());
}
}