Skip to main content

coreml_native/
error.rs

1//! Error types for the coreml crate.
2
3use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum ErrorKind {
7    ModelLoad,
8    TensorCreate,
9    Prediction,
10    Introspection,
11    InvalidShape,
12    UnsupportedPlatform,
13}
14
15impl fmt::Display for ErrorKind {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            Self::ModelLoad => write!(f, "model load"),
19            Self::TensorCreate => write!(f, "tensor create"),
20            Self::Prediction => write!(f, "prediction"),
21            Self::Introspection => write!(f, "introspection"),
22            Self::InvalidShape => write!(f, "invalid shape"),
23            Self::UnsupportedPlatform => write!(f, "unsupported platform"),
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct Error {
30    kind: ErrorKind,
31    message: String,
32}
33
34impl Error {
35    pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
36        Self { kind, message: message.into() }
37    }
38
39    pub fn kind(&self) -> &ErrorKind { &self.kind }
40    pub fn message(&self) -> &str { &self.message }
41}
42
43impl fmt::Display for Error {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        write!(f, "coreml {}: {}", self.kind, self.message)
46    }
47}
48
49impl std::error::Error for Error {}
50
51pub type Result<T> = std::result::Result<T, Error>;
52
53#[cfg(target_vendor = "apple")]
54impl Error {
55    pub(crate) fn from_nserror(kind: ErrorKind, err: &objc2_foundation::NSError) -> Self {
56        let desc = err.localizedDescription();
57        Self::new(kind, desc.to_string())
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn error_display() {
67        let err = Error::new(ErrorKind::ModelLoad, "file not found");
68        let s = format!("{err}");
69        assert!(s.contains("model load"));
70        assert!(s.contains("file not found"));
71    }
72
73    #[test]
74    fn error_implements_std_error() {
75        let err = Error::new(ErrorKind::Prediction, "fail");
76        let _: &dyn std::error::Error = &err;
77    }
78
79    #[test]
80    fn error_kind_accessor() {
81        let err = Error::new(ErrorKind::InvalidShape, "mismatch");
82        assert_eq!(err.kind(), &ErrorKind::InvalidShape);
83    }
84
85    #[test]
86    fn all_error_kinds_distinct() {
87        let kinds = [
88            ErrorKind::ModelLoad, ErrorKind::TensorCreate, ErrorKind::Prediction,
89            ErrorKind::Introspection, ErrorKind::InvalidShape, ErrorKind::UnsupportedPlatform,
90        ];
91        for (i, a) in kinds.iter().enumerate() {
92            for (j, b) in kinds.iter().enumerate() {
93                assert_eq!(i == j, a == b);
94            }
95        }
96    }
97}