1use thiserror::Error;
6
7pub type ModelResult<T> = Result<T, ModelError>;
9
10#[derive(Error, Debug)]
12pub enum ModelError {
13 #[error("Invalid model configuration: {message}")]
14 InvalidConfig { message: String },
15
16 #[error("Dimension mismatch in {context}: expected {expected}, got {got}")]
17 DimensionMismatch {
18 context: String,
19 expected: usize,
20 got: usize,
21 },
22
23 #[error("Model not initialized: {details}")]
24 NotInitialized { details: String },
25
26 #[error("Weight loading failed for tensor '{tensor_name}': {reason}")]
27 WeightLoadError { tensor_name: String, reason: String },
28
29 #[error("Tensor not found: '{name}' in model '{model}'")]
30 TensorNotFound { name: String, model: String },
31
32 #[error("Load error in {context}: {message}")]
33 LoadError { context: String, message: String },
34
35 #[error("Forward pass error at layer {layer_idx}: {message}")]
36 ForwardError { layer_idx: usize, message: String },
37
38 #[error("State count mismatch for {model}: expected {expected} layers, got {got}")]
39 StateCountMismatch {
40 model: String,
41 expected: usize,
42 got: usize,
43 },
44
45 #[error("Invalid batch size: expected {expected}, got {got}")]
46 InvalidBatchSize { expected: usize, got: usize },
47
48 #[error("Numerical instability detected in {operation}: {details}")]
49 NumericalInstability { operation: String, details: String },
50
51 #[error("Unsupported operation: {operation} for model type {model_type}")]
52 UnsupportedOperation {
53 operation: String,
54 model_type: String,
55 },
56
57 #[error("Quantization error: {message}")]
58 QuantizationError { message: String },
59
60 #[error("Memory allocation failed: requested {bytes} bytes for {purpose}")]
61 AllocationError { bytes: usize, purpose: String },
62
63 #[error("Index out of bounds: index {index} exceeds limit {limit} in {context}")]
64 IndexOutOfBounds {
65 index: usize,
66 limit: usize,
67 context: String,
68 },
69
70 #[error("Core error: {0}")]
71 CoreError(#[from] kizzasi_core::CoreError),
72
73 #[error("Candle error: {0}")]
74 CandleError(#[from] candle_core::Error),
75
76 #[error("I/O error: {0}")]
77 IoError(#[from] std::io::Error),
78}
79
80impl ModelError {
81 pub fn invalid_config(message: impl Into<String>) -> Self {
83 Self::InvalidConfig {
84 message: message.into(),
85 }
86 }
87
88 pub fn dimension_mismatch(context: impl Into<String>, expected: usize, got: usize) -> Self {
90 Self::DimensionMismatch {
91 context: context.into(),
92 expected,
93 got,
94 }
95 }
96
97 pub fn not_initialized(details: impl Into<String>) -> Self {
99 Self::NotInitialized {
100 details: details.into(),
101 }
102 }
103
104 pub fn load_error(context: impl Into<String>, message: impl Into<String>) -> Self {
106 Self::LoadError {
107 context: context.into(),
108 message: message.into(),
109 }
110 }
111
112 pub fn simple_load_error(message: impl Into<String>) -> Self {
114 Self::LoadError {
115 context: "general".into(),
116 message: message.into(),
117 }
118 }
119
120 pub fn forward_error(layer_idx: usize, message: impl Into<String>) -> Self {
122 Self::ForwardError {
123 layer_idx,
124 message: message.into(),
125 }
126 }
127
128 pub fn weight_load_error(tensor_name: impl Into<String>, reason: impl Into<String>) -> Self {
130 Self::WeightLoadError {
131 tensor_name: tensor_name.into(),
132 reason: reason.into(),
133 }
134 }
135
136 pub fn tensor_not_found(name: impl Into<String>, model: impl Into<String>) -> Self {
138 Self::TensorNotFound {
139 name: name.into(),
140 model: model.into(),
141 }
142 }
143
144 pub fn state_count_mismatch(model: impl Into<String>, expected: usize, got: usize) -> Self {
146 Self::StateCountMismatch {
147 model: model.into(),
148 expected,
149 got,
150 }
151 }
152
153 pub fn numerical_instability(operation: impl Into<String>, details: impl Into<String>) -> Self {
155 Self::NumericalInstability {
156 operation: operation.into(),
157 details: details.into(),
158 }
159 }
160
161 pub fn unsupported_operation(
163 operation: impl Into<String>,
164 model_type: impl Into<String>,
165 ) -> Self {
166 Self::UnsupportedOperation {
167 operation: operation.into(),
168 model_type: model_type.into(),
169 }
170 }
171
172 pub fn quantization_error(message: impl Into<String>) -> Self {
174 Self::QuantizationError {
175 message: message.into(),
176 }
177 }
178}