1use std::{io, path::PathBuf};
4
5use thiserror::Error;
6
7use onnxruntime_sys as sys;
8
9use crate::{char_p_to_string, g_ort};
10
11pub type Result<T> = std::result::Result<T, OrtError>;
13
14#[non_exhaustive]
16#[derive(Error, Debug)]
17pub enum OrtError {
18 #[error("Failed to construct String")]
21 StringConversion(OrtApiError),
22 #[error("Failed to create environment: {0}")]
25 Environment(OrtApiError),
26 #[error("Failed to create session options: {0}")]
28 SessionOptions(OrtApiError),
29 #[error("Failed to create session: {0}")]
31 Session(OrtApiError),
32 #[error("Failed to get allocator: {0}")]
34 Allocator(OrtApiError),
35 #[error("Failed to get input or output count: {0}")]
37 InOutCount(OrtApiError),
38 #[error("Failed to get input name: {0}")]
40 InputName(OrtApiError),
41 #[error("Failed to get type info: {0}")]
43 GetTypeInfo(OrtApiError),
44 #[error("Failed to cast type info to tensor info: {0}")]
46 CastTypeInfoToTensorInfo(OrtApiError),
47 #[error("Failed to get tensor element type: {0}")]
49 TensorElementType(OrtApiError),
50 #[error("Failed to get dimensions count: {0}")]
52 GetDimensionsCount(OrtApiError),
53 #[error("Failed to get dimensions: {0}")]
55 GetDimensions(OrtApiError),
56 #[error("Failed to get dimensions: {0}")]
58 CreateCpuMemoryInfo(OrtApiError),
59 #[error("Failed to create tensor: {0}")]
61 CreateTensor(OrtApiError),
62 #[error("Failed to create tensor with data: {0}")]
64 CreateTensorWithData(OrtApiError),
65 #[error("Failed to fill string tensor: {0}")]
67 FillStringTensor(OrtApiError),
68 #[error("Failed to check if tensor: {0}")]
70 IsTensor(OrtApiError),
71 #[error("Failed to get tensor type and shape: {0}")]
73 GetTensorTypeAndShape(OrtApiError),
74 #[error("Failed to run: {0}")]
76 Run(OrtApiError),
77 #[error("Failed to get tensor data: {0}")]
79 GetTensorMutableData(OrtApiError),
80
81 #[error("Failed to download ONNX model: {0}")]
83 DownloadError(#[from] OrtDownloadError),
84
85 #[error("Dimensions do not match: {0:?}")]
87 NonMatchingDimensions(NonMatchingDimensionsError),
88 #[error("File {filename:?} does not exists")]
90 FileDoesNotExists {
91 filename: PathBuf,
93 },
94 #[error("Path {path:?} cannot be converted to UTF-8")]
96 NonUtf8Path {
97 path: PathBuf,
99 },
100 #[error("Failed to build CString when original contains null: {0}")]
102 CStringNulError(#[from] std::ffi::NulError),
103}
104
105#[non_exhaustive]
108#[derive(Error, Debug)]
109pub enum NonMatchingDimensionsError {
110 #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
112 InputsCount {
113 inference_input_count: usize,
115 model_input_count: usize,
117 inference_input: Vec<Vec<usize>>,
119 model_input: Vec<Vec<Option<u32>>>,
121 },
122 #[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
124 InputsLength {
125 inference_input: Vec<Vec<usize>>,
127 model_input: Vec<Vec<Option<u32>>>,
129 },
130}
131
132#[non_exhaustive]
134#[derive(Error, Debug)]
135pub enum OrtApiError {
136 #[error("Error calling ONNX Runtime C function: {0}")]
138 Msg(String),
139 #[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
141 IntoStringError(std::ffi::IntoStringError),
142}
143
144#[non_exhaustive]
146#[derive(Error, Debug)]
147pub enum OrtDownloadError {
148 #[error("Error downloading data to file: {0}")]
150 IoError(#[from] io::Error),
151 #[cfg(feature = "model-fetching")]
152 #[error("Error downloading data to file: {0}")]
154 UreqError(#[from] Box<ureq::Error>),
155 #[error("Error getting content-length")]
157 ContentLengthError,
158 #[error("Error copying data to file: expected {expected} length, received {io}")]
160 CopyError {
161 expected: u64,
163 io: u64,
165 },
166}
167
168pub struct OrtStatusWrapper(*const sys::OrtStatus);
172
173impl From<*const sys::OrtStatus> for OrtStatusWrapper {
174 fn from(status: *const sys::OrtStatus) -> Self {
175 OrtStatusWrapper(status)
176 }
177}
178
179impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
180 fn from(status: OrtStatusWrapper) -> Self {
181 if status.0.is_null() {
182 Ok(())
183 } else {
184 let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
185 match char_p_to_string(raw) {
186 Ok(msg) => Err(OrtApiError::Msg(msg)),
187 Err(err) => match err {
188 OrtError::StringConversion(OrtApiError::IntoStringError(e)) => {
189 Err(OrtApiError::IntoStringError(e))
190 }
191 _ => unreachable!(),
192 },
193 }
194 }
195 }
196}
197
198pub(crate) fn status_to_result(
199 status: *const sys::OrtStatus,
200) -> std::result::Result<(), OrtApiError> {
201 let status_wrapper: OrtStatusWrapper = status.into();
202 status_wrapper.into()
203}
204
205pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
207where
208 F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
209{
210 status_to_result(f(g_ort()))
211}