1use core::fmt;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum TensorError {
10 ShapeMismatch {
12 expected: Vec<usize>,
14 got: Vec<usize>,
16 },
17
18 IndexOutOfBounds {
20 index: usize,
22 dim: usize,
24 size: usize,
26 },
27
28 DimensionMismatch {
30 expected: usize,
32 got: usize,
34 },
35
36 UnsupportedDType {
38 dtype: String,
40 operation: String,
42 },
43
44 UnsupportedDevice {
46 device: String,
48 },
49
50 OutOfMemory {
52 requested_bytes: usize,
54 },
55
56 BlasError {
58 code: i32,
60 description: String,
62 },
63
64 SparseFormatError {
66 from: String,
68 to: String,
70 description: String,
72 },
73
74 BroadcastError {
76 shape1: Vec<usize>,
78 shape2: Vec<usize>,
80 },
81
82 SliceError {
84 description: String,
86 },
87
88 DeviceTransferError {
90 from: String,
92 to: String,
94 description: String,
96 },
97
98 AllocationError {
100 message: String,
102 },
103
104 MatrixError {
106 message: String,
108 },
109}
110
111impl fmt::Display for TensorError {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 TensorError::ShapeMismatch { expected, got } => {
115 write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
116 }
117 TensorError::IndexOutOfBounds { index, dim, size } => {
118 write!(
119 f,
120 "Index {} out of bounds for dimension {} (size {})",
121 index, dim, size
122 )
123 }
124 TensorError::DimensionMismatch { expected, got } => {
125 write!(
126 f,
127 "Dimension mismatch: expected {}D, got {}D",
128 expected, got
129 )
130 }
131 TensorError::UnsupportedDType { dtype, operation } => {
132 write!(f, "Unsupported dtype {} for operation {}", dtype, operation)
133 }
134 TensorError::UnsupportedDevice { device } => {
135 write!(f, "Unsupported device: {}", device)
136 }
137 TensorError::OutOfMemory { requested_bytes } => {
138 write!(f, "Out of memory: requested {} bytes", requested_bytes)
139 }
140 TensorError::BlasError { code, description } => {
141 write!(f, "BLAS error (code {}): {}", code, description)
142 }
143 TensorError::SparseFormatError {
144 from,
145 to,
146 description,
147 } => {
148 write!(
149 f,
150 "Sparse format error converting {} to {}: {}",
151 from, to, description
152 )
153 }
154 TensorError::BroadcastError { shape1, shape2 } => {
155 write!(f, "Cannot broadcast shapes {:?} and {:?}", shape1, shape2)
156 }
157 TensorError::SliceError { description } => {
158 write!(f, "Slice error: {}", description)
159 }
160 TensorError::DeviceTransferError {
161 from,
162 to,
163 description,
164 } => {
165 write!(
166 f,
167 "Device transfer error from {} to {}: {}",
168 from, to, description
169 )
170 }
171 TensorError::AllocationError { message } => {
172 write!(f, "Allocation error: {}", message)
173 }
174 TensorError::MatrixError { message } => {
175 write!(f, "Matrix error: {}", message)
176 }
177 }
178 }
179}
180
181#[cfg(feature = "std")]
182impl std::error::Error for TensorError {}
183
184pub type TensorResult<T> = Result<T, TensorError>;
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_error_display() {
193 let err = TensorError::ShapeMismatch {
194 expected: vec![2, 3],
195 got: vec![3, 2],
196 };
197 assert_eq!(
198 format!("{}", err),
199 "Shape mismatch: expected [2, 3], got [3, 2]"
200 );
201
202 let err = TensorError::IndexOutOfBounds {
203 index: 5,
204 dim: 0,
205 size: 3,
206 };
207 assert_eq!(
208 format!("{}", err),
209 "Index 5 out of bounds for dimension 0 (size 3)"
210 );
211 }
212
213 #[test]
214 fn test_error_equality() {
215 let err1 = TensorError::ShapeMismatch {
216 expected: vec![2, 3],
217 got: vec![3, 2],
218 };
219 let err2 = TensorError::ShapeMismatch {
220 expected: vec![2, 3],
221 got: vec![3, 2],
222 };
223 assert_eq!(err1, err2);
224
225 let err3 = TensorError::ShapeMismatch {
226 expected: vec![2, 3],
227 got: vec![2, 3],
228 };
229 assert_ne!(err1, err3);
230 }
231}