Skip to main content

edgefirst_tflite/
error.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Error types for the `edgefirst-tflite` crate.
5//!
6//! This module follows the canonical error struct pattern: a public [`Error`]
7//! struct wrapping a private `ErrorKind` enum. Callers inspect errors through
8//! [`Error::is_library_error`], [`Error::is_delegate_error`],
9//! [`Error::is_null_pointer`], and [`Error::status_code`] rather than matching
10//! on variants directly.
11
12use std::fmt;
13
14// ---------------------------------------------------------------------------
15// StatusCode
16// ---------------------------------------------------------------------------
17
18/// Status codes returned by the TensorFlow Lite C API.
19///
20/// Each variant maps to a `kTfLite*` constant defined in the C header
21/// `common.h`. The numeric value is accessible via `as u32`.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum StatusCode {
24    /// Generic runtime error (`kTfLiteError = 1`).
25    RuntimeError = 1,
26    /// Delegate returned an error (`kTfLiteDelegateError = 2`).
27    DelegateError = 2,
28    /// Application-level error (`kTfLiteApplicationError = 3`).
29    ApplicationError = 3,
30    /// Delegate data not found (`kTfLiteDelegateDataNotFound = 4`).
31    DelegateDataNotFound = 4,
32    /// Delegate data write error (`kTfLiteDelegateDataWriteError = 5`).
33    DelegateDataWriteError = 5,
34    /// Delegate data read error (`kTfLiteDelegateDataReadError = 6`).
35    DelegateDataReadError = 6,
36    /// Model contains unresolved ops (`kTfLiteUnresolvedOps = 7`).
37    UnresolvedOps = 7,
38    /// Operation was cancelled (`kTfLiteCancelled = 8`).
39    Cancelled = 8,
40    /// Output tensor shape is not yet known (`kTfLiteOutputShapeNotKnown = 9`).
41    OutputShapeNotKnown = 9,
42}
43
44impl StatusCode {
45    /// Attempt to convert a raw C API status value into a `StatusCode`.
46    ///
47    /// Returns `None` for `kTfLiteOk` (0) or any unknown value.
48    fn from_raw(value: u32) -> Option<Self> {
49        match value {
50            1 => Some(Self::RuntimeError),
51            2 => Some(Self::DelegateError),
52            3 => Some(Self::ApplicationError),
53            4 => Some(Self::DelegateDataNotFound),
54            5 => Some(Self::DelegateDataWriteError),
55            6 => Some(Self::DelegateDataReadError),
56            7 => Some(Self::UnresolvedOps),
57            8 => Some(Self::Cancelled),
58            9 => Some(Self::OutputShapeNotKnown),
59            _ => None,
60        }
61    }
62}
63
64impl fmt::Display for StatusCode {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Self::RuntimeError => f.write_str("runtime error"),
68            Self::DelegateError => f.write_str("delegate error"),
69            Self::ApplicationError => f.write_str("application error"),
70            Self::DelegateDataNotFound => f.write_str("delegate data not found"),
71            Self::DelegateDataWriteError => f.write_str("delegate data write error"),
72            Self::DelegateDataReadError => f.write_str("delegate data read error"),
73            Self::UnresolvedOps => f.write_str("unresolved ops"),
74            Self::Cancelled => f.write_str("cancelled"),
75            Self::OutputShapeNotKnown => f.write_str("output shape not known"),
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// ErrorKind (private)
82// ---------------------------------------------------------------------------
83
84/// Internal error classification. Not exposed to consumers.
85#[derive(Debug)]
86enum ErrorKind {
87    /// The TensorFlow Lite C API returned a non-OK status.
88    Status(StatusCode),
89    /// A C API function returned a null pointer.
90    NullPointer,
91    /// Library loading or symbol resolution failed.
92    Library(libloading::Error),
93    /// An invalid argument was passed to the API.
94    InvalidArgument(String),
95}
96
97// ---------------------------------------------------------------------------
98// Error
99// ---------------------------------------------------------------------------
100
101/// The error type for all fallible operations in `edgefirst-tflite`.
102///
103/// `Error` wraps a private `ErrorKind` enum so that the set of failure modes
104/// can grow without breaking callers. Use the `is_*()` inspection methods and
105/// [`Error::status_code`] to classify an error programmatically.
106#[derive(Debug)]
107pub struct Error {
108    kind: ErrorKind,
109    context: Option<String>,
110}
111
112// -- Public inspection API --------------------------------------------------
113
114impl Error {
115    /// Returns `true` if this error originated from library loading or symbol
116    /// resolution (i.e. a [`libloading::Error`]).
117    #[must_use]
118    pub fn is_library_error(&self) -> bool {
119        matches!(self.kind, ErrorKind::Library(_))
120    }
121
122    /// Returns `true` if the underlying `TFLite` status is one of the delegate
123    /// error codes: [`StatusCode::DelegateError`],
124    /// [`StatusCode::DelegateDataNotFound`],
125    /// [`StatusCode::DelegateDataWriteError`], or
126    /// [`StatusCode::DelegateDataReadError`].
127    #[must_use]
128    pub fn is_delegate_error(&self) -> bool {
129        matches!(
130            self.kind,
131            ErrorKind::Status(
132                StatusCode::DelegateError
133                    | StatusCode::DelegateDataNotFound
134                    | StatusCode::DelegateDataWriteError
135                    | StatusCode::DelegateDataReadError
136            )
137        )
138    }
139
140    /// Returns `true` if a C API call returned a null pointer.
141    #[must_use]
142    pub fn is_null_pointer(&self) -> bool {
143        matches!(self.kind, ErrorKind::NullPointer)
144    }
145
146    /// Returns `true` if this error is an invalid-argument error.
147    #[must_use]
148    pub fn is_invalid_argument(&self) -> bool {
149        matches!(self.kind, ErrorKind::InvalidArgument(_))
150    }
151
152    /// Returns the `TFLite` [`StatusCode`] when the error originated from a
153    /// non-OK C API status, or `None` otherwise.
154    #[must_use]
155    pub fn status_code(&self) -> Option<StatusCode> {
156        if let ErrorKind::Status(code) = self.kind {
157            Some(code)
158        } else {
159            None
160        }
161    }
162
163    /// Attach additional human-readable context to this error.
164    ///
165    /// The context string is appended in parentheses when the error is
166    /// displayed.
167    #[must_use]
168    pub fn with_context(mut self, context: impl Into<String>) -> Self {
169        self.context = Some(context.into());
170        self
171    }
172}
173
174// -- Crate-internal constructors --------------------------------------------
175
176impl Error {
177    /// Create an error from a `TFLite` [`StatusCode`].
178    #[must_use]
179    pub(crate) fn status(code: StatusCode) -> Self {
180        Self {
181            kind: ErrorKind::Status(code),
182            context: None,
183        }
184    }
185
186    /// Create a null-pointer error with a description of which pointer was
187    /// null.
188    #[must_use]
189    pub(crate) fn null_pointer(context: impl Into<String>) -> Self {
190        Self {
191            kind: ErrorKind::NullPointer,
192            context: Some(context.into()),
193        }
194    }
195
196    /// Create an invalid-argument error.
197    #[must_use]
198    pub(crate) fn invalid_argument(msg: impl Into<String>) -> Self {
199        Self {
200            kind: ErrorKind::InvalidArgument(msg.into()),
201            context: None,
202        }
203    }
204}
205
206// -- Display ----------------------------------------------------------------
207
208impl fmt::Display for Error {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        match &self.kind {
211            ErrorKind::Status(code) => write!(f, "TFLite status: {code}")?,
212            ErrorKind::NullPointer => f.write_str("null pointer from C API")?,
213            ErrorKind::Library(inner) => write!(f, "library loading error: {inner}")?,
214            ErrorKind::InvalidArgument(msg) => write!(f, "invalid argument: {msg}")?,
215        }
216        if let Some(ctx) = &self.context {
217            write!(f, " ({ctx})")?;
218        }
219        Ok(())
220    }
221}
222
223// -- std::error::Error ------------------------------------------------------
224
225impl std::error::Error for Error {
226    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
227        match &self.kind {
228            ErrorKind::Library(inner) => Some(inner),
229            _ => None,
230        }
231    }
232}
233
234// -- From conversions -------------------------------------------------------
235
236impl From<libloading::Error> for Error {
237    fn from(err: libloading::Error) -> Self {
238        Self {
239            kind: ErrorKind::Library(err),
240            context: None,
241        }
242    }
243}
244
245// ---------------------------------------------------------------------------
246// status_to_result
247// ---------------------------------------------------------------------------
248
249/// Convert a raw `TFLite` C API status code to a [`Result`].
250///
251/// `kTfLiteOk` (0) maps to `Ok(())`. All other known values map to the
252/// corresponding [`StatusCode`]. Unknown non-zero values are treated as
253/// [`StatusCode::RuntimeError`].
254pub(crate) fn status_to_result(status: u32) -> Result<()> {
255    if status == 0 {
256        return Ok(());
257    }
258    let code = StatusCode::from_raw(status).unwrap_or(StatusCode::RuntimeError);
259    Err(Error::status(code))
260}
261
262// ---------------------------------------------------------------------------
263// Result type alias
264// ---------------------------------------------------------------------------
265
266/// A [`Result`](std::result::Result) type alias using [`Error`] as the error
267/// variant.
268pub type Result<T> = std::result::Result<T, Error>;
269
270// ---------------------------------------------------------------------------
271// Tests
272// ---------------------------------------------------------------------------
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn status_ok_is_ok() {
280        assert!(status_to_result(0).is_ok());
281    }
282
283    #[test]
284    fn status_error_maps_correctly() {
285        let err = status_to_result(1).unwrap_err();
286        assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
287    }
288
289    #[test]
290    fn status_delegate_codes() {
291        for (raw, expected) in [
292            (2, StatusCode::DelegateError),
293            (4, StatusCode::DelegateDataNotFound),
294            (5, StatusCode::DelegateDataWriteError),
295            (6, StatusCode::DelegateDataReadError),
296        ] {
297            let err = status_to_result(raw).unwrap_err();
298            assert_eq!(err.status_code(), Some(expected));
299            assert!(err.is_delegate_error());
300        }
301    }
302
303    #[test]
304    fn status_all_known_codes() {
305        for raw in 1..=9 {
306            let err = status_to_result(raw).unwrap_err();
307            assert!(err.status_code().is_some());
308        }
309    }
310
311    #[test]
312    fn unknown_status_falls_back_to_runtime_error() {
313        let err = status_to_result(42).unwrap_err();
314        assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
315    }
316
317    #[test]
318    fn null_pointer_error() {
319        let err = Error::null_pointer("TfLiteModelCreate");
320        assert!(err.is_null_pointer());
321        assert!(!err.is_library_error());
322        assert!(!err.is_delegate_error());
323        assert!(err.status_code().is_none());
324        assert!(err.to_string().contains("null pointer"));
325        assert!(err.to_string().contains("TfLiteModelCreate"));
326    }
327
328    #[test]
329    fn invalid_argument_error() {
330        let err = Error::invalid_argument("tensor index out of range");
331        assert!(!err.is_null_pointer());
332        assert!(err.to_string().contains("tensor index out of range"));
333    }
334
335    #[test]
336    fn with_context_appends_message() {
337        let err = Error::status(StatusCode::RuntimeError).with_context("during AllocateTensors");
338        let msg = err.to_string();
339        assert!(msg.contains("runtime error"));
340        assert!(msg.contains("during AllocateTensors"));
341    }
342
343    #[test]
344    fn from_libloading_error() {
345        // Attempt to load a library that does not exist to obtain a
346        // `libloading::Error`.
347        let lib_err = unsafe { libloading::Library::new("__nonexistent__.so") }.unwrap_err();
348        let err = Error::from(lib_err);
349        assert!(err.is_library_error());
350        assert!(err.status_code().is_none());
351        assert!(std::error::Error::source(&err).is_some());
352    }
353
354    #[test]
355    fn display_includes_status_code_name() {
356        let err = Error::status(StatusCode::Cancelled);
357        assert!(err.to_string().contains("cancelled"));
358    }
359
360    #[test]
361    fn non_delegate_status_is_not_delegate_error() {
362        let err = Error::status(StatusCode::RuntimeError);
363        assert!(!err.is_delegate_error());
364    }
365
366    #[test]
367    fn status_code_discriminant_values() {
368        assert_eq!(StatusCode::RuntimeError as u32, 1);
369        assert_eq!(StatusCode::DelegateError as u32, 2);
370        assert_eq!(StatusCode::ApplicationError as u32, 3);
371        assert_eq!(StatusCode::DelegateDataNotFound as u32, 4);
372        assert_eq!(StatusCode::DelegateDataWriteError as u32, 5);
373        assert_eq!(StatusCode::DelegateDataReadError as u32, 6);
374        assert_eq!(StatusCode::UnresolvedOps as u32, 7);
375        assert_eq!(StatusCode::Cancelled as u32, 8);
376        assert_eq!(StatusCode::OutputShapeNotKnown as u32, 9);
377    }
378
379    #[test]
380    fn status_code_display_all_variants() {
381        let cases = [
382            (StatusCode::RuntimeError, "runtime error"),
383            (StatusCode::DelegateError, "delegate error"),
384            (StatusCode::ApplicationError, "application error"),
385            (StatusCode::DelegateDataNotFound, "delegate data not found"),
386            (
387                StatusCode::DelegateDataWriteError,
388                "delegate data write error",
389            ),
390            (
391                StatusCode::DelegateDataReadError,
392                "delegate data read error",
393            ),
394            (StatusCode::UnresolvedOps, "unresolved ops"),
395            (StatusCode::Cancelled, "cancelled"),
396            (StatusCode::OutputShapeNotKnown, "output shape not known"),
397        ];
398        for (code, expected) in cases {
399            assert_eq!(code.to_string(), expected);
400        }
401    }
402
403    #[test]
404    fn error_debug_format() {
405        let err = Error::status(StatusCode::RuntimeError);
406        let debug = format!("{err:?}");
407        assert!(debug.contains("Error"));
408        assert!(debug.contains("Status"));
409    }
410}