use thiserror::Error;
use crate::device::Device;
use crate::dtype::DType;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum Error {
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("DType mismatch: expected {expected:?}, got {actual:?}")]
DTypeMismatch {
expected: DType,
actual: DType,
},
#[error("Device mismatch: expected {expected:?}, got {actual:?}")]
DeviceMismatch {
expected: Device,
actual: Device,
},
#[error("Invalid dimension: index {index} for tensor with {ndim} dimensions")]
InvalidDimension {
index: i64,
ndim: usize,
},
#[error("Index out of bounds: index {index} for dimension of size {size}")]
IndexOutOfBounds {
index: usize,
size: usize,
},
#[error("Memory allocation failed: requested {size} bytes on {device:?}")]
AllocationFailed {
size: usize,
device: Device,
},
#[error("Device not available: {device:?}")]
DeviceNotAvailable {
device: Device,
},
#[error("Invalid operation: {message}")]
InvalidOperation {
message: String,
},
#[error("Cannot broadcast shapes {shape1:?} and {shape2:?}")]
BroadcastError {
shape1: Vec<usize>,
shape2: Vec<usize>,
},
#[error("Operation not supported on empty tensor")]
EmptyTensor,
#[error("Operation requires contiguous tensor")]
NotContiguous,
#[error("Gradient error: {message}")]
GradientError {
message: String,
},
#[error("Serialization error: {message}")]
SerializationError {
message: String,
},
#[error("Internal error: {message}")]
InternalError {
message: String,
},
}
pub type Result<T> = core::result::Result<T, Error>;
impl Error {
#[must_use]
pub fn shape_mismatch(expected: &[usize], actual: &[usize]) -> Self {
Self::ShapeMismatch {
expected: expected.to_vec(),
actual: actual.to_vec(),
}
}
#[must_use]
pub fn invalid_operation(message: impl Into<String>) -> Self {
Self::InvalidOperation {
message: message.into(),
}
}
#[must_use]
pub fn internal(message: impl Into<String>) -> Self {
Self::InternalError {
message: message.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = Error::shape_mismatch(&[2, 3], &[2, 4]);
assert!(err.to_string().contains("Shape mismatch"));
}
#[test]
fn test_error_equality() {
let err1 = Error::EmptyTensor;
let err2 = Error::EmptyTensor;
assert_eq!(err1, err2);
}
}