use std::{io, path::PathBuf, string};
use thiserror::Error;
use super::{char_p_to_string, ort, sys, tensor::TensorElementDataType};
pub type OrtResult<T> = std::result::Result<T, OrtError>;
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtError {
#[error("Failed to construct Rust String")]
FfiStringConversion(OrtApiError),
#[error("Failed to create ONNX Runtime environment: {0}")]
CreateEnvironment(OrtApiError),
#[error("Failed to create ONNX Runtime session options: {0}")]
CreateSessionOptions(OrtApiError),
#[error("Failed to create ONNX Runtime session: {0}")]
CreateSession(OrtApiError),
#[error("Failed to get ONNX allocator: {0}")]
GetAllocator(OrtApiError),
#[error("Failed to get input or output count: {0}")]
GetInOutCount(OrtApiError),
#[error("Failed to get input name: {0}")]
GetInputName(OrtApiError),
#[error("Failed to get type info: {0}")]
GetTypeInfo(OrtApiError),
#[error("Failed to cast type info to tensor info: {0}")]
CastTypeInfoToTensorInfo(OrtApiError),
#[error("Failed to get tensor element type: {0}")]
GetTensorElementType(OrtApiError),
#[error("Failed to get dimensions count: {0}")]
GetDimensionsCount(OrtApiError),
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
#[error("Failed to get string tensor length: {0}")]
GetStringTensorDataLength(OrtApiError),
#[error("Failed to get tensor element count: {0}")]
GetTensorShapeElementCount(OrtApiError),
#[error("Failed to create CPU memory info: {0}")]
CreateCpuMemoryInfo(OrtApiError),
#[error("Failed to create tensor: {0}")]
CreateTensor(OrtApiError),
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(OrtApiError),
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(OrtApiError),
#[error("Failed to check if tensor is a tensor or was properly initialized: {0}")]
FailedTensorCheck(OrtApiError),
#[error("Failed to get tensor type and shape: {0}")]
GetTensorTypeAndShape(OrtApiError),
#[error("Failed to run inference on model: {0}")]
SessionRun(OrtApiError),
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
#[error("Failed to get tensor string data: {0}")]
GetStringTensorContent(OrtApiError),
#[error("Data was not UTF-8: {0}")]
StringFromUtf8Error(#[from] string::FromUtf8Error),
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] OrtDownloadError),
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
#[error("File `{filename:?}` does not exist")]
FileDoesNotExist {
filename: PathBuf
},
#[error("Path `{path:?}` cannot be converted to UTF-8")]
NonUtf8Path {
path: PathBuf
},
#[error("Failed to build CString when original contains null: {0}")]
FfiStringNull(#[from] std::ffi::NulError),
#[cfg(all(windows, feature = "profiling"))]
#[error("Failed to build CString when original contains null: {0}")]
WideFfiStringNull(#[from] widestring::error::ContainsNul<u16>),
#[error("{0} pointer should be null")]
PointerShouldBeNull(String),
#[error("{0} pointer should not be null")]
PointerShouldNotBeNull(String),
#[error("Undefined tensor element type")]
UndefinedTensorElementType,
#[error("Failed to retrieve model metadata: {0}")]
GetModelMetadata(OrtApiError),
#[error("Data type mismatch: was {:?}, tried to convert to {:?}", actual, requested)]
DataTypeMismatch {
actual: TensorElementDataType,
requested: TensorElementDataType
}
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
#[error(
"Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})"
)]
InputsCount {
inference_input_count: usize,
model_input_count: usize,
inference_input: Vec<Vec<usize>>,
model_input: Vec<Vec<Option<u32>>>
},
#[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")]
InputsLength {
inference_input: Vec<Vec<usize>>,
model_input: Vec<Vec<Option<u32>>>
}
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtApiError {
#[error("{0}")]
Msg(String),
#[error("an error occurred, but ort failed to convert the error message to UTF-8")]
IntoStringError(std::ffi::IntoStringError)
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtDownloadError {
#[error("Error reading file: {0}")]
IoError(#[from] io::Error),
#[cfg(feature = "fetch-models")]
#[error("Error downloading to file: {0}")]
FetchError(#[from] Box<ureq::Error>),
#[error("Error getting Content-Length from HTTP GET")]
ContentLengthError,
#[error("Error copying data to file: expected {expected} length, but got {io}")]
CopyError {
expected: u64,
io: u64
}
}
pub struct OrtStatusWrapper(*mut sys::OrtStatus);
impl From<*mut sys::OrtStatus> for OrtStatusWrapper {
fn from(status: *mut sys::OrtStatus) -> Self {
OrtStatusWrapper(status)
}
}
pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> OrtResult<()> {
ptr.is_null().then_some(()).ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
}
pub(crate) fn assert_non_null_pointer<T>(ptr: *const T, name: &str) -> OrtResult<()> {
(!ptr.is_null())
.then_some(())
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
}
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
fn from(status: OrtStatusWrapper) -> Self {
if status.0.is_null() {
Ok(())
} else {
let raw: *const std::os::raw::c_char = unsafe { ort().GetErrorMessage.unwrap()(status.0) };
match char_p_to_string(raw) {
Ok(msg) => Err(OrtApiError::Msg(msg)),
Err(err) => match err {
OrtError::FfiStringConversion(OrtApiError::IntoStringError(e)) => Err(OrtApiError::IntoStringError(e)),
_ => unreachable!()
}
}
}
}
}
impl Drop for OrtStatusWrapper {
fn drop(&mut self) {
unsafe { ort().ReleaseStatus.unwrap()(self.0) }
}
}
pub(crate) fn status_to_result(status: *mut sys::OrtStatus) -> std::result::Result<(), OrtApiError> {
let status_wrapper: OrtStatusWrapper = status.into();
status_wrapper.into()
}