onnxruntime/
error.rs

1//! Module containing error definitions.
2
3use std::{io, path::PathBuf};
4
5use thiserror::Error;
6
7use onnxruntime_sys as sys;
8
9use crate::{char_p_to_string, g_ort};
10
11/// Type alias for the `Result`
12pub type Result<T> = std::result::Result<T, OrtError>;
13
14/// Error type centralizing all possible errors
15#[non_exhaustive]
16#[derive(Error, Debug)]
17pub enum OrtError {
18    /// The C API can message to the caller using a C `char *` which needs to be converted
19    /// to Rust's `String`. This operation can fail.
20    #[error("Failed to construct String")]
21    StringConversion(OrtApiError),
22    // FIXME: Move these to another enum (they are C API calls errors)
23    /// An error occurred when creating an ONNX environment
24    #[error("Failed to create environment: {0}")]
25    Environment(OrtApiError),
26    /// Error occurred when creating an ONNX session options
27    #[error("Failed to create session options: {0}")]
28    SessionOptions(OrtApiError),
29    /// Error occurred when creating an ONNX session
30    #[error("Failed to create session: {0}")]
31    Session(OrtApiError),
32    /// Error occurred when creating an ONNX allocator
33    #[error("Failed to get allocator: {0}")]
34    Allocator(OrtApiError),
35    /// Error occurred when counting ONNX input or output count
36    #[error("Failed to get input or output count: {0}")]
37    InOutCount(OrtApiError),
38    /// Error occurred when getting ONNX input name
39    #[error("Failed to get input name: {0}")]
40    InputName(OrtApiError),
41    /// Error occurred when getting ONNX type information
42    #[error("Failed to get type info: {0}")]
43    GetTypeInfo(OrtApiError),
44    /// Error occurred when casting ONNX type information to tensor information
45    #[error("Failed to cast type info to tensor info: {0}")]
46    CastTypeInfoToTensorInfo(OrtApiError),
47    /// Error occurred when getting tensor elements type
48    #[error("Failed to get tensor element type: {0}")]
49    TensorElementType(OrtApiError),
50    /// Error occurred when getting ONNX dimensions count
51    #[error("Failed to get dimensions count: {0}")]
52    GetDimensionsCount(OrtApiError),
53    /// Error occurred when getting ONNX dimensions
54    #[error("Failed to get dimensions: {0}")]
55    GetDimensions(OrtApiError),
56    /// Error occurred when creating CPU memory information
57    #[error("Failed to get dimensions: {0}")]
58    CreateCpuMemoryInfo(OrtApiError),
59    /// Error occurred when creating ONNX tensor
60    #[error("Failed to create tensor: {0}")]
61    CreateTensor(OrtApiError),
62    /// Error occurred when creating ONNX tensor with specific data
63    #[error("Failed to create tensor with data: {0}")]
64    CreateTensorWithData(OrtApiError),
65    /// Error occurred when filling a tensor with string data
66    #[error("Failed to fill string tensor: {0}")]
67    FillStringTensor(OrtApiError),
68    /// Error occurred when checking if ONNX tensor was properly initialized
69    #[error("Failed to check if tensor: {0}")]
70    IsTensor(OrtApiError),
71    /// Error occurred when getting tensor type and shape
72    #[error("Failed to get tensor type and shape: {0}")]
73    GetTensorTypeAndShape(OrtApiError),
74    /// Error occurred when ONNX inference operation was called
75    #[error("Failed to run: {0}")]
76    Run(OrtApiError),
77    /// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
78    #[error("Failed to get tensor data: {0}")]
79    GetTensorMutableData(OrtApiError),
80
81    /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
82    #[error("Failed to download ONNX model: {0}")]
83    DownloadError(#[from] OrtDownloadError),
84
85    /// Dimensions of input data and ONNX model loaded from file do not match
86    #[error("Dimensions do not match: {0:?}")]
87    NonMatchingDimensions(NonMatchingDimensionsError),
88    /// File does not exists
89    #[error("File {filename:?} does not exists")]
90    FileDoesNotExists {
91        /// Path which does not exists
92        filename: PathBuf,
93    },
94    /// Path is an invalid UTF-8
95    #[error("Path {path:?} cannot be converted to UTF-8")]
96    NonUtf8Path {
97        /// Path with invalid UTF-8
98        path: PathBuf,
99    },
100    /// Attempt to build a Rust `CString` from a null pointer
101    #[error("Failed to build CString when original contains null: {0}")]
102    CStringNulError(#[from] std::ffi::NulError),
103}
104
105/// Error used when dimensions of input (from model and from inference call)
106/// do not match (as they should).
107#[non_exhaustive]
108#[derive(Error, Debug)]
109pub enum NonMatchingDimensionsError {
110    /// Number of inputs from model does not match number of inputs from inference call
111    #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
112    InputsCount {
113        /// Number of input dimensions used by inference call
114        inference_input_count: usize,
115        /// Number of input dimensions defined in model
116        model_input_count: usize,
117        /// Input dimensions used by inference call
118        inference_input: Vec<Vec<usize>>,
119        /// Input dimensions defined in model
120        model_input: Vec<Vec<Option<u32>>>,
121    },
122    /// Inputs length from model does not match the expected input from inference call
123    #[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
124    InputsLength {
125        /// Input dimensions used by inference call
126        inference_input: Vec<Vec<usize>>,
127        /// Input dimensions defined in model
128        model_input: Vec<Vec<Option<u32>>>,
129    },
130}
131
132/// Error details when ONNX C API fail
133#[non_exhaustive]
134#[derive(Error, Debug)]
135pub enum OrtApiError {
136    /// Details as reported by the ONNX C API in case of error
137    #[error("Error calling ONNX Runtime C function: {0}")]
138    Msg(String),
139    /// Details as reported by the ONNX C API in case of error cannot be converted to UTF-8
140    #[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
141    IntoStringError(std::ffi::IntoStringError),
142}
143
144/// Error from downloading pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models).
145#[non_exhaustive]
146#[derive(Error, Debug)]
147pub enum OrtDownloadError {
148    /// Generic input/output error
149    #[error("Error downloading data to file: {0}")]
150    IoError(#[from] io::Error),
151    #[cfg(feature = "model-fetching")]
152    /// Download error by ureq
153    #[error("Error downloading data to file: {0}")]
154    UreqError(#[from] Box<ureq::Error>),
155    /// Error getting content-length from an HTTP GET request
156    #[error("Error getting content-length")]
157    ContentLengthError,
158    /// Mismatch between amount of downloaded and expected bytes
159    #[error("Error copying data to file: expected {expected} length, received {io}")]
160    CopyError {
161        /// Expected amount of bytes to download
162        expected: u64,
163        /// Number of bytes read from network and written to file
164        io: u64,
165    },
166}
167
168/// Wrapper type around a ONNX C API's `OrtStatus` pointer
169///
170/// This wrapper exists to facilitate conversion from C raw pointers to Rust error types
171pub struct OrtStatusWrapper(*const sys::OrtStatus);
172
173impl From<*const sys::OrtStatus> for OrtStatusWrapper {
174    fn from(status: *const sys::OrtStatus) -> Self {
175        OrtStatusWrapper(status)
176    }
177}
178
179impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
180    fn from(status: OrtStatusWrapper) -> Self {
181        if status.0.is_null() {
182            Ok(())
183        } else {
184            let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
185            match char_p_to_string(raw) {
186                Ok(msg) => Err(OrtApiError::Msg(msg)),
187                Err(err) => match err {
188                    OrtError::StringConversion(OrtApiError::IntoStringError(e)) => {
189                        Err(OrtApiError::IntoStringError(e))
190                    }
191                    _ => unreachable!(),
192                },
193            }
194        }
195    }
196}
197
198pub(crate) fn status_to_result(
199    status: *const sys::OrtStatus,
200) -> std::result::Result<(), OrtApiError> {
201    let status_wrapper: OrtStatusWrapper = status.into();
202    status_wrapper.into()
203}
204
205/// A wrapper around a function on OrtApi that maps the status code into [OrtApiError]
206pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
207where
208    F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
209{
210    status_to_result(f(g_ort()))
211}