acme_tensor/
error.rs

1/*
2    Appellation: error <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::shape::error::ShapeError;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumCount, EnumIs, EnumIter, EnumString, VariantNames};
9
10pub type TensorResult<T = ()> = std::result::Result<T, TensorError>;
11
12#[derive(
13    Clone, Debug, Display, EnumCount, EnumIs, Eq, Hash, Ord, PartialEq, PartialOrd, VariantNames,
14)]
15#[cfg_attr(
16    feature = "serde",
17    derive(Deserialize, Serialize),
18    serde(rename_all = "snake_case", untagged)
19)]
20#[repr(usize)]
21#[strum(serialize_all = "snake_case")]
22pub enum TensorError {
23    Arithmetic(ArithmeticError),
24    Shape(ShapeError),
25    Singular,
26    NotScalar,
27    Unknown(String),
28}
29
30unsafe impl Send for TensorError {}
31
32unsafe impl Sync for TensorError {}
33
34impl std::error::Error for TensorError {}
35
36impl From<&str> for TensorError {
37    fn from(error: &str) -> Self {
38        TensorError::Unknown(error.to_string())
39    }
40}
41
42impl From<String> for TensorError {
43    fn from(error: String) -> Self {
44        TensorError::Unknown(error)
45    }
46}
47
48#[derive(
49    Clone,
50    Copy,
51    Debug,
52    Display,
53    EnumCount,
54    EnumIs,
55    EnumIter,
56    EnumString,
57    Eq,
58    Hash,
59    Ord,
60    PartialEq,
61    PartialOrd,
62    VariantNames,
63)]
64#[cfg_attr(
65    feature = "serde",
66    derive(Deserialize, Serialize),
67    serde(rename_all = "snake_case", untagged)
68)]
69#[repr(usize)]
70#[strum(serialize_all = "snake_case")]
71pub enum ArithmeticError {
72    DivisionByZero,
73    Overflow,
74    Underflow,
75}
76
77macro_rules! into_tensor_error {
78    ($(($error:ident => $kind:ident)),*) => {
79        $(into_tensor_error!($error => $kind);)*
80    };
81    ($error:ident => $kind:ident) => {
82        impl From<$error> for TensorError {
83            fn from(error: $error) -> Self {
84                TensorError::$kind(error)
85            }
86        }
87
88        impl TryFrom<TensorError> for $error {
89            type Error = TensorError;
90
91            fn try_from(error: TensorError) -> TensorResult<$error> {
92                match error {
93                    TensorError::$kind(error) => Ok(error),
94                    error => Err(error),
95                }
96            }
97        }
98    };
99}
100
101into_tensor_error!(
102    (ArithmeticError => Arithmetic),
103    (ShapeError => Shape)
104);