use thiserror::Error;
pub type Result<T> = std::result::Result<T, NnlError>;
#[derive(Error, Debug)]
pub enum NnlError {
#[error("Tensor error: {message}")]
TensorError {
message: String,
},
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("Device error: {message}")]
DeviceError {
message: String,
},
#[error("Network error: {message}")]
NetworkError {
message: String,
},
#[error("Training error: {message}")]
TrainingError {
message: String,
},
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(#[from] bincode::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("GPU compute error: {0}")]
GpuError(String),
#[error("Invalid configuration: {message}")]
ConfigError {
message: String,
},
#[error("Math error: {message}")]
MathError {
message: String,
},
#[error("Memory error: {message}")]
MemoryError {
message: String,
},
#[error("Unsupported operation: {message}")]
UnsupportedError {
message: String,
},
#[error("Invalid input: {message}")]
InvalidInputError {
message: String,
},
}
impl NnlError {
pub fn tensor<S: Into<String>>(message: S) -> Self {
Self::TensorError {
message: message.into(),
}
}
pub fn shape_mismatch(expected: &[usize], actual: &[usize]) -> Self {
Self::ShapeMismatch {
expected: expected.to_vec(),
actual: actual.to_vec(),
}
}
pub fn device<S: Into<String>>(message: S) -> Self {
Self::DeviceError {
message: message.into(),
}
}
pub fn network<S: Into<String>>(message: S) -> Self {
Self::NetworkError {
message: message.into(),
}
}
pub fn training<S: Into<String>>(message: S) -> Self {
Self::TrainingError {
message: message.into(),
}
}
pub fn config<S: Into<String>>(message: S) -> Self {
Self::ConfigError {
message: message.into(),
}
}
pub fn math<S: Into<String>>(message: S) -> Self {
Self::MathError {
message: message.into(),
}
}
pub fn memory<S: Into<String>>(message: S) -> Self {
Self::MemoryError {
message: message.into(),
}
}
pub fn unsupported<S: Into<String>>(message: S) -> Self {
Self::UnsupportedError {
message: message.into(),
}
}
pub fn invalid_input<S: Into<String>>(message: S) -> Self {
Self::InvalidInputError {
message: message.into(),
}
}
pub fn gpu<S: Into<String>>(message: S) -> Self {
Self::GpuError(message.into())
}
pub fn io(error: std::io::Error) -> Self {
Self::IoError(error)
}
}
pub trait IntoNnlError<T> {
fn into_nnl_error(self) -> Result<T>;
fn with_context<F>(self, f: F) -> Result<T>
where
F: FnOnce() -> String;
}
impl<T, E> IntoNnlError<T> for std::result::Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn into_nnl_error(self) -> Result<T> {
self.map_err(|e| NnlError::device(e.to_string()))
}
fn with_context<F>(self, f: F) -> Result<T>
where
F: FnOnce() -> String,
{
self.map_err(|e| NnlError::device(format!("{}: {}", f(), e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_creation() {
let err = NnlError::tensor("test message");
assert!(matches!(err, NnlError::TensorError { .. }));
assert_eq!(err.to_string(), "Tensor error: test message");
}
#[test]
fn test_shape_mismatch() {
let err = NnlError::shape_mismatch(&[2, 3], &[4, 5]);
assert!(matches!(err, NnlError::ShapeMismatch { .. }));
assert!(err.to_string().contains("expected [2, 3], got [4, 5]"));
}
#[test]
fn test_error_chaining() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let nnl_err: NnlError = io_err.into();
assert!(matches!(nnl_err, NnlError::IoError(_)));
}
}