1use std::fmt;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum StatusCode {
24 RuntimeError = 1,
26 DelegateError = 2,
28 ApplicationError = 3,
30 DelegateDataNotFound = 4,
32 DelegateDataWriteError = 5,
34 DelegateDataReadError = 6,
36 UnresolvedOps = 7,
38 Cancelled = 8,
40 OutputShapeNotKnown = 9,
42}
43
44impl StatusCode {
45 fn from_raw(value: u32) -> Option<Self> {
49 match value {
50 1 => Some(Self::RuntimeError),
51 2 => Some(Self::DelegateError),
52 3 => Some(Self::ApplicationError),
53 4 => Some(Self::DelegateDataNotFound),
54 5 => Some(Self::DelegateDataWriteError),
55 6 => Some(Self::DelegateDataReadError),
56 7 => Some(Self::UnresolvedOps),
57 8 => Some(Self::Cancelled),
58 9 => Some(Self::OutputShapeNotKnown),
59 _ => None,
60 }
61 }
62}
63
64impl fmt::Display for StatusCode {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::RuntimeError => f.write_str("runtime error"),
68 Self::DelegateError => f.write_str("delegate error"),
69 Self::ApplicationError => f.write_str("application error"),
70 Self::DelegateDataNotFound => f.write_str("delegate data not found"),
71 Self::DelegateDataWriteError => f.write_str("delegate data write error"),
72 Self::DelegateDataReadError => f.write_str("delegate data read error"),
73 Self::UnresolvedOps => f.write_str("unresolved ops"),
74 Self::Cancelled => f.write_str("cancelled"),
75 Self::OutputShapeNotKnown => f.write_str("output shape not known"),
76 }
77 }
78}
79
80#[derive(Debug)]
86enum ErrorKind {
87 Status(StatusCode),
89 NullPointer,
91 Library(libloading::Error),
93 InvalidArgument(String),
95}
96
97#[derive(Debug)]
107pub struct Error {
108 kind: ErrorKind,
109 context: Option<String>,
110}
111
112impl Error {
115 #[must_use]
118 pub fn is_library_error(&self) -> bool {
119 matches!(self.kind, ErrorKind::Library(_))
120 }
121
122 #[must_use]
128 pub fn is_delegate_error(&self) -> bool {
129 matches!(
130 self.kind,
131 ErrorKind::Status(
132 StatusCode::DelegateError
133 | StatusCode::DelegateDataNotFound
134 | StatusCode::DelegateDataWriteError
135 | StatusCode::DelegateDataReadError
136 )
137 )
138 }
139
140 #[must_use]
142 pub fn is_null_pointer(&self) -> bool {
143 matches!(self.kind, ErrorKind::NullPointer)
144 }
145
146 #[must_use]
148 pub fn is_invalid_argument(&self) -> bool {
149 matches!(self.kind, ErrorKind::InvalidArgument(_))
150 }
151
152 #[must_use]
155 pub fn status_code(&self) -> Option<StatusCode> {
156 if let ErrorKind::Status(code) = self.kind {
157 Some(code)
158 } else {
159 None
160 }
161 }
162
163 #[must_use]
168 pub fn with_context(mut self, context: impl Into<String>) -> Self {
169 self.context = Some(context.into());
170 self
171 }
172}
173
174impl Error {
177 #[must_use]
179 pub(crate) fn status(code: StatusCode) -> Self {
180 Self {
181 kind: ErrorKind::Status(code),
182 context: None,
183 }
184 }
185
186 #[must_use]
189 pub(crate) fn null_pointer(context: impl Into<String>) -> Self {
190 Self {
191 kind: ErrorKind::NullPointer,
192 context: Some(context.into()),
193 }
194 }
195
196 #[must_use]
198 pub(crate) fn invalid_argument(msg: impl Into<String>) -> Self {
199 Self {
200 kind: ErrorKind::InvalidArgument(msg.into()),
201 context: None,
202 }
203 }
204}
205
206impl fmt::Display for Error {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 match &self.kind {
211 ErrorKind::Status(code) => write!(f, "TFLite status: {code}")?,
212 ErrorKind::NullPointer => f.write_str("null pointer from C API")?,
213 ErrorKind::Library(inner) => write!(f, "library loading error: {inner}")?,
214 ErrorKind::InvalidArgument(msg) => write!(f, "invalid argument: {msg}")?,
215 }
216 if let Some(ctx) = &self.context {
217 write!(f, " ({ctx})")?;
218 }
219 Ok(())
220 }
221}
222
223impl std::error::Error for Error {
226 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
227 match &self.kind {
228 ErrorKind::Library(inner) => Some(inner),
229 _ => None,
230 }
231 }
232}
233
234impl From<libloading::Error> for Error {
237 fn from(err: libloading::Error) -> Self {
238 Self {
239 kind: ErrorKind::Library(err),
240 context: None,
241 }
242 }
243}
244
245pub(crate) fn status_to_result(status: u32) -> Result<()> {
255 if status == 0 {
256 return Ok(());
257 }
258 let code = StatusCode::from_raw(status).unwrap_or(StatusCode::RuntimeError);
259 Err(Error::status(code))
260}
261
262pub type Result<T> = std::result::Result<T, Error>;
269
270#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn status_ok_is_ok() {
280 assert!(status_to_result(0).is_ok());
281 }
282
283 #[test]
284 fn status_error_maps_correctly() {
285 let err = status_to_result(1).unwrap_err();
286 assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
287 }
288
289 #[test]
290 fn status_delegate_codes() {
291 for (raw, expected) in [
292 (2, StatusCode::DelegateError),
293 (4, StatusCode::DelegateDataNotFound),
294 (5, StatusCode::DelegateDataWriteError),
295 (6, StatusCode::DelegateDataReadError),
296 ] {
297 let err = status_to_result(raw).unwrap_err();
298 assert_eq!(err.status_code(), Some(expected));
299 assert!(err.is_delegate_error());
300 }
301 }
302
303 #[test]
304 fn status_all_known_codes() {
305 for raw in 1..=9 {
306 let err = status_to_result(raw).unwrap_err();
307 assert!(err.status_code().is_some());
308 }
309 }
310
311 #[test]
312 fn unknown_status_falls_back_to_runtime_error() {
313 let err = status_to_result(42).unwrap_err();
314 assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
315 }
316
317 #[test]
318 fn null_pointer_error() {
319 let err = Error::null_pointer("TfLiteModelCreate");
320 assert!(err.is_null_pointer());
321 assert!(!err.is_library_error());
322 assert!(!err.is_delegate_error());
323 assert!(err.status_code().is_none());
324 assert!(err.to_string().contains("null pointer"));
325 assert!(err.to_string().contains("TfLiteModelCreate"));
326 }
327
328 #[test]
329 fn invalid_argument_error() {
330 let err = Error::invalid_argument("tensor index out of range");
331 assert!(!err.is_null_pointer());
332 assert!(err.to_string().contains("tensor index out of range"));
333 }
334
335 #[test]
336 fn with_context_appends_message() {
337 let err = Error::status(StatusCode::RuntimeError).with_context("during AllocateTensors");
338 let msg = err.to_string();
339 assert!(msg.contains("runtime error"));
340 assert!(msg.contains("during AllocateTensors"));
341 }
342
343 #[test]
344 fn from_libloading_error() {
345 let lib_err = unsafe { libloading::Library::new("__nonexistent__.so") }.unwrap_err();
348 let err = Error::from(lib_err);
349 assert!(err.is_library_error());
350 assert!(err.status_code().is_none());
351 assert!(std::error::Error::source(&err).is_some());
352 }
353
354 #[test]
355 fn display_includes_status_code_name() {
356 let err = Error::status(StatusCode::Cancelled);
357 assert!(err.to_string().contains("cancelled"));
358 }
359
360 #[test]
361 fn non_delegate_status_is_not_delegate_error() {
362 let err = Error::status(StatusCode::RuntimeError);
363 assert!(!err.is_delegate_error());
364 }
365
366 #[test]
367 fn status_code_discriminant_values() {
368 assert_eq!(StatusCode::RuntimeError as u32, 1);
369 assert_eq!(StatusCode::DelegateError as u32, 2);
370 assert_eq!(StatusCode::ApplicationError as u32, 3);
371 assert_eq!(StatusCode::DelegateDataNotFound as u32, 4);
372 assert_eq!(StatusCode::DelegateDataWriteError as u32, 5);
373 assert_eq!(StatusCode::DelegateDataReadError as u32, 6);
374 assert_eq!(StatusCode::UnresolvedOps as u32, 7);
375 assert_eq!(StatusCode::Cancelled as u32, 8);
376 assert_eq!(StatusCode::OutputShapeNotKnown as u32, 9);
377 }
378
379 #[test]
380 fn status_code_display_all_variants() {
381 let cases = [
382 (StatusCode::RuntimeError, "runtime error"),
383 (StatusCode::DelegateError, "delegate error"),
384 (StatusCode::ApplicationError, "application error"),
385 (StatusCode::DelegateDataNotFound, "delegate data not found"),
386 (
387 StatusCode::DelegateDataWriteError,
388 "delegate data write error",
389 ),
390 (
391 StatusCode::DelegateDataReadError,
392 "delegate data read error",
393 ),
394 (StatusCode::UnresolvedOps, "unresolved ops"),
395 (StatusCode::Cancelled, "cancelled"),
396 (StatusCode::OutputShapeNotKnown, "output shape not known"),
397 ];
398 for (code, expected) in cases {
399 assert_eq!(code.to_string(), expected);
400 }
401 }
402
403 #[test]
404 fn error_debug_format() {
405 let err = Error::status(StatusCode::RuntimeError);
406 let debug = format!("{err:?}");
407 assert!(debug.contains("Error"));
408 assert!(debug.contains("Status"));
409 }
410}