1use thiserror::Error;
6
7#[derive(Debug, Error)]
9pub enum ModelError {
10 #[error("Configuration error: {0}")]
11 Config(String),
12
13 #[error("Tokenization error: {0}")]
14 Tokenization(String),
15
16 #[error("Forward pass error: {0}")]
17 Forward(String),
18
19 #[error("Training error: {0}")]
20 Training(String),
21
22 #[error("Inference error: {0}")]
23 Inference(String),
24
25 #[error("Evaluation error: {0}")]
26 Evaluation(String),
27
28 #[error("Mathematical operation error: {0}")]
29 Math(String),
30
31 #[error("IO error: {0}")]
32 Io(#[from] std::io::Error),
33
34 #[error("JSON serialization error: {0}")]
35 Json(#[from] serde_json::Error),
36
37 #[error("Invalid input: {0}")]
38 InvalidInput(String),
39
40 #[error("Not implemented: {0}")]
41 NotImplemented(String),
42}
43
44pub type Result<T> = std::result::Result<T, ModelError>;
46
47impl ModelError {
48 pub fn config<S: Into<String>>(msg: S) -> Self {
50 Self::Config(msg.into())
51 }
52
53 pub fn forward<S: Into<String>>(msg: S) -> Self {
55 Self::Forward(msg.into())
56 }
57
58 pub fn training<S: Into<String>>(msg: S) -> Self {
60 Self::Training(msg.into())
61 }
62
63 pub fn inference<S: Into<String>>(msg: S) -> Self {
65 Self::Inference(msg.into())
66 }
67
68 pub fn not_implemented<S: Into<String>>(msg: S) -> Self {
70 Self::NotImplemented(msg.into())
71 }
72
73 pub fn is_recoverable(&self) -> bool {
75 match self {
76 Self::Config(_) | Self::InvalidInput(_) => false,
77 Self::Io(_) | Self::Json(_) => false,
78 Self::Tokenization(_)
79 | Self::Forward(_)
80 | Self::Training(_)
81 | Self::Inference(_)
82 | Self::Evaluation(_)
83 | Self::Math(_) => true,
84 Self::NotImplemented(_) => false,
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_error_creation() {
95 let config_err = ModelError::config("test config error");
96 assert!(matches!(config_err, ModelError::Config(_)));
97
98 let forward_err = ModelError::forward("test forward error");
99 assert!(matches!(forward_err, ModelError::Forward(_)));
100
101 let training_err = ModelError::training("test training error");
102 assert!(matches!(training_err, ModelError::Training(_)));
103 }
104
105 #[test]
106 fn test_error_recoverability() {
107 let config_err = ModelError::config("test");
108 assert!(!config_err.is_recoverable());
109
110 let forward_err = ModelError::forward("test");
111 assert!(forward_err.is_recoverable());
112
113 let not_impl_err = ModelError::not_implemented("test");
114 assert!(!not_impl_err.is_recoverable());
115 }
116
117 #[test]
118 fn test_error_display() {
119 let err = ModelError::config("test message");
120 let display = format!("{}", err);
121 assert!(display.contains("Configuration error"));
122 assert!(display.contains("test message"));
123 }
124}