use thiserror::Error;
pub type Result<T> = std::result::Result<T, CoreError>;
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum CoreError {
#[error("invalid configuration: {0}")]
InvalidConfig(String),
#[error("shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("dimension mismatch: {message}")]
DimensionMismatch {
message: String,
},
#[error("device not available: {device}")]
DeviceNotAvailable {
device: String,
},
#[error("device mismatch: tensors must be on the same device")]
DeviceMismatch,
#[error("out of memory: {message}")]
OutOfMemory {
message: String,
},
#[error("kernel error: {message}")]
KernelError {
message: String,
},
#[error("not implemented: {feature}")]
NotImplemented {
feature: String,
},
#[error("I/O error: {0}")]
Io(String),
#[error("candle error: {0}")]
Candle(#[from] candle_core::Error),
}
impl CoreError {
pub fn invalid_config(msg: impl Into<String>) -> Self {
Self::InvalidConfig(msg.into())
}
pub fn shape_mismatch(expected: impl Into<Vec<usize>>, actual: impl Into<Vec<usize>>) -> Self {
Self::ShapeMismatch {
expected: expected.into(),
actual: actual.into(),
}
}
pub fn dim_mismatch(msg: impl Into<String>) -> Self {
Self::DimensionMismatch {
message: msg.into(),
}
}
pub fn device_not_available(device: impl Into<String>) -> Self {
Self::DeviceNotAvailable {
device: device.into(),
}
}
pub fn oom(msg: impl Into<String>) -> Self {
Self::OutOfMemory {
message: msg.into(),
}
}
pub fn kernel(msg: impl Into<String>) -> Self {
Self::KernelError {
message: msg.into(),
}
}
pub fn not_implemented(feature: impl Into<String>) -> Self {
Self::NotImplemented {
feature: feature.into(),
}
}
pub fn io(msg: impl Into<String>) -> Self {
Self::Io(msg.into())
}
}
impl From<std::io::Error> for CoreError {
fn from(err: std::io::Error) -> Self {
Self::Io(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = CoreError::invalid_config("rank must be positive");
assert_eq!(
err.to_string(),
"invalid configuration: rank must be positive"
);
let err = CoreError::shape_mismatch(vec![2, 3], vec![3, 2]);
assert!(err.to_string().contains("shape mismatch"));
let err = CoreError::device_not_available("CUDA:5");
assert!(err.to_string().contains("CUDA:5"));
}
#[test]
fn test_error_conversion() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let core_err: CoreError = io_err.into();
assert!(matches!(core_err, CoreError::Io(_)));
}
}