use std::fmt;
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum NnError {
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
NoGradient,
InvalidParameter {
name: &'static str,
reason: &'static str,
},
EmptyInput,
IndexOutOfBounds { index: usize, len: usize },
OnnxError(String),
SerializeError(String),
CoreError(scivex_core::CoreError),
#[cfg(feature = "gpu")]
GpuError(String),
}
impl fmt::Display for NnError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ShapeMismatch { expected, got } => {
write!(f, "shape mismatch: expected {expected:?}, got {got:?}")
}
Self::NoGradient => write!(f, "gradient is not available"),
Self::InvalidParameter { name, reason } => {
write!(f, "invalid parameter `{name}`: {reason}")
}
Self::OnnxError(msg) => write!(f, "onnx: {msg}"),
Self::SerializeError(msg) => write!(f, "serialize: {msg}"),
Self::EmptyInput => write!(f, "input data is empty"),
Self::IndexOutOfBounds { index, len } => {
write!(f, "index {index} out of bounds for length {len}")
}
Self::CoreError(e) => write!(f, "core: {e}"),
#[cfg(feature = "gpu")]
Self::GpuError(e) => write!(f, "gpu: {e}"),
}
}
}
impl std::error::Error for NnError {}
impl From<scivex_core::CoreError> for NnError {
fn from(e: scivex_core::CoreError) -> Self {
Self::CoreError(e)
}
}
#[cfg(feature = "gpu")]
impl From<scivex_gpu::GpuError> for NnError {
fn from(e: scivex_gpu::GpuError) -> Self {
Self::GpuError(e.to_string())
}
}
pub type Result<T> = std::result::Result<T, NnError>;