metal_candle/
error.rs

1//! Error types for metal-candle.
2//!
3//! This module defines the error types used throughout the crate.
4//! We use `thiserror` for ergonomic error handling in library code.
5
6use thiserror::Error;
7
8/// Result type alias for metal-candle operations.
9pub type Result<T> = std::result::Result<T, Error>;
10
11/// The main error type for metal-candle operations.
12#[derive(Error, Debug)]
13pub enum Error {
14    /// Error related to model operations (loading, validation, etc.)
15    #[error("model error: {0}")]
16    Model(#[from] ModelError),
17
18    /// Error related to training operations
19    #[error("training error: {0}")]
20    Training(#[from] TrainingError),
21
22    /// Error related to inference/generation operations
23    #[error("inference error: {0}")]
24    Inference(#[from] InferenceError),
25
26    /// Error related to checkpoint operations
27    #[error("checkpoint error: {0}")]
28    Checkpoint(#[from] CheckpointError),
29
30    /// Error related to device/backend operations
31    #[error("device error: {0}")]
32    Device(#[from] DeviceError),
33
34    /// Error related to embedding operations
35    #[error("embedding error: {0}")]
36    Embedding(#[from] EmbeddingError),
37
38    /// IO errors
39    #[error("io error: {0}")]
40    Io(#[from] std::io::Error),
41
42    /// Candle framework errors
43    #[error("candle error: {0}")]
44    Candle(#[from] candle_core::Error),
45
46    /// JSON serialization/deserialization error
47    #[error("JSON error: {0}")]
48    Json(#[from] serde_json::Error),
49}
50
51/// Errors related to model operations.
52#[derive(Error, Debug)]
53pub enum ModelError {
54    /// Model file not found at the specified path
55    #[error("model file not found: {path}")]
56    FileNotFound {
57        /// The path that was not found
58        path: std::path::PathBuf,
59    },
60
61    /// Invalid model format or corrupted file
62    #[error("invalid model format: {reason}")]
63    InvalidFormat {
64        /// Description of the format issue
65        reason: String,
66    },
67
68    /// Model version incompatibility
69    #[error("incompatible model version: expected {expected}, found {found}")]
70    IncompatibleVersion {
71        /// Expected version
72        expected: String,
73        /// Found version
74        found: String,
75    },
76
77    /// Invalid model configuration
78    #[error("invalid model configuration: {reason}")]
79    InvalidConfig {
80        /// Description of the configuration issue
81        reason: String,
82    },
83
84    /// Tensor shape mismatch
85    #[error("tensor shape mismatch: expected {expected:?}, got {actual:?}")]
86    ShapeMismatch {
87        /// Expected shape
88        expected: Vec<usize>,
89        /// Actual shape
90        actual: Vec<usize>,
91    },
92}
93
94/// Errors related to training operations.
95#[derive(Error, Debug)]
96pub enum TrainingError {
97    /// Invalid `LoRA` configuration
98    #[error("invalid LoRA configuration: {reason}")]
99    InvalidLoRAConfig {
100        /// Description of the configuration issue
101        reason: String,
102    },
103
104    /// Invalid training configuration
105    #[error("invalid training configuration: {reason}")]
106    InvalidConfig {
107        /// Description of the configuration issue
108        reason: String,
109    },
110
111    /// Training failed
112    #[error("training failed: {reason}")]
113    Failed {
114        /// Description of the failure
115        reason: String,
116    },
117
118    /// Gradient computation error
119    #[error("gradient error: {reason}")]
120    Gradient {
121        /// Description of the gradient issue
122        reason: String,
123    },
124
125    /// Training state error
126    #[error("training state error: {reason}")]
127    StateError {
128        /// Description of the state issue
129        reason: String,
130    },
131}
132
133/// Errors related to inference operations.
134#[derive(Error, Debug)]
135pub enum InferenceError {
136    /// Invalid generation configuration
137    #[error("invalid generation configuration: {reason}")]
138    InvalidConfig {
139        /// Description of the configuration issue
140        reason: String,
141    },
142
143    /// Generation failed
144    #[error("generation failed: {reason}")]
145    Failed {
146        /// Description of the failure
147        reason: String,
148    },
149
150    /// Invalid sampling parameters
151    #[error("invalid sampling parameters: {reason}")]
152    InvalidSampling {
153        /// Description of the sampling issue
154        reason: String,
155    },
156
157    /// KV-cache is full
158    #[error("KV-cache full: position {position} exceeds max length {max_len}")]
159    CacheFull {
160        /// Current cache position
161        position: usize,
162        /// Maximum cache length
163        max_len: usize,
164    },
165
166    /// Token sampling failed
167    #[error("token sampling failed: {reason}")]
168    SamplingError {
169        /// Description of the sampling error
170        reason: String,
171    },
172
173    /// End-of-sequence reached
174    #[error("end of sequence reached at position {position}")]
175    EndOfSequence {
176        /// Position where EOS was reached
177        position: usize,
178    },
179}
180
181/// Errors related to checkpoint operations.
182#[derive(Error, Debug)]
183pub enum CheckpointError {
184    /// Failed to save checkpoint
185    #[error("failed to save checkpoint: {reason}")]
186    SaveFailed {
187        /// Description of the save failure
188        reason: String,
189    },
190
191    /// Failed to load checkpoint
192    #[error("failed to load checkpoint: {reason}")]
193    LoadFailed {
194        /// Description of the load failure
195        reason: String,
196    },
197
198    /// Checkpoint format is invalid or corrupted
199    #[error("invalid checkpoint format: {reason}")]
200    InvalidFormat {
201        /// Description of the format issue
202        reason: String,
203    },
204}
205
206/// Errors related to device/backend operations.
207#[derive(Error, Debug)]
208pub enum DeviceError {
209    /// Metal device not available
210    #[error("Metal device not available: {reason}")]
211    MetalUnavailable {
212        /// Description of why Metal is unavailable
213        reason: String,
214    },
215
216    /// Device initialization failed
217    #[error("device initialization failed: {reason}")]
218    InitializationFailed {
219        /// Description of the initialization failure
220        reason: String,
221    },
222
223    /// Memory allocation failed
224    #[error("memory allocation failed: requested {requested_bytes} bytes")]
225    AllocationFailed {
226        /// Number of bytes requested
227        requested_bytes: usize,
228    },
229
230    /// Tensor operation failed
231    #[error("tensor operation failed: {operation}")]
232    OperationFailed {
233        /// Name of the failed operation
234        operation: String,
235    },
236}
237
238/// Errors related to embedding operations.
239#[derive(Error, Debug)]
240pub enum EmbeddingError {
241    /// Failed to download model from `HuggingFace` Hub
242    #[error("model download failed: {reason}")]
243    DownloadFailed {
244        /// Description of the download failure
245        reason: String,
246    },
247
248    /// Model not found in cache or on `HuggingFace` Hub
249    #[error("model not found: {model_id}")]
250    ModelNotFound {
251        /// `HuggingFace` model ID
252        model_id: String,
253    },
254
255    /// Failed to load tokenizer
256    #[error("tokenizer loading failed: {reason}")]
257    TokenizerFailed {
258        /// Description of the tokenizer failure
259        reason: String,
260    },
261
262    /// Tokenization failed
263    #[error("tokenization failed: {reason}")]
264    TokenizationFailed {
265        /// Description of the tokenization error
266        reason: String,
267    },
268
269    /// Empty input provided to encoding
270    #[error("cannot encode empty text array")]
271    EmptyInput,
272
273    /// Invalid embedding configuration
274    #[error("invalid embedding configuration: {reason}")]
275    InvalidConfig {
276        /// Description of the configuration issue
277        reason: String,
278    },
279
280    /// Embedding dimension mismatch
281    #[error("embedding dimension mismatch: expected {expected}, got {actual}")]
282    DimensionMismatch {
283        /// Expected dimension
284        expected: usize,
285        /// Actual dimension
286        actual: usize,
287    },
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_error_display() {
296        let err = Error::Model(ModelError::FileNotFound {
297            path: std::path::PathBuf::from("/path/to/model.safetensors"),
298        });
299        assert!(err.to_string().contains("model file not found"));
300    }
301
302    #[test]
303    fn test_error_from_io() {
304        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
305        let err: Error = io_err.into();
306        assert!(matches!(err, Error::Io(_)));
307    }
308
309    #[test]
310    fn test_model_error_types() {
311        let err = ModelError::ShapeMismatch {
312            expected: vec![1, 2, 3],
313            actual: vec![1, 2, 4],
314        };
315        assert!(err.to_string().contains("shape mismatch"));
316    }
317
318    #[test]
319    fn test_training_error_types() {
320        let err = TrainingError::InvalidLoRAConfig {
321            reason: "rank must be > 0".to_string(),
322        };
323        assert!(err.to_string().contains("invalid LoRA configuration"));
324
325        let err = TrainingError::InvalidConfig {
326            reason: "invalid".to_string(),
327        };
328        assert!(err.to_string().contains("invalid training configuration"));
329
330        let err = TrainingError::StateError {
331            reason: "state error".to_string(),
332        };
333        assert!(err.to_string().contains("training state error"));
334    }
335
336    #[test]
337    fn test_device_error_types() {
338        let err = DeviceError::MetalUnavailable {
339            reason: "not running on Apple Silicon".to_string(),
340        };
341        assert!(err.to_string().contains("Metal device not available"));
342    }
343
344    #[test]
345    fn test_embedding_error_types() {
346        let err = EmbeddingError::ModelNotFound {
347            model_id: "intfloat/e5-small-v2".to_string(),
348        };
349        assert!(err.to_string().contains("model not found"));
350
351        let err = EmbeddingError::EmptyInput;
352        assert!(err.to_string().contains("cannot encode empty text"));
353
354        let err = EmbeddingError::DimensionMismatch {
355            expected: 384,
356            actual: 768,
357        };
358        assert!(err.to_string().contains("dimension mismatch"));
359    }
360}