use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StatusCode {
RuntimeError = 1,
DelegateError = 2,
ApplicationError = 3,
DelegateDataNotFound = 4,
DelegateDataWriteError = 5,
DelegateDataReadError = 6,
UnresolvedOps = 7,
Cancelled = 8,
OutputShapeNotKnown = 9,
}
impl StatusCode {
fn from_raw(value: u32) -> Option<Self> {
match value {
1 => Some(Self::RuntimeError),
2 => Some(Self::DelegateError),
3 => Some(Self::ApplicationError),
4 => Some(Self::DelegateDataNotFound),
5 => Some(Self::DelegateDataWriteError),
6 => Some(Self::DelegateDataReadError),
7 => Some(Self::UnresolvedOps),
8 => Some(Self::Cancelled),
9 => Some(Self::OutputShapeNotKnown),
_ => None,
}
}
}
impl fmt::Display for StatusCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::RuntimeError => f.write_str("runtime error"),
Self::DelegateError => f.write_str("delegate error"),
Self::ApplicationError => f.write_str("application error"),
Self::DelegateDataNotFound => f.write_str("delegate data not found"),
Self::DelegateDataWriteError => f.write_str("delegate data write error"),
Self::DelegateDataReadError => f.write_str("delegate data read error"),
Self::UnresolvedOps => f.write_str("unresolved ops"),
Self::Cancelled => f.write_str("cancelled"),
Self::OutputShapeNotKnown => f.write_str("output shape not known"),
}
}
}
#[derive(Debug)]
enum ErrorKind {
Status(StatusCode),
NullPointer,
Library(libloading::Error),
InvalidArgument(String),
}
#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
context: Option<String>,
}
impl Error {
#[must_use]
pub fn is_library_error(&self) -> bool {
matches!(self.kind, ErrorKind::Library(_))
}
#[must_use]
pub fn is_delegate_error(&self) -> bool {
matches!(
self.kind,
ErrorKind::Status(
StatusCode::DelegateError
| StatusCode::DelegateDataNotFound
| StatusCode::DelegateDataWriteError
| StatusCode::DelegateDataReadError
)
)
}
#[must_use]
pub fn is_null_pointer(&self) -> bool {
matches!(self.kind, ErrorKind::NullPointer)
}
#[must_use]
pub fn is_invalid_argument(&self) -> bool {
matches!(self.kind, ErrorKind::InvalidArgument(_))
}
#[must_use]
pub fn status_code(&self) -> Option<StatusCode> {
if let ErrorKind::Status(code) = self.kind {
Some(code)
} else {
None
}
}
#[must_use]
pub fn with_context(mut self, context: impl Into<String>) -> Self {
self.context = Some(context.into());
self
}
}
impl Error {
#[must_use]
pub(crate) fn status(code: StatusCode) -> Self {
Self {
kind: ErrorKind::Status(code),
context: None,
}
}
#[must_use]
pub(crate) fn null_pointer(context: impl Into<String>) -> Self {
Self {
kind: ErrorKind::NullPointer,
context: Some(context.into()),
}
}
#[must_use]
pub(crate) fn invalid_argument(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::InvalidArgument(msg.into()),
context: None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
ErrorKind::Status(code) => write!(f, "TFLite status: {code}")?,
ErrorKind::NullPointer => f.write_str("null pointer from C API")?,
ErrorKind::Library(inner) => write!(f, "library loading error: {inner}")?,
ErrorKind::InvalidArgument(msg) => write!(f, "invalid argument: {msg}")?,
}
if let Some(ctx) = &self.context {
write!(f, " ({ctx})")?;
}
Ok(())
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
ErrorKind::Library(inner) => Some(inner),
_ => None,
}
}
}
impl From<libloading::Error> for Error {
fn from(err: libloading::Error) -> Self {
Self {
kind: ErrorKind::Library(err),
context: None,
}
}
}
pub(crate) fn hal_to_result(ret: std::ffi::c_int, context: &str) -> Result<()> {
if ret == 0 {
return Ok(());
}
let os_err = std::io::Error::last_os_error();
Err(Error::status(StatusCode::DelegateError).with_context(format!("{context}: {os_err}")))
}
pub(crate) fn status_to_result(status: u32) -> Result<()> {
if status == 0 {
return Ok(());
}
let code = StatusCode::from_raw(status).unwrap_or(StatusCode::RuntimeError);
Err(Error::status(code))
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn status_ok_is_ok() {
assert!(status_to_result(0).is_ok());
}
#[test]
fn status_error_maps_correctly() {
let err = status_to_result(1).unwrap_err();
assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
}
#[test]
fn status_delegate_codes() {
for (raw, expected) in [
(2, StatusCode::DelegateError),
(4, StatusCode::DelegateDataNotFound),
(5, StatusCode::DelegateDataWriteError),
(6, StatusCode::DelegateDataReadError),
] {
let err = status_to_result(raw).unwrap_err();
assert_eq!(err.status_code(), Some(expected));
assert!(err.is_delegate_error());
}
}
#[test]
fn status_all_known_codes() {
for raw in 1..=9 {
let err = status_to_result(raw).unwrap_err();
assert!(err.status_code().is_some());
}
}
#[test]
fn unknown_status_falls_back_to_runtime_error() {
let err = status_to_result(42).unwrap_err();
assert_eq!(err.status_code(), Some(StatusCode::RuntimeError));
}
#[test]
fn null_pointer_error() {
let err = Error::null_pointer("TfLiteModelCreate");
assert!(err.is_null_pointer());
assert!(!err.is_library_error());
assert!(!err.is_delegate_error());
assert!(err.status_code().is_none());
assert!(err.to_string().contains("null pointer"));
assert!(err.to_string().contains("TfLiteModelCreate"));
}
#[test]
fn invalid_argument_error() {
let err = Error::invalid_argument("tensor index out of range");
assert!(!err.is_null_pointer());
assert!(err.to_string().contains("tensor index out of range"));
}
#[test]
fn with_context_appends_message() {
let err = Error::status(StatusCode::RuntimeError).with_context("during AllocateTensors");
let msg = err.to_string();
assert!(msg.contains("runtime error"));
assert!(msg.contains("during AllocateTensors"));
}
#[test]
fn from_libloading_error() {
let lib_err = unsafe { libloading::Library::new("__nonexistent__.so") }.unwrap_err();
let err = Error::from(lib_err);
assert!(err.is_library_error());
assert!(err.status_code().is_none());
assert!(std::error::Error::source(&err).is_some());
}
#[test]
fn display_includes_status_code_name() {
let err = Error::status(StatusCode::Cancelled);
assert!(err.to_string().contains("cancelled"));
}
#[test]
fn non_delegate_status_is_not_delegate_error() {
let err = Error::status(StatusCode::RuntimeError);
assert!(!err.is_delegate_error());
}
#[test]
fn status_code_discriminant_values() {
assert_eq!(StatusCode::RuntimeError as u32, 1);
assert_eq!(StatusCode::DelegateError as u32, 2);
assert_eq!(StatusCode::ApplicationError as u32, 3);
assert_eq!(StatusCode::DelegateDataNotFound as u32, 4);
assert_eq!(StatusCode::DelegateDataWriteError as u32, 5);
assert_eq!(StatusCode::DelegateDataReadError as u32, 6);
assert_eq!(StatusCode::UnresolvedOps as u32, 7);
assert_eq!(StatusCode::Cancelled as u32, 8);
assert_eq!(StatusCode::OutputShapeNotKnown as u32, 9);
}
#[test]
fn status_code_display_all_variants() {
let cases = [
(StatusCode::RuntimeError, "runtime error"),
(StatusCode::DelegateError, "delegate error"),
(StatusCode::ApplicationError, "application error"),
(StatusCode::DelegateDataNotFound, "delegate data not found"),
(
StatusCode::DelegateDataWriteError,
"delegate data write error",
),
(
StatusCode::DelegateDataReadError,
"delegate data read error",
),
(StatusCode::UnresolvedOps, "unresolved ops"),
(StatusCode::Cancelled, "cancelled"),
(StatusCode::OutputShapeNotKnown, "output shape not known"),
];
for (code, expected) in cases {
assert_eq!(code.to_string(), expected);
}
}
#[test]
fn error_debug_format() {
let err = Error::status(StatusCode::RuntimeError);
let debug = format!("{err:?}");
assert!(debug.contains("Error"));
assert!(debug.contains("Status"));
}
}