onnxruntime_ng/
error.rs

1//! Module containing error definitions.
2
3use std::{io, path::PathBuf};
4
5use thiserror::Error;
6
7use onnxruntime_sys_ng as sys;
8
9use crate::{char_p_to_string, g_ort};
10
11/// Type alias for the `Result`
12pub type Result<T> = std::result::Result<T, OrtError>;
13
14/// Error type centralizing all possible errors
15#[non_exhaustive]
16#[derive(Error, Debug)]
17pub enum OrtError {
18    /// The C API can message to the caller using a C `char *` which needs to be converted
19    /// to Rust's `String`. This operation can fail.
20    #[error("Failed to construct String")]
21    StringConversion(OrtApiError),
22    // FIXME: Move these to another enum (they are C API calls errors)
23    /// An error occurred when creating an ONNX environment
24    #[error("Failed to create environment: {0}")]
25    Environment(OrtApiError),
26    /// Error occurred when creating an ONNX session options
27    #[error("Failed to create session options: {0}")]
28    SessionOptions(OrtApiError),
29    /// Error occurred when creating an ONNX session
30    #[error("Failed to create session: {0}")]
31    Session(OrtApiError),
32    /// Error occurred when creating an ONNX allocator
33    #[error("Failed to get allocator: {0}")]
34    Allocator(OrtApiError),
35    /// Error occurred when counting ONNX input or output count
36    #[error("Failed to get input or output count: {0}")]
37    InOutCount(OrtApiError),
38    /// Error occurred when getting ONNX input name
39    #[error("Failed to get input name: {0}")]
40    InputName(OrtApiError),
41    /// Error occurred when getting ONNX type information
42    #[error("Failed to get type info: {0}")]
43    GetTypeInfo(OrtApiError),
44    /// Error occurred when casting ONNX type information to tensor information
45    #[error("Failed to cast type info to tensor info: {0}")]
46    CastTypeInfoToTensorInfo(OrtApiError),
47    /// Error occurred when getting tensor elements type
48    #[error("Failed to get tensor element type: {0}")]
49    TensorElementType(OrtApiError),
50    /// Error occurred when getting ONNX dimensions count
51    #[error("Failed to get dimensions count: {0}")]
52    GetDimensionsCount(OrtApiError),
53    /// Error occurred when getting ONNX dimensions
54    #[error("Failed to get dimensions: {0}")]
55    GetDimensions(OrtApiError),
56    /// Error occurred when creating CPU memory information
57    #[error("Failed to get dimensions: {0}")]
58    CreateCpuMemoryInfo(OrtApiError),
59    /// Error occurred when creating ONNX tensor
60    #[error("Failed to create tensor: {0}")]
61    CreateTensor(OrtApiError),
62    /// Error occurred when creating ONNX tensor with specific data
63    #[error("Failed to create tensor with data: {0}")]
64    CreateTensorWithData(OrtApiError),
65    /// Error occurred when filling a tensor with string data
66    #[error("Failed to fill string tensor: {0}")]
67    FillStringTensor(OrtApiError),
68    /// Error occurred when checking if ONNX tensor was properly initialized
69    #[error("Failed to check if tensor: {0}")]
70    IsTensor(OrtApiError),
71    /// Error occurred when getting tensor type and shape
72    #[error("Failed to get tensor type and shape: {0}")]
73    GetTensorTypeAndShape(OrtApiError),
74    /// Error occurred when ONNX inference operation was called
75    #[error("Failed to run: {0}")]
76    Run(OrtApiError),
77    /// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
78    #[error("Failed to get tensor data: {0}")]
79    GetTensorMutableData(OrtApiError),
80
81    /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
82    #[error("Failed to download ONNX model: {0}")]
83    DownloadError(#[from] OrtDownloadError),
84
85    /// Dimensions of input data and ONNX model loaded from file do not match
86    #[error("Dimensions do not match: {0:?}")]
87    NonMatchingDimensions(NonMatchingDimensionsError),
88    /// File does not exists
89    #[error("File {filename:?} does not exists")]
90    FileDoesNotExists {
91        /// Path which does not exists
92        filename: PathBuf,
93    },
94    /// Path is an invalid UTF-8
95    #[error("Path {path:?} cannot be converted to UTF-8")]
96    NonUtf8Path {
97        /// Path with invalid UTF-8
98        path: PathBuf,
99    },
100    /// Attempt to build a Rust `CString` from a null pointer
101    #[error("Failed to build CString when original contains null: {0}")]
102    CStringNulError(#[from] std::ffi::NulError),
103    #[error("{0} pointer should be null")]
104    /// Ort Pointer should have been null
105    PointerShouldBeNull(String),
106    /// Ort pointer should not have been null
107    #[error("{0} pointer should not be null")]
108    PointerShouldNotBeNull(String),
109    /// ONNX Model has invalid dimensions
110    #[error("Invalid dimensions")]
111    InvalidDimensions,
112    /// The runtime type was undefined
113    #[error("Undefined Tensor Element Type")]
114    UndefinedTensorElementType,
115    /// Error occurred when checking if ONNX tensor was properly initialized
116    #[error("Failed to check if tensor")]
117    IsTensorCheck,
118}
119
120/// Error used when dimensions of input (from model and from inference call)
121/// do not match (as they should).
122#[non_exhaustive]
123#[derive(Error, Debug)]
124pub enum NonMatchingDimensionsError {
125    /// Number of inputs from model does not match number of inputs from inference call
126    #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
127    InputsCount {
128        /// Number of input dimensions used by inference call
129        inference_input_count: usize,
130        /// Number of input dimensions defined in model
131        model_input_count: usize,
132        /// Input dimensions used by inference call
133        inference_input: Vec<Vec<usize>>,
134        /// Input dimensions defined in model
135        model_input: Vec<Vec<Option<u32>>>,
136    },
137    /// Inputs length from model does not match the expected input from inference call
138    #[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
139    InputsLength {
140        /// Input dimensions used by inference call
141        inference_input: Vec<Vec<usize>>,
142        /// Input dimensions defined in model
143        model_input: Vec<Vec<Option<u32>>>,
144    },
145}
146
147/// Error details when ONNX C API fail
148#[non_exhaustive]
149#[derive(Error, Debug)]
150pub enum OrtApiError {
151    /// Details as reported by the ONNX C API in case of error
152    #[error("Error calling ONNX Runtime C function: {0}")]
153    Msg(String),
154    /// Details as reported by the ONNX C API in case of error cannot be converted to UTF-8
155    #[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
156    IntoStringError(std::ffi::IntoStringError),
157}
158
159/// Error from downloading pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models).
160#[non_exhaustive]
161#[derive(Error, Debug)]
162pub enum OrtDownloadError {
163    /// Generic input/output error
164    #[error("Error downloading data to file: {0}")]
165    IoError(#[from] io::Error),
166    #[cfg(feature = "model-fetching")]
167    /// Download error by ureq
168    #[error("Error downloading data to file: {0}")]
169    UreqError(#[from] Box<ureq::Error>),
170    /// Error getting content-length from an HTTP GET request
171    #[error("Error getting content-length")]
172    ContentLengthError,
173    /// Mismatch between amount of downloaded and expected bytes
174    #[error("Error copying data to file: expected {expected} length, received {io}")]
175    CopyError {
176        /// Expected amount of bytes to download
177        expected: u64,
178        /// Number of bytes read from network and written to file
179        io: u64,
180    },
181}
182
183/// Wrapper type around a ONNX C API's `OrtStatus` pointer
184///
185/// This wrapper exists to facilitate conversion from C raw pointers to Rust error types
186pub struct OrtStatusWrapper(*const sys::OrtStatus);
187
188impl From<*const sys::OrtStatus> for OrtStatusWrapper {
189    fn from(status: *const sys::OrtStatus) -> Self {
190        OrtStatusWrapper(status)
191    }
192}
193
194pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
195    ptr.is_null()
196        .then(|| ())
197        .ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
198}
199
200pub(crate) fn assert_not_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
201    (!ptr.is_null())
202        .then(|| ())
203        .ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
204}
205
206impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
207    fn from(status: OrtStatusWrapper) -> Self {
208        if status.0.is_null() {
209            Ok(())
210        } else {
211            let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
212            match char_p_to_string(raw) {
213                Ok(msg) => Err(OrtApiError::Msg(msg)),
214                Err(err) => match err {
215                    OrtError::StringConversion(OrtApiError::IntoStringError(e)) => {
216                        Err(OrtApiError::IntoStringError(e))
217                    }
218                    _ => unreachable!(),
219                },
220            }
221        }
222    }
223}
224
225pub(crate) fn status_to_result(
226    status: *const sys::OrtStatus,
227) -> std::result::Result<(), OrtApiError> {
228    let status_wrapper: OrtStatusWrapper = status.into();
229    status_wrapper.into()
230}
231
232/// A wrapper around a function on OrtApi that maps the status code into [OrtApiError]
233pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
234where
235    F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
236{
237    status_to_result(f(g_ort()))
238}