1use std::error;
2use std::fmt;
3
4#[derive(Debug, Clone)]
5pub enum TensorError {
6 MaxDimsError,
7 InvalidTensor,
8 SliceError,
9 ViewError,
10 BroadcastError,
11 OpError,
12 DimError,
13 MatmulShapeError,
14 ShapeError,
15 GradError,
16}
17
18impl fmt::Display for TensorError {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 match self {
21 TensorError::MaxDimsError => write!(f, "L2 currently only supports tensors with up to 4 dimensions"),
22 TensorError::InvalidTensor => write!(f, "Invalid parameters for Tensor"),
23 TensorError::SliceError => write!(f, "Invalid slice for Tensor"),
24 TensorError::ViewError => write!(f, "Invalid view shape for Tensor"),
25 TensorError::BroadcastError => write!(f, "Shapes are not broadcastable"),
26 TensorError::OpError => write!(f, "Tensors cannot be operated on"),
27 TensorError::DimError => write!(f, "Tensors cannot be operated on over the given dimension"),
28 TensorError::MatmulShapeError => write!(
29 f,
30 "Tensors must have at least two dimensions and have same shape in all dims except the last dimension"
31 ),
32 TensorError::ShapeError => write!(f, "Tensors must have the same shape in all dims except the last dimension"),
33 TensorError::GradError => write!(f, "Error while computing .backward()"),
34 }
35 }
36}
37
38impl error::Error for TensorError {
40 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
41 None
43 }
44}