use core::fmt;
use crate::math::scalar::ScalarCastError;
pub type TensorResult<T> = Result<T, TensorError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TensorError {
InvalidShape { shape: Vec<usize> },
ShapeProductOverflow { shape: Vec<usize> },
ShapeMismatch { lhs: Vec<usize>, rhs: Vec<usize> },
RankMismatch {
shape: Vec<usize>,
index_rank: usize,
},
ExpectedRank {
operation: &'static str,
expected: usize,
actual: usize,
},
ScalarCast(ScalarCastError),
}
impl From<ScalarCastError> for TensorError {
#[inline]
fn from(value: ScalarCastError) -> Self {
Self::ScalarCast(value)
}
}
impl fmt::Display for TensorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidShape { shape } => {
write!(
f,
"tensor shape must contain at least one nonzero axis; got {shape:?}"
)
}
Self::ShapeProductOverflow { shape } => {
write!(f, "tensor shape product overflowed usize; got {shape:?}")
}
Self::ShapeMismatch { lhs, rhs } => {
write!(f, "tensor shape mismatch: lhs={lhs:?}, rhs={rhs:?}")
}
Self::RankMismatch { shape, index_rank } => {
write!(
f,
"tensor index rank mismatch: shape rank={}, index rank={index_rank}",
shape.len()
)
}
Self::ExpectedRank {
operation,
expected,
actual,
} => {
write!(
f,
"{operation} requires rank {expected}, but tensor rank is {actual}"
)
}
Self::ScalarCast(error) => write!(f, "tensor scalar cast failed: {error}"),
}
}
}
impl std::error::Error for TensorError {}
#[inline]
pub fn validate_shape(shape: &[usize]) -> TensorResult<()> {
if shape.is_empty() || shape.iter().any(|&dim| dim == 0) {
return Err(TensorError::InvalidShape {
shape: shape.to_vec(),
});
}
Ok(())
}
#[inline]
pub fn checked_num_elements(shape: &[usize]) -> TensorResult<usize> {
validate_shape(shape)?;
shape.iter().try_fold(1usize, |acc, &dim| {
acc.checked_mul(dim)
.ok_or_else(|| TensorError::ShapeProductOverflow {
shape: shape.to_vec(),
})
})
}
#[inline]
pub fn ensure_same_shape(lhs: &[usize], rhs: &[usize]) -> TensorResult<()> {
if lhs != rhs {
return Err(TensorError::ShapeMismatch {
lhs: lhs.to_vec(),
rhs: rhs.to_vec(),
});
}
Ok(())
}
#[inline]
pub fn ensure_index_rank(shape: &[usize], index_rank: usize) -> TensorResult<()> {
if shape.len() != index_rank {
return Err(TensorError::RankMismatch {
shape: shape.to_vec(),
index_rank,
});
}
Ok(())
}
#[inline]
pub fn ensure_rank(operation: &'static str, actual: usize, expected: usize) -> TensorResult<()> {
if actual != expected {
return Err(TensorError::ExpectedRank {
operation,
expected,
actual,
});
}
Ok(())
}