1use core::fmt;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum TensorError {
10 ShapeMismatch {
12 expected: Vec<usize>,
14 got: Vec<usize>,
16 },
17
18 IndexOutOfBounds {
20 index: usize,
22 dim: usize,
24 size: usize,
26 },
27
28 DimensionMismatch {
30 expected: usize,
32 got: usize,
34 },
35
36 UnsupportedDType {
38 dtype: String,
40 operation: String,
42 },
43
44 UnsupportedDevice {
46 device: String,
48 },
49
50 OutOfMemory {
52 requested_bytes: usize,
54 },
55
56 BlasError {
58 code: i32,
60 description: String,
62 },
63
64 SparseFormatError {
66 from: String,
68 to: String,
70 description: String,
72 },
73
74 BroadcastError {
76 shape1: Vec<usize>,
78 shape2: Vec<usize>,
80 },
81
82 SliceError {
84 description: String,
86 },
87
88 DeviceTransferError {
90 from: String,
92 to: String,
94 description: String,
96 },
97}
98
99impl fmt::Display for TensorError {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self {
102 TensorError::ShapeMismatch { expected, got } => {
103 write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
104 }
105 TensorError::IndexOutOfBounds { index, dim, size } => {
106 write!(f, "Index {} out of bounds for dimension {} (size {})", index, dim, size)
107 }
108 TensorError::DimensionMismatch { expected, got } => {
109 write!(f, "Dimension mismatch: expected {}D, got {}D", expected, got)
110 }
111 TensorError::UnsupportedDType { dtype, operation } => {
112 write!(f, "Unsupported dtype {} for operation {}", dtype, operation)
113 }
114 TensorError::UnsupportedDevice { device } => {
115 write!(f, "Unsupported device: {}", device)
116 }
117 TensorError::OutOfMemory { requested_bytes } => {
118 write!(f, "Out of memory: requested {} bytes", requested_bytes)
119 }
120 TensorError::BlasError { code, description } => {
121 write!(f, "BLAS error (code {}): {}", code, description)
122 }
123 TensorError::SparseFormatError { from, to, description } => {
124 write!(f, "Sparse format error converting {} to {}: {}", from, to, description)
125 }
126 TensorError::BroadcastError { shape1, shape2 } => {
127 write!(f, "Cannot broadcast shapes {:?} and {:?}", shape1, shape2)
128 }
129 TensorError::SliceError { description } => {
130 write!(f, "Slice error: {}", description)
131 }
132 TensorError::DeviceTransferError { from, to, description } => {
133 write!(f, "Device transfer error from {} to {}: {}", from, to, description)
134 }
135 }
136 }
137}
138
139#[cfg(feature = "std")]
140impl std::error::Error for TensorError {}
141
142pub type TensorResult<T> = Result<T, TensorError>;
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_error_display() {
151 let err = TensorError::ShapeMismatch {
152 expected: vec![2, 3],
153 got: vec![3, 2],
154 };
155 assert_eq!(
156 format!("{}", err),
157 "Shape mismatch: expected [2, 3], got [3, 2]"
158 );
159
160 let err = TensorError::IndexOutOfBounds {
161 index: 5,
162 dim: 0,
163 size: 3,
164 };
165 assert_eq!(
166 format!("{}", err),
167 "Index 5 out of bounds for dimension 0 (size 3)"
168 );
169 }
170
171 #[test]
172 fn test_error_equality() {
173 let err1 = TensorError::ShapeMismatch {
174 expected: vec![2, 3],
175 got: vec![3, 2],
176 };
177 let err2 = TensorError::ShapeMismatch {
178 expected: vec![2, 3],
179 got: vec![3, 2],
180 };
181 assert_eq!(err1, err2);
182
183 let err3 = TensorError::ShapeMismatch {
184 expected: vec![2, 3],
185 got: vec![2, 3],
186 };
187 assert_ne!(err1, err3);
188 }
189}