1use crate::shape::error::ShapeError;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use strum::{Display, EnumCount, EnumIs, EnumIter, EnumString, VariantNames};
9
10pub type TensorResult<T = ()> = std::result::Result<T, TensorError>;
11
12#[derive(
13 Clone, Debug, Display, EnumCount, EnumIs, Eq, Hash, Ord, PartialEq, PartialOrd, VariantNames,
14)]
15#[cfg_attr(
16 feature = "serde",
17 derive(Deserialize, Serialize),
18 serde(rename_all = "snake_case", untagged)
19)]
20#[repr(usize)]
21#[strum(serialize_all = "snake_case")]
22pub enum TensorError {
23 Arithmetic(ArithmeticError),
24 Shape(ShapeError),
25 Singular,
26 NotScalar,
27 Unknown(String),
28}
29
30unsafe impl Send for TensorError {}
31
32unsafe impl Sync for TensorError {}
33
34impl std::error::Error for TensorError {}
35
36impl From<&str> for TensorError {
37 fn from(error: &str) -> Self {
38 TensorError::Unknown(error.to_string())
39 }
40}
41
42impl From<String> for TensorError {
43 fn from(error: String) -> Self {
44 TensorError::Unknown(error)
45 }
46}
47
48#[derive(
49 Clone,
50 Copy,
51 Debug,
52 Display,
53 EnumCount,
54 EnumIs,
55 EnumIter,
56 EnumString,
57 Eq,
58 Hash,
59 Ord,
60 PartialEq,
61 PartialOrd,
62 VariantNames,
63)]
64#[cfg_attr(
65 feature = "serde",
66 derive(Deserialize, Serialize),
67 serde(rename_all = "snake_case", untagged)
68)]
69#[repr(usize)]
70#[strum(serialize_all = "snake_case")]
71pub enum ArithmeticError {
72 DivisionByZero,
73 Overflow,
74 Underflow,
75}
76
77macro_rules! into_tensor_error {
78 ($(($error:ident => $kind:ident)),*) => {
79 $(into_tensor_error!($error => $kind);)*
80 };
81 ($error:ident => $kind:ident) => {
82 impl From<$error> for TensorError {
83 fn from(error: $error) -> Self {
84 TensorError::$kind(error)
85 }
86 }
87
88 impl TryFrom<TensorError> for $error {
89 type Error = TensorError;
90
91 fn try_from(error: TensorError) -> TensorResult<$error> {
92 match error {
93 TensorError::$kind(error) => Ok(error),
94 error => Err(error),
95 }
96 }
97 }
98 };
99}
100
101into_tensor_error!(
102 (ArithmeticError => Arithmetic),
103 (ShapeError => Shape)
104);