Skip to main content

edgefirst_tflite/
error.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Error types for the `edgefirst-tflite` crate.
5//!
6//! This module follows the canonical error struct pattern: a public [`Error`]
7//! struct wrapping a private `ErrorKind` enum. Callers inspect errors through
8//! [`Error::is_library_error`], [`Error::is_delegate_error`],
9//! [`Error::is_null_pointer`], and [`Error::status_code`] rather than matching
10//! on variants directly.
11
12use std::fmt;
13
14// ---------------------------------------------------------------------------
15// StatusCode
16// ---------------------------------------------------------------------------
17
18/// Status codes returned by the TensorFlow Lite C API.
19///
20/// Each variant maps to a `kTfLite*` constant defined in the C header
21/// `common.h`. The numeric value is accessible via `as u32`.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum StatusCode {
24    /// Generic runtime error (`kTfLiteError = 1`).
25    RuntimeError = 1,
26    /// Delegate returned an error (`kTfLiteDelegateError = 2`).
27    DelegateError = 2,
28    /// Application-level error (`kTfLiteApplicationError = 3`).
29    ApplicationError = 3,
30    /// Delegate data not found (`kTfLiteDelegateDataNotFound = 4`).
31    DelegateDataNotFound = 4,
32    /// Delegate data write error (`kTfLiteDelegateDataWriteError = 5`).
33    DelegateDataWriteError = 5,
34    /// Delegate data read error (`kTfLiteDelegateDataReadError = 6`).
35    DelegateDataReadError = 6,
36    /// Model contains unresolved ops (`kTfLiteUnresolvedOps = 7`).
37    UnresolvedOps = 7,
38    /// Operation was cancelled (`kTfLiteCancelled = 8`).
39    Cancelled = 8,
40    /// Output tensor shape is not yet known (`kTfLiteOutputShapeNotKnown = 9`).
41    OutputShapeNotKnown = 9,
42}
43
44impl StatusCode {
45    /// Attempt to convert a raw C API status value into a `StatusCode`.
46    ///
47    /// Returns `None` for `kTfLiteOk` (0) or any unknown value.
48    fn from_raw(value: u32) -> Option<Self> {
49        match value {
50            1 => Some(Self::RuntimeError),
51            2 => Some(Self::DelegateError),
52            3 => Some(Self::ApplicationError),
53            4 => Some(Self::DelegateDataNotFound),
54            5 => Some(Self::DelegateDataWriteError),
55            6 => Some(Self::DelegateDataReadError),
56            7 => Some(Self::UnresolvedOps),
57            8 => Some(Self::Cancelled),
58            9 => Some(Self::OutputShapeNotKnown),
59            _ => None,
60        }
61    }
62}
63
64impl fmt::Display for StatusCode {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Self::RuntimeError => f.write_str("runtime error"),
68            Self::DelegateError => f.write_str("delegate error"),
69            Self::ApplicationError => f.write_str("application error"),
70            Self::DelegateDataNotFound => f.write_str("delegate data not found"),
71            Self::DelegateDataWriteError => f.write_str("delegate data write error"),
72            Self::DelegateDataReadError => f.write_str("delegate data read error"),
73            Self::UnresolvedOps => f.write_str("unresolved ops"),
74            Self::Cancelled => f.write_str("cancelled"),
75            Self::OutputShapeNotKnown => f.write_str("output shape not known"),
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// ErrorKind (private)
82// ---------------------------------------------------------------------------
83
84/// Internal error classification. Not exposed to consumers.
85#[derive(Debug)]
86enum ErrorKind {
87    /// The TensorFlow Lite C API returned a non-OK status.
88    Status(StatusCode),
89    /// A C API function returned a null pointer.
90    NullPointer,
91    /// Library loading or symbol resolution failed.
92    Library(libloading::Error),
93    /// An invalid argument was passed to the API.
94    InvalidArgument(String),
95}
96
97// ---------------------------------------------------------------------------
98// Error
99// ---------------------------------------------------------------------------
100
101/// The error type for all fallible operations in `edgefirst-tflite`.
102///
103/// `Error` wraps a private `ErrorKind` enum so that the set of failure modes
104/// can grow without breaking callers. Use the `is_*()` inspection methods and
105/// [`Error::status_code`] to classify an error programmatically.
106#[derive(Debug)]
107pub struct Error {
108    kind: ErrorKind,
109    context: Option<String>,
110}
111
112// -- Public inspection API --------------------------------------------------
113
114impl Error {
115    /// Returns `true` if this error originated from library loading or symbol
116    /// resolution (i.e. a [`libloading::Error`]).
117    #[must_use]
118    pub fn is_library_error(&self) -> bool {
119        matches!(self.kind, ErrorKind::Library(_))
120    }
121
122    /// Returns `true` if the underlying `TFLite` status is one of the delegate
123    /// error codes: [`StatusCode::DelegateError`],
124    /// [`StatusCode::DelegateDataNotFound`],
125    /// [`StatusCode::DelegateDataWriteError`], or
126    /// [`StatusCode::DelegateDataReadError`].
127    #[must_use]
128    pub fn is_delegate_error(&self) -> bool {
129        matches!(
130            self.kind,
131            ErrorKind::Status(
132                StatusCode::DelegateError
133                    | StatusCode::DelegateDataNotFound
134                    | StatusCode::DelegateDataWriteError
135                    | StatusCode::DelegateDataReadError
136            )
137        )
138    }
139
140    /// Returns `true` if a C API call returned a null pointer.
141    #[must_use]
142    pub fn is_null_pointer(&self) -> bool {
143        matches!(self.kind, ErrorKind::NullPointer)
144    }
145
146    /// Returns `true` if this error is an invalid-argument error.
147    #[must_use]
148    pub fn is_invalid_argument(&self) -> bool {
149        matches!(self.kind, ErrorKind::InvalidArgument(_))
150    }
151
152    /// Returns the `TFLite` [`StatusCode`] when the error originated from a
153    /// non-OK C API status, or `None` otherwise.
154    #[must_use]
155    pub fn status_code(&self) -> Option<StatusCode> {
156        if let ErrorKind::Status(code) = self.kind {
157            Some(code)
158        } else {
159            None
160        }
161    }
162
163    /// Attach additional human-readable context to this error.
164    ///
165    /// The context string is appended in parentheses when the error is
166    /// displayed.
167    #[must_use]
168    pub fn with_context(mut self, context: impl Into<String>) -> Self {
169        self.context = Some(context.into());
170        self
171    }
172}
173
174// -- Crate-internal constructors --------------------------------------------
175
176impl Error {
177    /// Create an error from a `TFLite` [`StatusCode`].
178    #[must_use]
179    pub(crate) fn status(code: StatusCode) -> Self {
180        Self {
181            kind: ErrorKind::Status(code),
182            context: None,
183        }
184    }
185
186    /// Create a null-pointer error with a description of which pointer was
187    /// null.
188    #[must_use]
189    pub(crate) fn null_pointer(context: impl Into<String>) -> Self {
190        Self {
191            kind: ErrorKind::NullPointer,
192            context: Some(context.into()),
193        }
194    }
195
196    /// Create an invalid-argument error.
197    #[must_use]
198    pub(crate) fn invalid_argument(msg: impl Into<String>) -> Self {
199        Self {
200            kind: ErrorKind::InvalidArgument(msg.into()),
201            context: None,
202        }
203    }
204}
205
206// -- Display ----------------------------------------------------------------
207
208impl fmt::Display for Error {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        match &self.kind {
211            ErrorKind::Status(code) => write!(f, "TFLite status: {code}")?,
212            ErrorKind::NullPointer => f.write_str("null pointer from C API")?,
213            ErrorKind::Library(inner) => write!(f, "library loading error: {inner}")?,
214            ErrorKind::InvalidArgument(msg) => write!(f, "invalid argument: {msg}")?,
215        }
216        if let Some(ctx) = &self.context {
217            write!(f, " ({ctx})")?;
218        }
219        Ok(())
220    }
221}
222
223// -- std::error::Error ------------------------------------------------------
224
225impl std::error::Error for Error {
226    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
227        match &self.kind {
228            ErrorKind::Library(inner) => Some(inner),
229            _ => None,
230        }
231    }
232}
233
234// -- From conversions -------------------------------------------------------
235
236impl From<libloading::Error> for Error {
237    fn from(err: libloading::Error) -> Self {
238        Self {
239            kind: ErrorKind::Library(err),
240            context: None,
241        }
242    }
243}
244
245// ---------------------------------------------------------------------------
246// hal_to_result
247// ---------------------------------------------------------------------------
248
249/// Convert a HAL DMA-BUF return code to a [`Result`].
250///
251/// HAL functions return `0` on success and `-1` on error (with `errno` set).
252/// On failure, the errno value is captured via [`std::io::Error::last_os_error`]
253/// and included in the error context.
254pub(crate) fn hal_to_result(ret: std::ffi::c_int, context: &str) -> Result<()> {
255    if ret == 0 {
256        return Ok(());
257    }
258    let os_err = std::io::Error::last_os_error();
259    Err(Error::status(StatusCode::DelegateError).with_context(format!("{context}: {os_err}")))
260}
261
262// ---------------------------------------------------------------------------
263// status_to_result
264// ---------------------------------------------------------------------------
265
266/// Convert a raw `TFLite` C API status code to a [`Result`].
267///
268/// `kTfLiteOk` (0) maps to `Ok(())`. All other known values map to the
269/// corresponding [`StatusCode`]. Unknown non-zero values are treated as
270/// [`StatusCode::RuntimeError`].
271pub(crate) fn status_to_result(status: u32) -> Result<()> {
272    if status == 0 {
273        return Ok(());
274    }
275    let code = StatusCode::from_raw(status).unwrap_or(StatusCode::RuntimeError);
276    Err(Error::status(code))
277}
278
279// ---------------------------------------------------------------------------
280// Result type alias
281// ---------------------------------------------------------------------------
282
283/// A [`Result`](std::result::Result) type alias using [`Error`] as the error
284/// variant.
285pub type Result<T> = std::result::Result<T, Error>;
286
287// ---------------------------------------------------------------------------
288// Tests
289// ---------------------------------------------------------------------------
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn status_ok_is_ok() {
297        assert!(status_to_result(0).is_ok());
298    }
299
300    #[test]
301    fn status_error_maps_correctly() {
302        let err = status_to_result(1).unwrap_err();
303        assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
304    }
305
306    #[test]
307    fn status_delegate_codes() {
308        for (raw, expected) in [
309            (2, StatusCode::DelegateError),
310            (4, StatusCode::DelegateDataNotFound),
311            (5, StatusCode::DelegateDataWriteError),
312            (6, StatusCode::DelegateDataReadError),
313        ] {
314            let err = status_to_result(raw).unwrap_err();
315            assert_eq!(err.status_code(), Some(expected));
316            assert!(err.is_delegate_error());
317        }
318    }
319
320    #[test]
321    fn status_all_known_codes() {
322        for raw in 1..=9 {
323            let err = status_to_result(raw).unwrap_err();
324            assert!(err.status_code().is_some());
325        }
326    }
327
328    #[test]
329    fn unknown_status_falls_back_to_runtime_error() {
330        let err = status_to_result(42).unwrap_err();
331        assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
332    }
333
334    #[test]
335    fn null_pointer_error() {
336        let err = Error::null_pointer("TfLiteModelCreate");
337        assert!(err.is_null_pointer());
338        assert!(!err.is_library_error());
339        assert!(!err.is_delegate_error());
340        assert!(err.status_code().is_none());
341        assert!(err.to_string().contains("null pointer"));
342        assert!(err.to_string().contains("TfLiteModelCreate"));
343    }
344
345    #[test]
346    fn invalid_argument_error() {
347        let err = Error::invalid_argument("tensor index out of range");
348        assert!(!err.is_null_pointer());
349        assert!(err.to_string().contains("tensor index out of range"));
350    }
351
352    #[test]
353    fn with_context_appends_message() {
354        let err = Error::status(StatusCode::RuntimeError).with_context("during AllocateTensors");
355        let msg = err.to_string();
356        assert!(msg.contains("runtime error"));
357        assert!(msg.contains("during AllocateTensors"));
358    }
359
360    #[test]
361    fn from_libloading_error() {
362        // Attempt to load a library that does not exist to obtain a
363        // `libloading::Error`.
364        let lib_err = unsafe { libloading::Library::new("__nonexistent__.so") }.unwrap_err();
365        let err = Error::from(lib_err);
366        assert!(err.is_library_error());
367        assert!(err.status_code().is_none());
368        assert!(std::error::Error::source(&err).is_some());
369    }
370
371    #[test]
372    fn display_includes_status_code_name() {
373        let err = Error::status(StatusCode::Cancelled);
374        assert!(err.to_string().contains("cancelled"));
375    }
376
377    #[test]
378    fn non_delegate_status_is_not_delegate_error() {
379        let err = Error::status(StatusCode::RuntimeError);
380        assert!(!err.is_delegate_error());
381    }
382
383    #[test]
384    fn status_code_discriminant_values() {
385        assert_eq!(StatusCode::RuntimeError as u32, 1);
386        assert_eq!(StatusCode::DelegateError as u32, 2);
387        assert_eq!(StatusCode::ApplicationError as u32, 3);
388        assert_eq!(StatusCode::DelegateDataNotFound as u32, 4);
389        assert_eq!(StatusCode::DelegateDataWriteError as u32, 5);
390        assert_eq!(StatusCode::DelegateDataReadError as u32, 6);
391        assert_eq!(StatusCode::UnresolvedOps as u32, 7);
392        assert_eq!(StatusCode::Cancelled as u32, 8);
393        assert_eq!(StatusCode::OutputShapeNotKnown as u32, 9);
394    }
395
396    #[test]
397    fn status_code_display_all_variants() {
398        let cases = [
399            (StatusCode::RuntimeError, "runtime error"),
400            (StatusCode::DelegateError, "delegate error"),
401            (StatusCode::ApplicationError, "application error"),
402            (StatusCode::DelegateDataNotFound, "delegate data not found"),
403            (
404                StatusCode::DelegateDataWriteError,
405                "delegate data write error",
406            ),
407            (
408                StatusCode::DelegateDataReadError,
409                "delegate data read error",
410            ),
411            (StatusCode::UnresolvedOps, "unresolved ops"),
412            (StatusCode::Cancelled, "cancelled"),
413            (StatusCode::OutputShapeNotKnown, "output shape not known"),
414        ];
415        for (code, expected) in cases {
416            assert_eq!(code.to_string(), expected);
417        }
418    }
419
420    #[test]
421    fn error_debug_format() {
422        let err = Error::status(StatusCode::RuntimeError);
423        let debug = format!("{err:?}");
424        assert!(debug.contains("Error"));
425        assert!(debug.contains("Status"));
426    }
427}