1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use std::{io, path::PathBuf};
use thiserror::Error;
use onnxruntime_sys as sys;
use crate::{char_p_to_string, g_ort};
pub type Result<T> = std::result::Result<T, OrtError>;
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtError {
#[error("Failed to construct String")]
StringConversion(OrtApiError),
#[error("Failed to create environment: {0}")]
Environment(OrtApiError),
#[error("Failed to create session options: {0}")]
SessionOptions(OrtApiError),
#[error("Failed to create session: {0}")]
Session(OrtApiError),
#[error("Failed to get allocator: {0}")]
Allocator(OrtApiError),
#[error("Failed to get input or output count: {0}")]
InOutCount(OrtApiError),
#[error("Failed to get input name: {0}")]
InputName(OrtApiError),
#[error("Failed to get type info: {0}")]
GetTypeInfo(OrtApiError),
#[error("Failed to cast type info to tensor info: {0}")]
CastTypeInfoToTensorInfo(OrtApiError),
#[error("Failed to get tensor element type: {0}")]
TensorElementType(OrtApiError),
#[error("Failed to get dimensions count: {0}")]
GetDimensionsCount(OrtApiError),
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(OrtApiError),
#[error("Failed to check if tensor: {0}")]
IsTensor(OrtApiError),
#[error("Failed to run: {0}")]
Run(OrtApiError),
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] OrtDownloadError),
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
#[error("File {filename:?} does not exists")]
FileDoesNotExists {
filename: PathBuf,
},
#[error("Path {path:?} cannot be converted to UTF-8")]
NonUtf8Path {
path: PathBuf,
},
#[error("Failed to build CString when original contains null: {0}")]
CStringNulError(#[from] std::ffi::NulError),
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
InputsCount {
inference_input_count: usize,
model_input_count: usize,
inference_input: Vec<Vec<usize>>,
model_input: Vec<Vec<Option<u32>>>,
},
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtApiError {
#[error("Error calling ONNX Runtime C function")]
Msg(String),
#[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
IntoStringError(std::ffi::IntoStringError),
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtDownloadError {
#[error("Error downloading data to file: {0}")]
IoError(#[from] io::Error),
#[error("Error getting content-length")]
ContentLengthError,
#[error("Error copying data to file: expected {expected} length, received {io}")]
CopyError {
expected: u64,
io: u64,
},
}
pub struct OrtStatusWrapper(*const sys::OrtStatus);
impl From<*const sys::OrtStatus> for OrtStatusWrapper {
fn from(status: *const sys::OrtStatus) -> Self {
OrtStatusWrapper(status)
}
}
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
fn from(status: OrtStatusWrapper) -> Self {
if status.0.is_null() {
Ok(())
} else {
let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
match char_p_to_string(raw) {
Ok(msg) => Err(OrtApiError::Msg(msg)),
Err(err) => match err {
OrtError::StringConversion(e) => match e {
OrtApiError::IntoStringError(e) => Err(OrtApiError::IntoStringError(e)),
_ => unreachable!(),
},
_ => unreachable!(),
},
}
}
}
}
pub(crate) fn status_to_result(
status: *const sys::OrtStatus,
) -> std::result::Result<(), OrtApiError> {
let status_wrapper: OrtStatusWrapper = status.into();
status_wrapper.into()
}