pub type EtensorResult<T> = Result<T, EtensorError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EtensorError {
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
DeviceMismatch { expected: String, got: String },
DTypeMismatch { expected: String, got: String },
AutogradError(String),
BackendError(String),
IoError(String),
InternalError(String),
}
impl std::fmt::Display for EtensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EtensorError::ShapeMismatch { expected, got } => {
write!(
f,
"ShapeMismatch: Operation expected shape {:?}, but got {:?}",
expected, got
)
}
EtensorError::DeviceMismatch { expected, got } => {
write!(
f,
"DeviceMismatch: Operation expected tensors on '{}', but found tensor on '{}'",
expected, got
)
}
EtensorError::DTypeMismatch { expected, got } => {
write!(
f,
"DTypeMismatch: Operation expected dtype '{}', but found '{}'",
expected, got
)
}
EtensorError::AutogradError(msg) => write!(f, "AutogradError: {}", msg),
EtensorError::BackendError(msg) => write!(f, "BackendError: {}", msg),
EtensorError::IoError(msg) => write!(f, "IoError: {}", msg),
EtensorError::InternalError(msg) => write!(f, "InternalError: {}", msg),
}
}
}
impl std::error::Error for EtensorError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape_mismatch_formatting() {
let err = EtensorError::ShapeMismatch {
expected: vec![2, 3],
got: vec![3, 2],
};
assert_eq!(
err.to_string(),
"ShapeMismatch: Operation expected shape [2, 3], but got [3, 2]"
);
}
#[test]
fn test_device_mismatch_formatting() {
let err = EtensorError::DeviceMismatch {
expected: "cuda:0".to_string(),
got: "cpu".to_string(),
};
assert_eq!(
err.to_string(),
"DeviceMismatch: Operation expected tensors on 'cuda:0', but found tensor on 'cpu'"
);
}
#[test]
fn test_result_propagation() {
fn mock_add(a_is_cpu: bool, b_is_cpu: bool) -> EtensorResult<f32> {
if a_is_cpu != b_is_cpu {
return Err(EtensorError::DeviceMismatch {
expected: "cpu".to_string(),
got: "cuda".to_string(),
});
}
Ok(42.0)
}
assert!(mock_add(true, true).is_ok());
let err = mock_add(true, false).unwrap_err();
assert!(matches!(err, EtensorError::DeviceMismatch { .. }));
}
}