#![warn(missing_docs)]
#,
a model can be fetched directly from the [ONNX Model Zoo](https://github.com/onnx/models) using
[`with_model_downloaded()`](session/struct.SessionBuilder.html#method.with_model_downloaded) method
(requires the `model-fetching` feature).
```no_run
# use std::error::Error;
# use onnxruntime::{environment::Environment, download::vision::ImageClassification, LoggingLevel, GraphOptimizationLevel};
# fn main() -> Result<(), Box<dyn Error>> {
# let environment = Environment::builder()
# .with_name("test")
# .with_log_level(LoggingLevel::Verbose)
# .build()?;
let mut session = environment
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_number_threads(1)?
.with_model_downloaded(ImageClassification::SqueezeNet)?;
# Ok(())
# }
```
See [`AvailableOnnxModel`](download/enum.AvailableOnnxModel.html) for the different models available
to download.
"##
)]
use std::sync::{atomic::AtomicPtr, Arc, Mutex};
use lazy_static::lazy_static;
use onnxruntime_sys as sys;
pub mod download;
pub mod environment;
pub mod error;
mod memory;
pub mod session;
pub mod tensor;
pub use error::{OrtApiError, OrtError, Result};
lazy_static! {
static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, std::ptr::null());
let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi =
unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Arc::new(Mutex::new(AtomicPtr::new(api as *mut sys::OrtApi)))
};
}
fn g_ort() -> sys::OrtApi {
let mut api_ref = G_ORT_API
.lock()
.expect("Failed to acquire lock: another thread panicked?");
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;
assert_ne!(api_ptr_mut, std::ptr::null_mut());
unsafe { *api_ptr_mut }
}
fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CString::from_raw(raw as *mut i8) };
match c_string.into_string() {
Ok(string) => Ok(string),
Err(e) => Err(OrtApiError::IntoStringError(e)),
}
.map_err(OrtError::StringConversion)
}
mod onnxruntime {
use std::ffi::CStr;
use tracing::{debug, error, info, span, trace, warn, Level};
#[derive(Debug)]
struct CodeLocation<'a> {
file: &'a str,
line_number: &'a str,
function: &'a str,
}
impl<'a> From<&'a str> for CodeLocation<'a> {
fn from(code_location: &'a str) -> Self {
let mut splitter = code_location.split(' ');
let file_and_line_number = splitter.next().unwrap_or("<unknown file:line>");
let function = splitter.next().unwrap_or("<unknown module>");
let mut file_and_line_number_splitter = file_and_line_number.split(':');
let file = file_and_line_number_splitter
.next()
.unwrap_or("<unknown file>");
let line_number = file_and_line_number_splitter
.next()
.unwrap_or("<unknown line number>");
CodeLocation {
file,
line_number,
function,
}
}
}
pub(crate) extern "C" fn custom_logger(
_params: *mut std::ffi::c_void,
severity: u32,
category: *const i8,
logid: *const i8,
code_location: *const i8,
message: *const i8,
) {
let log_level = match severity {
0 => Level::TRACE,
1 => Level::DEBUG,
2 => Level::INFO,
3 => Level::WARN,
_ => Level::ERROR,
};
assert_ne!(category, std::ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, std::ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }
.to_str()
.unwrap_or("unknown");
assert_ne!(message, std::ptr::null());
let message = unsafe { CStr::from_ptr(message) };
assert_ne!(logid, std::ptr::null());
let logid = unsafe { CStr::from_ptr(logid) };
let code_location: CodeLocation = code_location.into();
let span = span!(
Level::TRACE,
"onnxruntime",
category = category.to_str().unwrap_or("<unknown>"),
file = code_location.file,
line_number = code_location.line_number,
function = code_location.function,
logid = logid.to_str().unwrap_or("<unknown>"),
);
let _enter = span.enter();
match log_level {
Level::TRACE => trace!("{:?}", message),
Level::DEBUG => debug!("{:?}", message),
Level::INFO => info!("{:?}", message),
Level::WARN => warn!("{:?}", message),
Level::ERROR => error!("{:?}", message),
}
}
}
#[derive(Debug)]
#[repr(u32)]
pub enum LoggingLevel {
Verbose = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE,
Info = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO,
Warning = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING,
Error = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR,
Fatal = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL,
}
#[derive(Debug)]
#[repr(u32)]
pub enum GraphOptimizationLevel {
DisableAll = sys::GraphOptimizationLevel_ORT_DISABLE_ALL,
Basic = sys::GraphOptimizationLevel_ORT_ENABLE_BASIC,
Extended = sys::GraphOptimizationLevel_ORT_ENABLE_EXTENDED,
All = sys::GraphOptimizationLevel_ORT_ENABLE_ALL,
}
#[derive(Debug)]
#[repr(u32)]
pub enum TensorElementDataType {
Float = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
Uint8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
Int8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
Uint16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
Int16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
Int32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
Int64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
Double = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
Uint32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
Uint64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
}
pub trait TypeToTensorElementDataType {
fn tensor_element_data_type() -> TensorElementDataType;
}
macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl TypeToTensorElementDataType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
}
}
};
}
impl_type_trait!(f32, Float);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
impl_type_trait!(f64, Double);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum AllocatorType {
Device = sys::OrtAllocatorType_OrtDeviceAllocator,
Arena = sys::OrtAllocatorType_OrtArenaAllocator,
}
#[derive(Debug, Clone)]
#[repr(i32)]
pub enum MemType {
Default = sys::OrtMemType_OrtMemTypeDefault,
}