ds_r1_rs/utils/
error.rs

1//! # Error Handling
2//!
3//! Comprehensive error types and handling for the DeepSeek R1 implementation.
4
5use thiserror::Error;
6
7/// Main error type for the DeepSeek R1 implementation
8#[derive(Debug, Error)]
9pub enum ModelError {
10    #[error("Configuration error: {0}")]
11    Config(String),
12
13    #[error("Tokenization error: {0}")]
14    Tokenization(String),
15
16    #[error("Forward pass error: {0}")]
17    Forward(String),
18
19    #[error("Training error: {0}")]
20    Training(String),
21
22    #[error("Inference error: {0}")]
23    Inference(String),
24
25    #[error("Evaluation error: {0}")]
26    Evaluation(String),
27
28    #[error("Mathematical operation error: {0}")]
29    Math(String),
30
31    #[error("IO error: {0}")]
32    Io(#[from] std::io::Error),
33
34    #[error("JSON serialization error: {0}")]
35    Json(#[from] serde_json::Error),
36
37    #[error("Invalid input: {0}")]
38    InvalidInput(String),
39
40    #[error("Not implemented: {0}")]
41    NotImplemented(String),
42}
43
44/// Result type alias for convenience
45pub type Result<T> = std::result::Result<T, ModelError>;
46
47impl ModelError {
48    /// Create a configuration error
49    pub fn config<S: Into<String>>(msg: S) -> Self {
50        Self::Config(msg.into())
51    }
52
53    /// Create a forward pass error
54    pub fn forward<S: Into<String>>(msg: S) -> Self {
55        Self::Forward(msg.into())
56    }
57
58    /// Create a training error
59    pub fn training<S: Into<String>>(msg: S) -> Self {
60        Self::Training(msg.into())
61    }
62
63    /// Create an inference error
64    pub fn inference<S: Into<String>>(msg: S) -> Self {
65        Self::Inference(msg.into())
66    }
67
68    /// Create a not implemented error
69    pub fn not_implemented<S: Into<String>>(msg: S) -> Self {
70        Self::NotImplemented(msg.into())
71    }
72
73    /// Check if error is recoverable
74    pub fn is_recoverable(&self) -> bool {
75        match self {
76            Self::Config(_) | Self::InvalidInput(_) => false,
77            Self::Io(_) | Self::Json(_) => false,
78            Self::Tokenization(_)
79            | Self::Forward(_)
80            | Self::Training(_)
81            | Self::Inference(_)
82            | Self::Evaluation(_)
83            | Self::Math(_) => true,
84            Self::NotImplemented(_) => false,
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_error_creation() {
95        let config_err = ModelError::config("test config error");
96        assert!(matches!(config_err, ModelError::Config(_)));
97
98        let forward_err = ModelError::forward("test forward error");
99        assert!(matches!(forward_err, ModelError::Forward(_)));
100
101        let training_err = ModelError::training("test training error");
102        assert!(matches!(training_err, ModelError::Training(_)));
103    }
104
105    #[test]
106    fn test_error_recoverability() {
107        let config_err = ModelError::config("test");
108        assert!(!config_err.is_recoverable());
109
110        let forward_err = ModelError::forward("test");
111        assert!(forward_err.is_recoverable());
112
113        let not_impl_err = ModelError::not_implemented("test");
114        assert!(!not_impl_err.is_recoverable());
115    }
116
117    #[test]
118    fn test_error_display() {
119        let err = ModelError::config("test message");
120        let display = format!("{}", err);
121        assert!(display.contains("Configuration error"));
122        assert!(display.contains("test message"));
123    }
124}