l2/
errors.rs

1use std::error;
2use std::fmt;
3
4#[derive(Debug, Clone)]
5pub enum TensorError {
6    MaxDimsError,
7    InvalidTensor,
8    SliceError,
9    ViewError,
10    BroadcastError,
11    OpError,
12    DimError,
13    MatmulShapeError,
14    ShapeError,
15    GradError,
16}
17
18impl fmt::Display for TensorError {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        match self {
21            TensorError::MaxDimsError => write!(f, "L2 currently only supports tensors with up to 4 dimensions"),
22            TensorError::InvalidTensor => write!(f, "Invalid parameters for Tensor"),
23            TensorError::SliceError => write!(f, "Invalid slice for Tensor"),
24            TensorError::ViewError => write!(f, "Invalid view shape for Tensor"),
25            TensorError::BroadcastError => write!(f, "Shapes are not broadcastable"),
26            TensorError::OpError => write!(f, "Tensors cannot be operated on"),
27            TensorError::DimError => write!(f, "Tensors cannot be operated on over the given dimension"),
28            TensorError::MatmulShapeError => write!(
29                f,
30                "Tensors must have at least two dimensions and have same shape in all dims except the last dimension"
31            ),
32            TensorError::ShapeError => write!(f, "Tensors must have the same shape in all dims except the last dimension"),
33            TensorError::GradError => write!(f, "Error while computing .backward()"),
34        }
35    }
36}
37
38// This is important for other errors to wrap this one.
39impl error::Error for TensorError {
40    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
41        // Generic error, underlying cause isn't tracked.
42        None
43    }
44}