1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum ErrorKind {
7 ModelLoad,
8 TensorCreate,
9 Prediction,
10 Introspection,
11 InvalidShape,
12 UnsupportedPlatform,
13}
14
15impl fmt::Display for ErrorKind {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 Self::ModelLoad => write!(f, "model load"),
19 Self::TensorCreate => write!(f, "tensor create"),
20 Self::Prediction => write!(f, "prediction"),
21 Self::Introspection => write!(f, "introspection"),
22 Self::InvalidShape => write!(f, "invalid shape"),
23 Self::UnsupportedPlatform => write!(f, "unsupported platform"),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct Error {
30 kind: ErrorKind,
31 message: String,
32}
33
34impl Error {
35 pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
36 Self { kind, message: message.into() }
37 }
38
39 pub fn kind(&self) -> &ErrorKind { &self.kind }
40 pub fn message(&self) -> &str { &self.message }
41}
42
43impl fmt::Display for Error {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 write!(f, "coreml {}: {}", self.kind, self.message)
46 }
47}
48
49impl std::error::Error for Error {}
50
51pub type Result<T> = std::result::Result<T, Error>;
52
53#[cfg(target_vendor = "apple")]
54impl Error {
55 pub(crate) fn from_nserror(kind: ErrorKind, err: &objc2_foundation::NSError) -> Self {
56 let desc = err.localizedDescription();
57 Self::new(kind, desc.to_string())
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn error_display() {
67 let err = Error::new(ErrorKind::ModelLoad, "file not found");
68 let s = format!("{err}");
69 assert!(s.contains("model load"));
70 assert!(s.contains("file not found"));
71 }
72
73 #[test]
74 fn error_implements_std_error() {
75 let err = Error::new(ErrorKind::Prediction, "fail");
76 let _: &dyn std::error::Error = &err;
77 }
78
79 #[test]
80 fn error_kind_accessor() {
81 let err = Error::new(ErrorKind::InvalidShape, "mismatch");
82 assert_eq!(err.kind(), &ErrorKind::InvalidShape);
83 }
84
85 #[test]
86 fn all_error_kinds_distinct() {
87 let kinds = [
88 ErrorKind::ModelLoad, ErrorKind::TensorCreate, ErrorKind::Prediction,
89 ErrorKind::Introspection, ErrorKind::InvalidShape, ErrorKind::UnsupportedPlatform,
90 ];
91 for (i, a) in kinds.iter().enumerate() {
92 for (j, b) in kinds.iter().enumerate() {
93 assert_eq!(i == j, a == b);
94 }
95 }
96 }
97}