1use thiserror::Error;
7
8pub type Result<T> = std::result::Result<T, Error>;
10
11#[derive(Error, Debug)]
13pub enum Error {
14 #[error("model error: {0}")]
16 Model(#[from] ModelError),
17
18 #[error("training error: {0}")]
20 Training(#[from] TrainingError),
21
22 #[error("inference error: {0}")]
24 Inference(#[from] InferenceError),
25
26 #[error("checkpoint error: {0}")]
28 Checkpoint(#[from] CheckpointError),
29
30 #[error("device error: {0}")]
32 Device(#[from] DeviceError),
33
34 #[error("embedding error: {0}")]
36 Embedding(#[from] EmbeddingError),
37
38 #[error("io error: {0}")]
40 Io(#[from] std::io::Error),
41
42 #[error("candle error: {0}")]
44 Candle(#[from] candle_core::Error),
45
46 #[error("JSON error: {0}")]
48 Json(#[from] serde_json::Error),
49}
50
51#[derive(Error, Debug)]
53pub enum ModelError {
54 #[error("model file not found: {path}")]
56 FileNotFound {
57 path: std::path::PathBuf,
59 },
60
61 #[error("invalid model format: {reason}")]
63 InvalidFormat {
64 reason: String,
66 },
67
68 #[error("incompatible model version: expected {expected}, found {found}")]
70 IncompatibleVersion {
71 expected: String,
73 found: String,
75 },
76
77 #[error("invalid model configuration: {reason}")]
79 InvalidConfig {
80 reason: String,
82 },
83
84 #[error("tensor shape mismatch: expected {expected:?}, got {actual:?}")]
86 ShapeMismatch {
87 expected: Vec<usize>,
89 actual: Vec<usize>,
91 },
92}
93
94#[derive(Error, Debug)]
96pub enum TrainingError {
97 #[error("invalid LoRA configuration: {reason}")]
99 InvalidLoRAConfig {
100 reason: String,
102 },
103
104 #[error("invalid training configuration: {reason}")]
106 InvalidConfig {
107 reason: String,
109 },
110
111 #[error("training failed: {reason}")]
113 Failed {
114 reason: String,
116 },
117
118 #[error("gradient error: {reason}")]
120 Gradient {
121 reason: String,
123 },
124
125 #[error("training state error: {reason}")]
127 StateError {
128 reason: String,
130 },
131}
132
133#[derive(Error, Debug)]
135pub enum InferenceError {
136 #[error("invalid generation configuration: {reason}")]
138 InvalidConfig {
139 reason: String,
141 },
142
143 #[error("generation failed: {reason}")]
145 Failed {
146 reason: String,
148 },
149
150 #[error("invalid sampling parameters: {reason}")]
152 InvalidSampling {
153 reason: String,
155 },
156
157 #[error("KV-cache full: position {position} exceeds max length {max_len}")]
159 CacheFull {
160 position: usize,
162 max_len: usize,
164 },
165
166 #[error("token sampling failed: {reason}")]
168 SamplingError {
169 reason: String,
171 },
172
173 #[error("end of sequence reached at position {position}")]
175 EndOfSequence {
176 position: usize,
178 },
179}
180
181#[derive(Error, Debug)]
183pub enum CheckpointError {
184 #[error("failed to save checkpoint: {reason}")]
186 SaveFailed {
187 reason: String,
189 },
190
191 #[error("failed to load checkpoint: {reason}")]
193 LoadFailed {
194 reason: String,
196 },
197
198 #[error("invalid checkpoint format: {reason}")]
200 InvalidFormat {
201 reason: String,
203 },
204}
205
206#[derive(Error, Debug)]
208pub enum DeviceError {
209 #[error("Metal device not available: {reason}")]
211 MetalUnavailable {
212 reason: String,
214 },
215
216 #[error("device initialization failed: {reason}")]
218 InitializationFailed {
219 reason: String,
221 },
222
223 #[error("memory allocation failed: requested {requested_bytes} bytes")]
225 AllocationFailed {
226 requested_bytes: usize,
228 },
229
230 #[error("tensor operation failed: {operation}")]
232 OperationFailed {
233 operation: String,
235 },
236}
237
238#[derive(Error, Debug)]
240pub enum EmbeddingError {
241 #[error("model download failed: {reason}")]
243 DownloadFailed {
244 reason: String,
246 },
247
248 #[error("model not found: {model_id}")]
250 ModelNotFound {
251 model_id: String,
253 },
254
255 #[error("tokenizer loading failed: {reason}")]
257 TokenizerFailed {
258 reason: String,
260 },
261
262 #[error("tokenization failed: {reason}")]
264 TokenizationFailed {
265 reason: String,
267 },
268
269 #[error("cannot encode empty text array")]
271 EmptyInput,
272
273 #[error("invalid embedding configuration: {reason}")]
275 InvalidConfig {
276 reason: String,
278 },
279
280 #[error("embedding dimension mismatch: expected {expected}, got {actual}")]
282 DimensionMismatch {
283 expected: usize,
285 actual: usize,
287 },
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_error_display() {
296 let err = Error::Model(ModelError::FileNotFound {
297 path: std::path::PathBuf::from("/path/to/model.safetensors"),
298 });
299 assert!(err.to_string().contains("model file not found"));
300 }
301
302 #[test]
303 fn test_error_from_io() {
304 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
305 let err: Error = io_err.into();
306 assert!(matches!(err, Error::Io(_)));
307 }
308
309 #[test]
310 fn test_model_error_types() {
311 let err = ModelError::ShapeMismatch {
312 expected: vec![1, 2, 3],
313 actual: vec![1, 2, 4],
314 };
315 assert!(err.to_string().contains("shape mismatch"));
316 }
317
318 #[test]
319 fn test_training_error_types() {
320 let err = TrainingError::InvalidLoRAConfig {
321 reason: "rank must be > 0".to_string(),
322 };
323 assert!(err.to_string().contains("invalid LoRA configuration"));
324
325 let err = TrainingError::InvalidConfig {
326 reason: "invalid".to_string(),
327 };
328 assert!(err.to_string().contains("invalid training configuration"));
329
330 let err = TrainingError::StateError {
331 reason: "state error".to_string(),
332 };
333 assert!(err.to_string().contains("training state error"));
334 }
335
336 #[test]
337 fn test_device_error_types() {
338 let err = DeviceError::MetalUnavailable {
339 reason: "not running on Apple Silicon".to_string(),
340 };
341 assert!(err.to_string().contains("Metal device not available"));
342 }
343
344 #[test]
345 fn test_embedding_error_types() {
346 let err = EmbeddingError::ModelNotFound {
347 model_id: "intfloat/e5-small-v2".to_string(),
348 };
349 assert!(err.to_string().contains("model not found"));
350
351 let err = EmbeddingError::EmptyInput;
352 assert!(err.to_string().contains("cannot encode empty text"));
353
354 let err = EmbeddingError::DimensionMismatch {
355 expected: 384,
356 actual: 768,
357 };
358 assert!(err.to_string().contains("dimension mismatch"));
359 }
360}