use thiserror::Error;
pub type CnnResult<T> = Result<T, CnnError>;
#[derive(Error, Debug, Clone)]
pub enum CnnError {
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Model error: {0}")]
ModelError(String),
#[error("Dimension mismatch: {0}")]
DimensionMismatch(String),
#[error("SIMD error: {0}")]
SimdError(String),
#[error("Quantization error: {0}")]
QuantizationError(String),
#[error("Invalid shape: expected {expected}, got {got}")]
InvalidShape {
expected: String,
got: String,
},
#[error("Shape mismatch: {0}")]
ShapeMismatch(String),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("Memory allocation failed: {0}")]
AllocationError(String),
#[error("Invalid channel count: expected {expected}, got {actual}")]
InvalidChannels {
expected: usize,
actual: usize,
},
#[error("Invalid convolution parameters: {0}")]
InvalidConvParams(String),
#[error("Weight loading error: {0}")]
WeightLoadError(String),
#[error("Empty input: {0}")]
EmptyInput(String),
#[error("Numerical instability: {0}")]
NumericalInstability(String),
#[error("Unsupported backbone: {0}")]
UnsupportedBackbone(String),
#[error("Batch processing error: {0}")]
BatchError(String),
#[error("Convolution error: {0}")]
ConvolutionError(String),
#[error("Pooling error: {0}")]
PoolingError(String),
#[error("Normalization error: {0}")]
NormalizationError(String),
#[error("Invalid kernel: kernel_size={kernel_size}, but input spatial dims are ({height}, {width})")]
InvalidKernel {
kernel_size: usize,
height: usize,
width: usize,
},
#[error("IO error: {0}")]
IoError(String),
#[error("Image error: {0}")]
ImageError(String),
#[error("Index out of bounds: {index} >= {size}")]
IndexOutOfBounds {
index: usize,
size: usize,
},
#[error("Unsupported operation: {0}")]
Unsupported(String),
}
impl From<std::io::Error> for CnnError {
fn from(err: std::io::Error) -> Self {
CnnError::IoError(err.to_string())
}
}
impl CnnError {
pub fn dim_mismatch(expected: usize, actual: usize) -> Self {
Self::DimensionMismatch(format!("expected {expected}, got {actual}"))
}
pub fn invalid_shape(expected: impl Into<String>, got: impl Into<String>) -> Self {
Self::InvalidShape {
expected: expected.into(),
got: got.into(),
}
}
pub fn shape_mismatch(msg: impl Into<String>) -> Self {
Self::ShapeMismatch(msg.into())
}
pub fn invalid_parameter(msg: impl Into<String>) -> Self {
Self::InvalidParameter(msg.into())
}
pub fn invalid_config(msg: impl Into<String>) -> Self {
Self::InvalidConfig(msg.into())
}
pub fn convolution_error(msg: impl Into<String>) -> Self {
Self::ConvolutionError(msg.into())
}
pub fn pooling_error(msg: impl Into<String>) -> Self {
Self::PoolingError(msg.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = CnnError::DimensionMismatch("expected 64, got 32".to_string());
assert!(err.to_string().contains("expected 64"));
let err = CnnError::InvalidConfig("kernel_size must be positive".to_string());
assert_eq!(
err.to_string(),
"Invalid configuration: kernel_size must be positive"
);
}
#[test]
fn test_error_clone() {
let err = CnnError::ConvolutionError("test".to_string());
let cloned = err.clone();
assert_eq!(err.to_string(), cloned.to_string());
}
#[test]
fn test_invalid_kernel_error() {
let err = CnnError::InvalidKernel {
kernel_size: 7,
height: 3,
width: 3,
};
assert!(err.to_string().contains("kernel_size=7"));
assert!(err.to_string().contains("(3, 3)"));
}
#[test]
fn test_invalid_channels_error() {
let err = CnnError::InvalidChannels {
expected: 3,
actual: 1,
};
assert!(err.to_string().contains("expected 3"));
assert!(err.to_string().contains("got 1"));
}
#[test]
fn test_helper_methods() {
let err = CnnError::invalid_shape("NCHW", "NHWC");
assert!(err.to_string().contains("NCHW"));
assert!(err.to_string().contains("NHWC"));
let err = CnnError::invalid_config("dropout must be in [0, 1]");
assert!(err.to_string().contains("dropout"));
let err = CnnError::dim_mismatch(64, 32);
assert!(err.to_string().contains("64"));
assert!(err.to_string().contains("32"));
}
}