use crate::dtype::DType;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("Shape mismatch: expected {expected:?}, got {got:?}")]
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
#[error("Cannot broadcast shapes {lhs:?} and {rhs:?}")]
BroadcastError {
lhs: Vec<usize>,
rhs: Vec<usize>,
},
#[error("Invalid dimension {dim} for tensor with {ndim} dimensions")]
InvalidDimension {
dim: isize,
ndim: usize,
},
#[error("Unsupported dtype {dtype:?} for operation '{op}'")]
UnsupportedDType {
dtype: DType,
op: &'static str,
},
#[error("DType mismatch: {lhs:?} vs {rhs:?}")]
DTypeMismatch {
lhs: DType,
rhs: DType,
},
#[error("Device mismatch: tensors must be on the same device")]
DeviceMismatch,
#[error("Out of memory: failed to allocate {size} bytes")]
OutOfMemory {
size: usize,
},
#[error("Index {index} out of bounds for dimension of size {size}")]
IndexOutOfBounds {
index: usize,
size: usize,
},
#[error("Invalid argument '{arg}': {reason}")]
InvalidArgument {
arg: &'static str,
reason: String,
},
#[error("Operation requires contiguous tensor")]
NotContiguous,
#[error("Missing gradient for tensor")]
MissingGradient,
#[error("Backend error: {0}")]
Backend(String),
#[error("{backend} limitation: {operation} - {reason}")]
BackendLimitation {
backend: &'static str,
operation: &'static str,
reason: String,
},
#[cfg(feature = "cuda")]
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("{0}")]
Msg(String),
#[error("Internal error: {0}")]
Internal(String),
#[error("Not implemented: {feature}")]
NotImplemented {
feature: &'static str,
},
#[error(
"{dtype:?} requires the \"{feature}\" feature. Enable it with: cargo build --features {feature}"
)]
FeatureRequired {
dtype: DType,
feature: &'static str,
},
#[error("Allocator busy: {active_allocations} allocations still active")]
AllocatorBusy {
active_allocations: usize,
},
#[error("Allocator frozen: allocation rejected while frozen")]
AllocatorFrozen,
}
impl Error {
pub fn shape_mismatch(expected: &[usize], got: &[usize]) -> Self {
Self::ShapeMismatch {
expected: expected.to_vec(),
got: got.to_vec(),
}
}
pub fn broadcast(lhs: &[usize], rhs: &[usize]) -> Self {
Self::BroadcastError {
lhs: lhs.to_vec(),
rhs: rhs.to_vec(),
}
}
pub fn unsupported_dtype(dtype: DType, op: &'static str) -> Self {
Self::UnsupportedDType { dtype, op }
}
pub fn backend_limitation(
backend: &'static str,
operation: &'static str,
reason: impl Into<String>,
) -> Self {
Self::BackendLimitation {
backend,
operation,
reason: reason.into(),
}
}
}