use thiserror::Error;
pub type ModelResult<T> = Result<T, ModelError>;
#[derive(Error, Debug)]
pub enum ModelError {
#[error("missing tensor: {name}")]
MissingTensor { name: String },
#[error("shape mismatch for '{name}': expected {expected:?}, got {actual:?}")]
ShapeMismatch {
name: String,
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("sequence length {seq_len} exceeds max context {max_ctx}")]
SequenceTooLong { seq_len: usize, max_ctx: usize },
#[error("core: {0}")]
Core(#[from] oxibonsai_core::error::BonsaiError),
#[error("kernel: {0}")]
Kernel(#[from] oxibonsai_kernels::error::KernelError),
#[error("internal: {0}")]
Internal(String),
}
impl ModelError {
pub fn error_code(&self) -> &str {
match self {
Self::MissingTensor { .. } => "MISSING_TENSOR",
Self::ShapeMismatch { .. } => "SHAPE_MISMATCH",
Self::SequenceTooLong { .. } => "SEQUENCE_TOO_LONG",
Self::Core(_) => "CORE_ERROR",
Self::Kernel(_) => "KERNEL_ERROR",
Self::Internal(_) => "INTERNAL_ERROR",
}
}
}