1use std::{io, path::PathBuf};
4
5use thiserror::Error;
6
7use onnxruntime_sys_ng as sys;
8
9use crate::{char_p_to_string, g_ort};
10
11pub type Result<T> = std::result::Result<T, OrtError>;
13
14#[non_exhaustive]
16#[derive(Error, Debug)]
17pub enum OrtError {
18 #[error("Failed to construct String")]
21 StringConversion(OrtApiError),
22 #[error("Failed to create environment: {0}")]
25 Environment(OrtApiError),
26 #[error("Failed to create session options: {0}")]
28 SessionOptions(OrtApiError),
29 #[error("Failed to create session: {0}")]
31 Session(OrtApiError),
32 #[error("Failed to get allocator: {0}")]
34 Allocator(OrtApiError),
35 #[error("Failed to get input or output count: {0}")]
37 InOutCount(OrtApiError),
38 #[error("Failed to get input name: {0}")]
40 InputName(OrtApiError),
41 #[error("Failed to get type info: {0}")]
43 GetTypeInfo(OrtApiError),
44 #[error("Failed to cast type info to tensor info: {0}")]
46 CastTypeInfoToTensorInfo(OrtApiError),
47 #[error("Failed to get tensor element type: {0}")]
49 TensorElementType(OrtApiError),
50 #[error("Failed to get dimensions count: {0}")]
52 GetDimensionsCount(OrtApiError),
53 #[error("Failed to get dimensions: {0}")]
55 GetDimensions(OrtApiError),
56 #[error("Failed to get dimensions: {0}")]
58 CreateCpuMemoryInfo(OrtApiError),
59 #[error("Failed to create tensor: {0}")]
61 CreateTensor(OrtApiError),
62 #[error("Failed to create tensor with data: {0}")]
64 CreateTensorWithData(OrtApiError),
65 #[error("Failed to fill string tensor: {0}")]
67 FillStringTensor(OrtApiError),
68 #[error("Failed to check if tensor: {0}")]
70 IsTensor(OrtApiError),
71 #[error("Failed to get tensor type and shape: {0}")]
73 GetTensorTypeAndShape(OrtApiError),
74 #[error("Failed to run: {0}")]
76 Run(OrtApiError),
77 #[error("Failed to get tensor data: {0}")]
79 GetTensorMutableData(OrtApiError),
80
81 #[error("Failed to download ONNX model: {0}")]
83 DownloadError(#[from] OrtDownloadError),
84
85 #[error("Dimensions do not match: {0:?}")]
87 NonMatchingDimensions(NonMatchingDimensionsError),
88 #[error("File {filename:?} does not exists")]
90 FileDoesNotExists {
91 filename: PathBuf,
93 },
94 #[error("Path {path:?} cannot be converted to UTF-8")]
96 NonUtf8Path {
97 path: PathBuf,
99 },
100 #[error("Failed to build CString when original contains null: {0}")]
102 CStringNulError(#[from] std::ffi::NulError),
103 #[error("{0} pointer should be null")]
104 PointerShouldBeNull(String),
106 #[error("{0} pointer should not be null")]
108 PointerShouldNotBeNull(String),
109 #[error("Invalid dimensions")]
111 InvalidDimensions,
112 #[error("Undefined Tensor Element Type")]
114 UndefinedTensorElementType,
115 #[error("Failed to check if tensor")]
117 IsTensorCheck,
118}
119
120#[non_exhaustive]
123#[derive(Error, Debug)]
124pub enum NonMatchingDimensionsError {
125 #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
127 InputsCount {
128 inference_input_count: usize,
130 model_input_count: usize,
132 inference_input: Vec<Vec<usize>>,
134 model_input: Vec<Vec<Option<u32>>>,
136 },
137 #[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
139 InputsLength {
140 inference_input: Vec<Vec<usize>>,
142 model_input: Vec<Vec<Option<u32>>>,
144 },
145}
146
147#[non_exhaustive]
149#[derive(Error, Debug)]
150pub enum OrtApiError {
151 #[error("Error calling ONNX Runtime C function: {0}")]
153 Msg(String),
154 #[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
156 IntoStringError(std::ffi::IntoStringError),
157}
158
159#[non_exhaustive]
161#[derive(Error, Debug)]
162pub enum OrtDownloadError {
163 #[error("Error downloading data to file: {0}")]
165 IoError(#[from] io::Error),
166 #[cfg(feature = "model-fetching")]
167 #[error("Error downloading data to file: {0}")]
169 UreqError(#[from] Box<ureq::Error>),
170 #[error("Error getting content-length")]
172 ContentLengthError,
173 #[error("Error copying data to file: expected {expected} length, received {io}")]
175 CopyError {
176 expected: u64,
178 io: u64,
180 },
181}
182
183pub struct OrtStatusWrapper(*const sys::OrtStatus);
187
188impl From<*const sys::OrtStatus> for OrtStatusWrapper {
189 fn from(status: *const sys::OrtStatus) -> Self {
190 OrtStatusWrapper(status)
191 }
192}
193
194pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
195 ptr.is_null()
196 .then(|| ())
197 .ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
198}
199
200pub(crate) fn assert_not_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
201 (!ptr.is_null())
202 .then(|| ())
203 .ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
204}
205
206impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
207 fn from(status: OrtStatusWrapper) -> Self {
208 if status.0.is_null() {
209 Ok(())
210 } else {
211 let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
212 match char_p_to_string(raw) {
213 Ok(msg) => Err(OrtApiError::Msg(msg)),
214 Err(err) => match err {
215 OrtError::StringConversion(OrtApiError::IntoStringError(e)) => {
216 Err(OrtApiError::IntoStringError(e))
217 }
218 _ => unreachable!(),
219 },
220 }
221 }
222 }
223}
224
225pub(crate) fn status_to_result(
226 status: *const sys::OrtStatus,
227) -> std::result::Result<(), OrtApiError> {
228 let status_wrapper: OrtStatusWrapper = status.into();
229 status_wrapper.into()
230}
231
232pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
234where
235 F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
236{
237 status_to_result(f(g_ort()))
238}