use std::{io, path::PathBuf};
use thiserror::Error;
use onnxruntime_sys as sys;
use crate::{char_p_to_string, g_ort};
pub type Result<T> = std::result::Result<T, OrtError>;
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtError {
#[error("Failed to construct String")]
StringConversion(OrtApiError),
#[error("Failed to create environment: {0}")]
Environment(OrtApiError),
#[error("Failed to create session options: {0}")]
SessionOptions(OrtApiError),
#[error("Failed to create session: {0}")]
Session(OrtApiError),
#[error("Failed to get allocator: {0}")]
Allocator(OrtApiError),
#[error("Failed to get input or output count: {0}")]
InOutCount(OrtApiError),
#[error("Failed to get input name: {0}")]
InputName(OrtApiError),
#[error("Failed to get type info: {0}")]
GetTypeInfo(OrtApiError),
#[error("Failed to cast type info to tensor info: {0}")]
CastTypeInfoToTensorInfo(OrtApiError),
#[error("Failed to get tensor element type: {0}")]
TensorElementType(OrtApiError),
#[error("Failed to get dimensions count: {0}")]
GetDimensionsCount(OrtApiError),
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
#[error("Failed to create tensor: {0}")]
CreateTensor(OrtApiError),
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(OrtApiError),
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(OrtApiError),
#[error("Failed to check if tensor: {0}")]
IsTensor(OrtApiError),
#[error("Failed to get tensor type and shape: {0}")]
GetTensorTypeAndShape(OrtApiError),
#[error("Failed to run: {0}")]
Run(OrtApiError),
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] OrtDownloadError),
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
#[error("File {filename:?} does not exists")]
FileDoesNotExists {
filename: PathBuf,
},
#[error("Path {path:?} cannot be converted to UTF-8")]
NonUtf8Path {
path: PathBuf,
},
#[error("Failed to build CString when original contains null: {0}")]
CStringNulError(#[from] std::ffi::NulError),
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
InputsCount {
inference_input_count: usize,
model_input_count: usize,
inference_input: Vec<Vec<usize>>,
model_input: Vec<Vec<Option<u32>>>,
},
#[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
InputsLength {
inference_input: Vec<Vec<usize>>,
model_input: Vec<Vec<Option<u32>>>,
},
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtApiError {
#[error("Error calling ONNX Runtime C function: {0}")]
Msg(String),
#[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
IntoStringError(std::ffi::IntoStringError),
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtDownloadError {
#[error("Error downloading data to file: {0}")]
IoError(#[from] io::Error),
#[cfg(feature = "model-fetching")]
#[error("Error downloading data to file: {0}")]
UreqError(#[from] Box<ureq::Error>),
#[error("Error getting content-length")]
ContentLengthError,
#[error("Error copying data to file: expected {expected} length, received {io}")]
CopyError {
expected: u64,
io: u64,
},
}
pub struct OrtStatusWrapper(*const sys::OrtStatus);
impl From<*const sys::OrtStatus> for OrtStatusWrapper {
fn from(status: *const sys::OrtStatus) -> Self {
OrtStatusWrapper(status)
}
}
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
fn from(status: OrtStatusWrapper) -> Self {
if status.0.is_null() {
Ok(())
} else {
let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
match char_p_to_string(raw) {
Ok(msg) => Err(OrtApiError::Msg(msg)),
Err(err) => match err {
OrtError::StringConversion(OrtApiError::IntoStringError(e)) => {
Err(OrtApiError::IntoStringError(e))
}
_ => unreachable!(),
},
}
}
}
}
pub(crate) fn status_to_result(
status: *const sys::OrtStatus,
) -> std::result::Result<(), OrtApiError> {
let status_wrapper: OrtStatusWrapper = status.into();
status_wrapper.into()
}
pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
where
F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
{
status_to_result(f(g_ort()))
}