1use std::fmt;
2
3#[derive(Debug)]
5pub enum LLMError {
6 HttpError(String),
8 AuthError(String),
10 InvalidRequest(String),
12 ProviderError(String),
14 ResponseFormatError {
16 message: String,
17 raw_response: String,
18 },
19 Generic(String),
21 JsonError(String),
23 ToolConfigError(String),
25}
26
27impl fmt::Display for LLMError {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match self {
30 LLMError::HttpError(e) => write!(f, "HTTP Error: {e}"),
31 LLMError::AuthError(e) => write!(f, "Auth Error: {e}"),
32 LLMError::InvalidRequest(e) => write!(f, "Invalid Request: {e}"),
33 LLMError::ProviderError(e) => write!(f, "Provider Error: {e}"),
34 LLMError::Generic(e) => write!(f, "Generic Error : {e}"),
35 LLMError::ResponseFormatError {
36 message,
37 raw_response,
38 } => {
39 write!(
40 f,
41 "Response Format Error: {message}. Raw response: {raw_response}"
42 )
43 }
44 LLMError::JsonError(e) => write!(f, "JSON Parse Error: {e}"),
45 LLMError::ToolConfigError(e) => write!(f, "Tool Configuration Error: {e}"),
46 }
47 }
48}
49
50impl std::error::Error for LLMError {}
51
52impl From<reqwest::Error> for LLMError {
54 fn from(err: reqwest::Error) -> Self {
55 LLMError::HttpError(err.to_string())
56 }
57}
58
59impl From<serde_json::Error> for LLMError {
60 fn from(err: serde_json::Error) -> Self {
61 LLMError::JsonError(format!(
62 "{} at line {} column {}",
63 err,
64 err.line(),
65 err.column()
66 ))
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use serde_json::Error as JsonError;
74 use std::error::Error;
75
76 #[test]
77 fn test_llm_error_display_http_error() {
78 let error = LLMError::HttpError("Connection failed".to_string());
79 assert_eq!(error.to_string(), "HTTP Error: Connection failed");
80 }
81
82 #[test]
83 fn test_llm_error_display_auth_error() {
84 let error = LLMError::AuthError("Invalid API key".to_string());
85 assert_eq!(error.to_string(), "Auth Error: Invalid API key");
86 }
87
88 #[test]
89 fn test_llm_error_display_invalid_request() {
90 let error = LLMError::InvalidRequest("Missing required parameter".to_string());
91 assert_eq!(
92 error.to_string(),
93 "Invalid Request: Missing required parameter"
94 );
95 }
96
97 #[test]
98 fn test_llm_error_display_provider_error() {
99 let error = LLMError::ProviderError("Model not found".to_string());
100 assert_eq!(error.to_string(), "Provider Error: Model not found");
101 }
102
103 #[test]
104 fn test_llm_error_display_generic_error() {
105 let error = LLMError::Generic("Something went wrong".to_string());
106 assert_eq!(error.to_string(), "Generic Error : Something went wrong");
107 }
108
109 #[test]
110 fn test_llm_error_display_response_format_error() {
111 let error = LLMError::ResponseFormatError {
112 message: "Invalid JSON".to_string(),
113 raw_response: "{invalid json}".to_string(),
114 };
115 assert_eq!(
116 error.to_string(),
117 "Response Format Error: Invalid JSON. Raw response: {invalid json}"
118 );
119 }
120
121 #[test]
122 fn test_llm_error_display_json_error() {
123 let error = LLMError::JsonError("Parse error at line 5 column 10".to_string());
124 assert_eq!(
125 error.to_string(),
126 "JSON Parse Error: Parse error at line 5 column 10"
127 );
128 }
129
130 #[test]
131 fn test_llm_error_display_tool_config_error() {
132 let error = LLMError::ToolConfigError("Invalid tool configuration".to_string());
133 assert_eq!(
134 error.to_string(),
135 "Tool Configuration Error: Invalid tool configuration"
136 );
137 }
138
139 #[test]
140 fn test_llm_error_is_error_trait() {
141 let error = LLMError::Generic("test error".to_string());
142 assert!(error.source().is_none());
143 }
144
145 #[test]
146 fn test_llm_error_debug_format() {
147 let error = LLMError::HttpError("test".to_string());
148 let debug_str = format!("{error:?}");
149 assert!(debug_str.contains("HttpError"));
150 assert!(debug_str.contains("test"));
151 }
152
153 #[test]
154 fn test_from_reqwest_error() {
155 let client = reqwest::Client::new();
157 let rt = tokio::runtime::Runtime::new().unwrap();
158 let reqwest_error = rt
159 .block_on(async {
160 client
161 .get("http://invalid-url-that-does-not-exist-12345.com/")
162 .timeout(std::time::Duration::from_millis(100))
163 .send()
164 .await
165 })
166 .unwrap_err();
167
168 let llm_error: LLMError = reqwest_error.into();
169
170 match llm_error {
171 LLMError::HttpError(msg) => {
172 assert!(!msg.is_empty());
173 }
174 _ => panic!("Expected HttpError"),
175 }
176 }
177
178 #[test]
179 fn test_from_serde_json_error() {
180 let json_str = r#"{"invalid": json}"#;
181 let json_error: JsonError =
182 serde_json::from_str::<serde_json::Value>(json_str).unwrap_err();
183
184 let llm_error: LLMError = json_error.into();
185
186 match llm_error {
187 LLMError::JsonError(msg) => {
188 assert!(msg.contains("line"));
189 assert!(msg.contains("column"));
190 }
191 _ => panic!("Expected JsonError"),
192 }
193 }
194
195 #[test]
196 fn test_error_variants_equality() {
197 let error1 = LLMError::HttpError("test".to_string());
198 let error2 = LLMError::HttpError("test".to_string());
199 let error3 = LLMError::HttpError("different".to_string());
200 let error4 = LLMError::AuthError("test".to_string());
201
202 assert_eq!(error1.to_string(), error2.to_string());
204 assert_ne!(error1.to_string(), error3.to_string());
205 assert_ne!(error1.to_string(), error4.to_string());
206 }
207
208 #[test]
209 fn test_response_format_error_fields() {
210 let error = LLMError::ResponseFormatError {
211 message: "Parse failed".to_string(),
212 raw_response: "raw content".to_string(),
213 };
214
215 let display_str = error.to_string();
216 assert!(display_str.contains("Parse failed"));
217 assert!(display_str.contains("raw content"));
218 }
219
220 #[test]
221 fn test_all_error_variants_have_display() {
222 let errors = vec![
223 LLMError::HttpError("http".to_string()),
224 LLMError::AuthError("auth".to_string()),
225 LLMError::InvalidRequest("invalid".to_string()),
226 LLMError::ProviderError("provider".to_string()),
227 LLMError::Generic("generic".to_string()),
228 LLMError::ResponseFormatError {
229 message: "format".to_string(),
230 raw_response: "raw".to_string(),
231 },
232 LLMError::JsonError("json".to_string()),
233 LLMError::ToolConfigError("tool".to_string()),
234 ];
235
236 for error in errors {
237 let display_str = error.to_string();
238 assert!(!display_str.is_empty());
239 }
240 }
241
242 #[test]
243 fn test_error_type_classification() {
244 let http_error = LLMError::HttpError("test".to_string());
246 match http_error {
247 LLMError::HttpError(_) => {}
248 _ => panic!("Expected HttpError"),
249 }
250
251 let auth_error = LLMError::AuthError("test".to_string());
252 match auth_error {
253 LLMError::AuthError(_) => {}
254 _ => panic!("Expected AuthError"),
255 }
256
257 let response_error = LLMError::ResponseFormatError {
258 message: "test".to_string(),
259 raw_response: "test".to_string(),
260 };
261 match response_error {
262 LLMError::ResponseFormatError { .. } => {}
263 _ => panic!("Expected ResponseFormatError"),
264 }
265 }
266}