use core::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TensorError {
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
IndexOutOfBounds {
index: usize,
dim: usize,
size: usize,
},
DimensionMismatch {
expected: usize,
got: usize,
},
UnsupportedDType {
dtype: String,
operation: String,
},
UnsupportedDevice {
device: String,
},
OutOfMemory {
requested_bytes: usize,
},
BlasError {
code: i32,
description: String,
},
SparseFormatError {
from: String,
to: String,
description: String,
},
BroadcastError {
shape1: Vec<usize>,
shape2: Vec<usize>,
},
SliceError {
description: String,
},
DeviceTransferError {
from: String,
to: String,
description: String,
},
AllocationError {
message: String,
},
MatrixError {
message: String,
},
}
impl fmt::Display for TensorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TensorError::ShapeMismatch { expected, got } => {
write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
}
TensorError::IndexOutOfBounds { index, dim, size } => {
write!(
f,
"Index {} out of bounds for dimension {} (size {})",
index, dim, size
)
}
TensorError::DimensionMismatch { expected, got } => {
write!(
f,
"Dimension mismatch: expected {}D, got {}D",
expected, got
)
}
TensorError::UnsupportedDType { dtype, operation } => {
write!(f, "Unsupported dtype {} for operation {}", dtype, operation)
}
TensorError::UnsupportedDevice { device } => {
write!(f, "Unsupported device: {}", device)
}
TensorError::OutOfMemory { requested_bytes } => {
write!(f, "Out of memory: requested {} bytes", requested_bytes)
}
TensorError::BlasError { code, description } => {
write!(f, "BLAS error (code {}): {}", code, description)
}
TensorError::SparseFormatError {
from,
to,
description,
} => {
write!(
f,
"Sparse format error converting {} to {}: {}",
from, to, description
)
}
TensorError::BroadcastError { shape1, shape2 } => {
write!(f, "Cannot broadcast shapes {:?} and {:?}", shape1, shape2)
}
TensorError::SliceError { description } => {
write!(f, "Slice error: {}", description)
}
TensorError::DeviceTransferError {
from,
to,
description,
} => {
write!(
f,
"Device transfer error from {} to {}: {}",
from, to, description
)
}
TensorError::AllocationError { message } => {
write!(f, "Allocation error: {}", message)
}
TensorError::MatrixError { message } => {
write!(f, "Matrix error: {}", message)
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for TensorError {}
pub type TensorResult<T> = Result<T, TensorError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = TensorError::ShapeMismatch {
expected: vec![2, 3],
got: vec![3, 2],
};
assert_eq!(
format!("{}", err),
"Shape mismatch: expected [2, 3], got [3, 2]"
);
let err = TensorError::IndexOutOfBounds {
index: 5,
dim: 0,
size: 3,
};
assert_eq!(
format!("{}", err),
"Index 5 out of bounds for dimension 0 (size 3)"
);
}
#[test]
fn test_error_equality() {
let err1 = TensorError::ShapeMismatch {
expected: vec![2, 3],
got: vec![3, 2],
};
let err2 = TensorError::ShapeMismatch {
expected: vec![2, 3],
got: vec![3, 2],
};
assert_eq!(err1, err2);
let err3 = TensorError::ShapeMismatch {
expected: vec![2, 3],
got: vec![2, 3],
};
assert_ne!(err1, err3);
}
}