use thiserror::Error;
pub type ModelResult<T> = Result<T, ModelError>;
#[derive(Error, Debug)]
pub enum ModelError {
#[error("Invalid model configuration: {message}")]
InvalidConfig { message: String },
#[error("Dimension mismatch in {context}: expected {expected}, got {got}")]
DimensionMismatch {
context: String,
expected: usize,
got: usize,
},
#[error("Model not initialized: {details}")]
NotInitialized { details: String },
#[error("Weight loading failed for tensor '{tensor_name}': {reason}")]
WeightLoadError { tensor_name: String, reason: String },
#[error("Tensor not found: '{name}' in model '{model}'")]
TensorNotFound { name: String, model: String },
#[error("Load error in {context}: {message}")]
LoadError { context: String, message: String },
#[error("Forward pass error at layer {layer_idx}: {message}")]
ForwardError { layer_idx: usize, message: String },
#[error("State count mismatch for {model}: expected {expected} layers, got {got}")]
StateCountMismatch {
model: String,
expected: usize,
got: usize,
},
#[error("Invalid batch size: expected {expected}, got {got}")]
InvalidBatchSize { expected: usize, got: usize },
#[error("Numerical instability detected in {operation}: {details}")]
NumericalInstability { operation: String, details: String },
#[error("Unsupported operation: {operation} for model type {model_type}")]
UnsupportedOperation {
operation: String,
model_type: String,
},
#[error("Quantization error: {message}")]
QuantizationError { message: String },
#[error("Memory allocation failed: requested {bytes} bytes for {purpose}")]
AllocationError { bytes: usize, purpose: String },
#[error("Index out of bounds: index {index} exceeds limit {limit} in {context}")]
IndexOutOfBounds {
index: usize,
limit: usize,
context: String,
},
#[error("Core error: {0}")]
CoreError(#[from] kizzasi_core::CoreError),
#[error("Candle error: {0}")]
CandleError(#[from] candle_core::Error),
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
}
impl ModelError {
pub fn invalid_config(message: impl Into<String>) -> Self {
Self::InvalidConfig {
message: message.into(),
}
}
pub fn dimension_mismatch(context: impl Into<String>, expected: usize, got: usize) -> Self {
Self::DimensionMismatch {
context: context.into(),
expected,
got,
}
}
pub fn not_initialized(details: impl Into<String>) -> Self {
Self::NotInitialized {
details: details.into(),
}
}
pub fn load_error(context: impl Into<String>, message: impl Into<String>) -> Self {
Self::LoadError {
context: context.into(),
message: message.into(),
}
}
pub fn simple_load_error(message: impl Into<String>) -> Self {
Self::LoadError {
context: "general".into(),
message: message.into(),
}
}
pub fn forward_error(layer_idx: usize, message: impl Into<String>) -> Self {
Self::ForwardError {
layer_idx,
message: message.into(),
}
}
pub fn weight_load_error(tensor_name: impl Into<String>, reason: impl Into<String>) -> Self {
Self::WeightLoadError {
tensor_name: tensor_name.into(),
reason: reason.into(),
}
}
pub fn tensor_not_found(name: impl Into<String>, model: impl Into<String>) -> Self {
Self::TensorNotFound {
name: name.into(),
model: model.into(),
}
}
pub fn state_count_mismatch(model: impl Into<String>, expected: usize, got: usize) -> Self {
Self::StateCountMismatch {
model: model.into(),
expected,
got,
}
}
pub fn numerical_instability(operation: impl Into<String>, details: impl Into<String>) -> Self {
Self::NumericalInstability {
operation: operation.into(),
details: details.into(),
}
}
pub fn unsupported_operation(
operation: impl Into<String>,
model_type: impl Into<String>,
) -> Self {
Self::UnsupportedOperation {
operation: operation.into(),
model_type: model_type.into(),
}
}
pub fn quantization_error(message: impl Into<String>) -> Self {
Self::QuantizationError {
message: message.into(),
}
}
}