1use std::fmt;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum StatusCode {
24 RuntimeError = 1,
26 DelegateError = 2,
28 ApplicationError = 3,
30 DelegateDataNotFound = 4,
32 DelegateDataWriteError = 5,
34 DelegateDataReadError = 6,
36 UnresolvedOps = 7,
38 Cancelled = 8,
40 OutputShapeNotKnown = 9,
42}
43
44impl StatusCode {
45 fn from_raw(value: u32) -> Option<Self> {
49 match value {
50 1 => Some(Self::RuntimeError),
51 2 => Some(Self::DelegateError),
52 3 => Some(Self::ApplicationError),
53 4 => Some(Self::DelegateDataNotFound),
54 5 => Some(Self::DelegateDataWriteError),
55 6 => Some(Self::DelegateDataReadError),
56 7 => Some(Self::UnresolvedOps),
57 8 => Some(Self::Cancelled),
58 9 => Some(Self::OutputShapeNotKnown),
59 _ => None,
60 }
61 }
62}
63
64impl fmt::Display for StatusCode {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::RuntimeError => f.write_str("runtime error"),
68 Self::DelegateError => f.write_str("delegate error"),
69 Self::ApplicationError => f.write_str("application error"),
70 Self::DelegateDataNotFound => f.write_str("delegate data not found"),
71 Self::DelegateDataWriteError => f.write_str("delegate data write error"),
72 Self::DelegateDataReadError => f.write_str("delegate data read error"),
73 Self::UnresolvedOps => f.write_str("unresolved ops"),
74 Self::Cancelled => f.write_str("cancelled"),
75 Self::OutputShapeNotKnown => f.write_str("output shape not known"),
76 }
77 }
78}
79
80#[derive(Debug)]
86enum ErrorKind {
87 Status(StatusCode),
89 NullPointer,
91 Library(libloading::Error),
93 InvalidArgument(String),
95}
96
97#[derive(Debug)]
107pub struct Error {
108 kind: ErrorKind,
109 context: Option<String>,
110}
111
112impl Error {
115 #[must_use]
118 pub fn is_library_error(&self) -> bool {
119 matches!(self.kind, ErrorKind::Library(_))
120 }
121
122 #[must_use]
128 pub fn is_delegate_error(&self) -> bool {
129 matches!(
130 self.kind,
131 ErrorKind::Status(
132 StatusCode::DelegateError
133 | StatusCode::DelegateDataNotFound
134 | StatusCode::DelegateDataWriteError
135 | StatusCode::DelegateDataReadError
136 )
137 )
138 }
139
140 #[must_use]
142 pub fn is_null_pointer(&self) -> bool {
143 matches!(self.kind, ErrorKind::NullPointer)
144 }
145
146 #[must_use]
148 pub fn is_invalid_argument(&self) -> bool {
149 matches!(self.kind, ErrorKind::InvalidArgument(_))
150 }
151
152 #[must_use]
155 pub fn status_code(&self) -> Option<StatusCode> {
156 if let ErrorKind::Status(code) = self.kind {
157 Some(code)
158 } else {
159 None
160 }
161 }
162
163 #[must_use]
168 pub fn with_context(mut self, context: impl Into<String>) -> Self {
169 self.context = Some(context.into());
170 self
171 }
172}
173
174impl Error {
177 #[must_use]
179 pub(crate) fn status(code: StatusCode) -> Self {
180 Self {
181 kind: ErrorKind::Status(code),
182 context: None,
183 }
184 }
185
186 #[must_use]
189 pub(crate) fn null_pointer(context: impl Into<String>) -> Self {
190 Self {
191 kind: ErrorKind::NullPointer,
192 context: Some(context.into()),
193 }
194 }
195
196 #[must_use]
198 pub(crate) fn invalid_argument(msg: impl Into<String>) -> Self {
199 Self {
200 kind: ErrorKind::InvalidArgument(msg.into()),
201 context: None,
202 }
203 }
204}
205
206impl fmt::Display for Error {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 match &self.kind {
211 ErrorKind::Status(code) => write!(f, "TFLite status: {code}")?,
212 ErrorKind::NullPointer => f.write_str("null pointer from C API")?,
213 ErrorKind::Library(inner) => write!(f, "library loading error: {inner}")?,
214 ErrorKind::InvalidArgument(msg) => write!(f, "invalid argument: {msg}")?,
215 }
216 if let Some(ctx) = &self.context {
217 write!(f, " ({ctx})")?;
218 }
219 Ok(())
220 }
221}
222
223impl std::error::Error for Error {
226 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
227 match &self.kind {
228 ErrorKind::Library(inner) => Some(inner),
229 _ => None,
230 }
231 }
232}
233
234impl From<libloading::Error> for Error {
237 fn from(err: libloading::Error) -> Self {
238 Self {
239 kind: ErrorKind::Library(err),
240 context: None,
241 }
242 }
243}
244
245pub(crate) fn hal_to_result(ret: std::ffi::c_int, context: &str) -> Result<()> {
255 if ret == 0 {
256 return Ok(());
257 }
258 let os_err = std::io::Error::last_os_error();
259 Err(Error::status(StatusCode::DelegateError).with_context(format!("{context}: {os_err}")))
260}
261
262pub(crate) fn status_to_result(status: u32) -> Result<()> {
272 if status == 0 {
273 return Ok(());
274 }
275 let code = StatusCode::from_raw(status).unwrap_or(StatusCode::RuntimeError);
276 Err(Error::status(code))
277}
278
279pub type Result<T> = std::result::Result<T, Error>;
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn status_ok_is_ok() {
297 assert!(status_to_result(0).is_ok());
298 }
299
300 #[test]
301 fn status_error_maps_correctly() {
302 let err = status_to_result(1).unwrap_err();
303 assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
304 }
305
306 #[test]
307 fn status_delegate_codes() {
308 for (raw, expected) in [
309 (2, StatusCode::DelegateError),
310 (4, StatusCode::DelegateDataNotFound),
311 (5, StatusCode::DelegateDataWriteError),
312 (6, StatusCode::DelegateDataReadError),
313 ] {
314 let err = status_to_result(raw).unwrap_err();
315 assert_eq!(err.status_code(), Some(expected));
316 assert!(err.is_delegate_error());
317 }
318 }
319
320 #[test]
321 fn status_all_known_codes() {
322 for raw in 1..=9 {
323 let err = status_to_result(raw).unwrap_err();
324 assert!(err.status_code().is_some());
325 }
326 }
327
328 #[test]
329 fn unknown_status_falls_back_to_runtime_error() {
330 let err = status_to_result(42).unwrap_err();
331 assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
332 }
333
334 #[test]
335 fn null_pointer_error() {
336 let err = Error::null_pointer("TfLiteModelCreate");
337 assert!(err.is_null_pointer());
338 assert!(!err.is_library_error());
339 assert!(!err.is_delegate_error());
340 assert!(err.status_code().is_none());
341 assert!(err.to_string().contains("null pointer"));
342 assert!(err.to_string().contains("TfLiteModelCreate"));
343 }
344
345 #[test]
346 fn invalid_argument_error() {
347 let err = Error::invalid_argument("tensor index out of range");
348 assert!(!err.is_null_pointer());
349 assert!(err.to_string().contains("tensor index out of range"));
350 }
351
352 #[test]
353 fn with_context_appends_message() {
354 let err = Error::status(StatusCode::RuntimeError).with_context("during AllocateTensors");
355 let msg = err.to_string();
356 assert!(msg.contains("runtime error"));
357 assert!(msg.contains("during AllocateTensors"));
358 }
359
360 #[test]
361 fn from_libloading_error() {
362 let lib_err = unsafe { libloading::Library::new("__nonexistent__.so") }.unwrap_err();
365 let err = Error::from(lib_err);
366 assert!(err.is_library_error());
367 assert!(err.status_code().is_none());
368 assert!(std::error::Error::source(&err).is_some());
369 }
370
371 #[test]
372 fn display_includes_status_code_name() {
373 let err = Error::status(StatusCode::Cancelled);
374 assert!(err.to_string().contains("cancelled"));
375 }
376
377 #[test]
378 fn non_delegate_status_is_not_delegate_error() {
379 let err = Error::status(StatusCode::RuntimeError);
380 assert!(!err.is_delegate_error());
381 }
382
383 #[test]
384 fn status_code_discriminant_values() {
385 assert_eq!(StatusCode::RuntimeError as u32, 1);
386 assert_eq!(StatusCode::DelegateError as u32, 2);
387 assert_eq!(StatusCode::ApplicationError as u32, 3);
388 assert_eq!(StatusCode::DelegateDataNotFound as u32, 4);
389 assert_eq!(StatusCode::DelegateDataWriteError as u32, 5);
390 assert_eq!(StatusCode::DelegateDataReadError as u32, 6);
391 assert_eq!(StatusCode::UnresolvedOps as u32, 7);
392 assert_eq!(StatusCode::Cancelled as u32, 8);
393 assert_eq!(StatusCode::OutputShapeNotKnown as u32, 9);
394 }
395
396 #[test]
397 fn status_code_display_all_variants() {
398 let cases = [
399 (StatusCode::RuntimeError, "runtime error"),
400 (StatusCode::DelegateError, "delegate error"),
401 (StatusCode::ApplicationError, "application error"),
402 (StatusCode::DelegateDataNotFound, "delegate data not found"),
403 (
404 StatusCode::DelegateDataWriteError,
405 "delegate data write error",
406 ),
407 (
408 StatusCode::DelegateDataReadError,
409 "delegate data read error",
410 ),
411 (StatusCode::UnresolvedOps, "unresolved ops"),
412 (StatusCode::Cancelled, "cancelled"),
413 (StatusCode::OutputShapeNotKnown, "output shape not known"),
414 ];
415 for (code, expected) in cases {
416 assert_eq!(code.to_string(), expected);
417 }
418 }
419
420 #[test]
421 fn error_debug_format() {
422 let err = Error::status(StatusCode::RuntimeError);
423 let debug = format!("{err:?}");
424 assert!(debug.contains("Error"));
425 assert!(debug.contains("Status"));
426 }
427}