1use axum::{
9 Json,
10 http::StatusCode,
11 response::{IntoResponse, Response},
12};
13use serde::Serialize;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
17#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
18pub enum ErrorCode {
19 ValidationError,
21 ParseError,
23 RequestError,
25 Unauthenticated,
27 Forbidden,
29 InternalServerError,
31 DatabaseError,
33 Timeout,
35 RateLimitExceeded,
37 NotFound,
39 Conflict,
41}
42
43impl ErrorCode {
44 #[must_use]
46 pub fn status_code(self) -> StatusCode {
47 match self {
48 Self::ValidationError | Self::ParseError | Self::RequestError => {
49 StatusCode::BAD_REQUEST
50 },
51 Self::Unauthenticated => StatusCode::UNAUTHORIZED,
52 Self::Forbidden => StatusCode::FORBIDDEN,
53 Self::NotFound => StatusCode::NOT_FOUND,
54 Self::Conflict => StatusCode::CONFLICT,
55 Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
56 Self::Timeout => StatusCode::REQUEST_TIMEOUT,
57 Self::InternalServerError | Self::DatabaseError => StatusCode::INTERNAL_SERVER_ERROR,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize)]
64pub struct ErrorLocation {
65 pub line: usize,
67 pub column: usize,
69}
70
71#[derive(Debug, Clone, Serialize)]
73pub struct GraphQLError {
74 pub message: String,
76
77 pub code: ErrorCode,
79
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub locations: Option<Vec<ErrorLocation>>,
83
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub path: Option<Vec<String>>,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub extensions: Option<ErrorExtensions>,
91}
92
93#[derive(Debug, Clone, Serialize)]
95pub struct ErrorExtensions {
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub category: Option<String>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub status: Option<u16>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub request_id: Option<String>,
107}
108
109#[derive(Debug, Serialize)]
111pub struct ErrorResponse {
112 pub errors: Vec<GraphQLError>,
114}
115
116impl GraphQLError {
117 pub fn new(message: impl Into<String>, code: ErrorCode) -> Self {
119 Self {
120 message: message.into(),
121 code,
122 locations: None,
123 path: None,
124 extensions: None,
125 }
126 }
127
128 #[must_use]
130 pub fn with_location(mut self, line: usize, column: usize) -> Self {
131 self.locations = Some(vec![ErrorLocation { line, column }]);
132 self
133 }
134
135 #[must_use]
137 pub fn with_path(mut self, path: Vec<String>) -> Self {
138 self.path = Some(path);
139 self
140 }
141
142 #[must_use]
144 pub fn with_extensions(mut self, extensions: ErrorExtensions) -> Self {
145 self.extensions = Some(extensions);
146 self
147 }
148
149 #[must_use]
151 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
152 let request_id = request_id.into();
153 let extensions = self.extensions.take().unwrap_or(ErrorExtensions {
154 category: None,
155 status: None,
156 request_id: None,
157 });
158
159 self.extensions = Some(ErrorExtensions {
160 request_id: Some(request_id),
161 ..extensions
162 });
163 self
164 }
165
166 pub fn validation(message: impl Into<String>) -> Self {
168 Self::new(message, ErrorCode::ValidationError)
169 }
170
171 pub fn parse(message: impl Into<String>) -> Self {
173 Self::new(message, ErrorCode::ParseError)
174 }
175
176 pub fn request(message: impl Into<String>) -> Self {
178 Self::new(message, ErrorCode::RequestError)
179 }
180
181 pub fn database(message: impl Into<String>) -> Self {
183 Self::new(message, ErrorCode::DatabaseError)
184 }
185
186 pub fn internal(message: impl Into<String>) -> Self {
188 Self::new(message, ErrorCode::InternalServerError)
189 }
190
191 #[must_use]
193 pub fn execution(message: &str) -> Self {
194 Self::new(message, ErrorCode::InternalServerError)
195 }
196
197 #[must_use]
199 pub fn unauthenticated() -> Self {
200 Self::new("Authentication required", ErrorCode::Unauthenticated)
201 }
202
203 #[must_use]
205 pub fn forbidden() -> Self {
206 Self::new("Access denied", ErrorCode::Forbidden)
207 }
208
209 pub fn not_found(message: impl Into<String>) -> Self {
211 Self::new(message, ErrorCode::NotFound)
212 }
213
214 pub fn timeout(operation: impl Into<String>) -> Self {
216 Self::new(format!("{} exceeded timeout", operation.into()), ErrorCode::Timeout)
217 }
218
219 pub fn rate_limited(message: impl Into<String>) -> Self {
221 Self::new(message, ErrorCode::RateLimitExceeded)
222 }
223}
224
225impl ErrorResponse {
226 #[must_use]
228 pub fn new(errors: Vec<GraphQLError>) -> Self {
229 Self { errors }
230 }
231
232 #[must_use]
234 pub fn from_error(error: GraphQLError) -> Self {
235 Self {
236 errors: vec![error],
237 }
238 }
239}
240
241impl IntoResponse for ErrorResponse {
242 fn into_response(self) -> Response {
243 let status = self
244 .errors
245 .first()
246 .map_or(StatusCode::INTERNAL_SERVER_ERROR, |e| e.code.status_code());
247
248 (status, Json(self)).into_response()
249 }
250}
251
252impl From<GraphQLError> for ErrorResponse {
253 fn from(error: GraphQLError) -> Self {
254 Self::from_error(error)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_error_serialization() {
264 let error = GraphQLError::validation("Invalid query")
265 .with_location(1, 5)
266 .with_path(vec!["user".to_string(), "id".to_string()]);
267
268 let json = serde_json::to_string(&error).unwrap();
269 assert!(json.contains("Invalid query"));
270 assert!(json.contains("VALIDATION_ERROR"));
271 assert!(json.contains("\"line\":1"));
272 }
273
274 #[test]
275 fn test_error_response_serialization() {
276 let response = ErrorResponse::new(vec![
277 GraphQLError::validation("Field not found"),
278 GraphQLError::database("Connection timeout"),
279 ]);
280
281 let json = serde_json::to_string(&response).unwrap();
282 assert!(json.contains("Field not found"));
283 assert!(json.contains("Connection timeout"));
284 }
285
286 #[test]
287 fn test_error_code_status_codes() {
288 assert_eq!(ErrorCode::ValidationError.status_code(), StatusCode::BAD_REQUEST);
289 assert_eq!(ErrorCode::Unauthenticated.status_code(), StatusCode::UNAUTHORIZED);
290 assert_eq!(ErrorCode::Forbidden.status_code(), StatusCode::FORBIDDEN);
291 assert_eq!(ErrorCode::DatabaseError.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
292 }
293
294 #[test]
295 fn test_error_extensions() {
296 let extensions = ErrorExtensions {
297 category: Some("VALIDATION".to_string()),
298 status: Some(400),
299 request_id: Some("req-123".to_string()),
300 };
301
302 let error = GraphQLError::validation("Invalid").with_extensions(extensions);
303 let json = serde_json::to_string(&error).unwrap();
304 assert!(json.contains("VALIDATION"));
305 assert!(json.contains("req-123"));
306 }
307}