1pub type EtensorResult<T> = Result<T, EtensorError>;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum EtensorError {
12 ShapeMismatch {
14 expected: Vec<usize>,
15 got: Vec<usize>,
16 },
17
18 DeviceMismatch { expected: String, got: String },
21
22 DTypeMismatch { expected: String, got: String },
24
25 AutogradError(String),
27
28 BackendError(String),
30
31 IoError(String),
33
34 InternalError(String),
36}
37
38impl std::fmt::Display for EtensorError {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 EtensorError::ShapeMismatch { expected, got } => {
42 write!(
43 f,
44 "ShapeMismatch: Operation expected shape {:?}, but got {:?}",
45 expected, got
46 )
47 }
48 EtensorError::DeviceMismatch { expected, got } => {
49 write!(
50 f,
51 "DeviceMismatch: Operation expected tensors on '{}', but found tensor on '{}'",
52 expected, got
53 )
54 }
55 EtensorError::DTypeMismatch { expected, got } => {
56 write!(
57 f,
58 "DTypeMismatch: Operation expected dtype '{}', but found '{}'",
59 expected, got
60 )
61 }
62 EtensorError::AutogradError(msg) => write!(f, "AutogradError: {}", msg),
63 EtensorError::BackendError(msg) => write!(f, "BackendError: {}", msg),
64 EtensorError::IoError(msg) => write!(f, "IoError: {}", msg),
65 EtensorError::InternalError(msg) => write!(f, "InternalError: {}", msg),
66 }
67 }
68}
69
70impl std::error::Error for EtensorError {}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 #[test]
79 fn test_shape_mismatch_formatting() {
80 let err = EtensorError::ShapeMismatch {
81 expected: vec![2, 3],
82 got: vec![3, 2],
83 };
84 assert_eq!(
85 err.to_string(),
86 "ShapeMismatch: Operation expected shape [2, 3], but got [3, 2]"
87 );
88 }
89
90 #[test]
91 fn test_device_mismatch_formatting() {
92 let err = EtensorError::DeviceMismatch {
93 expected: "cuda:0".to_string(),
94 got: "cpu".to_string(),
95 };
96 assert_eq!(
97 err.to_string(),
98 "DeviceMismatch: Operation expected tensors on 'cuda:0', but found tensor on 'cpu'"
99 );
100 }
101
102 #[test]
103 fn test_result_propagation() {
104 fn mock_add(a_is_cpu: bool, b_is_cpu: bool) -> EtensorResult<f32> {
106 if a_is_cpu != b_is_cpu {
107 return Err(EtensorError::DeviceMismatch {
108 expected: "cpu".to_string(),
109 got: "cuda".to_string(),
110 });
111 }
112 Ok(42.0)
113 }
114
115 assert!(mock_add(true, true).is_ok());
116
117 let err = mock_add(true, false).unwrap_err();
118 assert!(matches!(err, EtensorError::DeviceMismatch { .. }));
119 }
120}