use axum::{
Json,
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Serialize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ErrorCode {
ValidationError,
ParseError,
RequestError,
Unauthenticated,
Forbidden,
InternalServerError,
DatabaseError,
Timeout,
RateLimitExceeded,
NotFound,
Conflict,
}
impl ErrorCode {
#[must_use]
pub fn status_code(self) -> StatusCode {
match self {
Self::ValidationError | Self::ParseError | Self::RequestError => {
StatusCode::BAD_REQUEST
},
Self::Unauthenticated => StatusCode::UNAUTHORIZED,
Self::Forbidden => StatusCode::FORBIDDEN,
Self::NotFound => StatusCode::NOT_FOUND,
Self::Conflict => StatusCode::CONFLICT,
Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
Self::Timeout => StatusCode::REQUEST_TIMEOUT,
Self::InternalServerError | Self::DatabaseError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ErrorLocation {
pub line: usize,
pub column: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct GraphQLError {
pub message: String,
pub code: ErrorCode,
#[serde(skip_serializing_if = "Option::is_none")]
pub locations: Option<Vec<ErrorLocation>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extensions: Option<ErrorExtensions>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ErrorExtensions {
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub errors: Vec<GraphQLError>,
}
impl GraphQLError {
pub fn new(message: impl Into<String>, code: ErrorCode) -> Self {
Self {
message: message.into(),
code,
locations: None,
path: None,
extensions: None,
}
}
#[must_use]
pub fn with_location(mut self, line: usize, column: usize) -> Self {
self.locations = Some(vec![ErrorLocation { line, column }]);
self
}
#[must_use]
pub fn with_path(mut self, path: Vec<String>) -> Self {
self.path = Some(path);
self
}
#[must_use]
pub fn with_extensions(mut self, extensions: ErrorExtensions) -> Self {
self.extensions = Some(extensions);
self
}
#[must_use]
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
let request_id = request_id.into();
let extensions = self.extensions.take().unwrap_or(ErrorExtensions {
category: None,
status: None,
request_id: None,
});
self.extensions = Some(ErrorExtensions {
request_id: Some(request_id),
..extensions
});
self
}
pub fn validation(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::ValidationError)
}
pub fn parse(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::ParseError)
}
pub fn request(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::RequestError)
}
pub fn database(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::DatabaseError)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::InternalServerError)
}
#[must_use]
pub fn execution(message: &str) -> Self {
Self::new(message, ErrorCode::InternalServerError)
}
#[must_use]
pub fn unauthenticated() -> Self {
Self::new("Authentication required", ErrorCode::Unauthenticated)
}
#[must_use]
pub fn forbidden() -> Self {
Self::new("Access denied", ErrorCode::Forbidden)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::NotFound)
}
pub fn timeout(operation: impl Into<String>) -> Self {
Self::new(format!("{} exceeded timeout", operation.into()), ErrorCode::Timeout)
}
pub fn rate_limited(message: impl Into<String>) -> Self {
Self::new(message, ErrorCode::RateLimitExceeded)
}
}
impl ErrorResponse {
#[must_use]
pub fn new(errors: Vec<GraphQLError>) -> Self {
Self { errors }
}
#[must_use]
pub fn from_error(error: GraphQLError) -> Self {
Self {
errors: vec![error],
}
}
}
impl IntoResponse for ErrorResponse {
fn into_response(self) -> Response {
let status = self
.errors
.first()
.map_or(StatusCode::INTERNAL_SERVER_ERROR, |e| e.code.status_code());
(status, Json(self)).into_response()
}
}
impl From<GraphQLError> for ErrorResponse {
fn from(error: GraphQLError) -> Self {
Self::from_error(error)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_serialization() {
let error = GraphQLError::validation("Invalid query")
.with_location(1, 5)
.with_path(vec!["user".to_string(), "id".to_string()]);
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("Invalid query"));
assert!(json.contains("VALIDATION_ERROR"));
assert!(json.contains("\"line\":1"));
}
#[test]
fn test_error_response_serialization() {
let response = ErrorResponse::new(vec![
GraphQLError::validation("Field not found"),
GraphQLError::database("Connection timeout"),
]);
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("Field not found"));
assert!(json.contains("Connection timeout"));
}
#[test]
fn test_error_code_status_codes() {
assert_eq!(ErrorCode::ValidationError.status_code(), StatusCode::BAD_REQUEST);
assert_eq!(ErrorCode::Unauthenticated.status_code(), StatusCode::UNAUTHORIZED);
assert_eq!(ErrorCode::Forbidden.status_code(), StatusCode::FORBIDDEN);
assert_eq!(ErrorCode::DatabaseError.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_error_extensions() {
let extensions = ErrorExtensions {
category: Some("VALIDATION".to_string()),
status: Some(400),
request_id: Some("req-123".to_string()),
};
let error = GraphQLError::validation("Invalid").with_extensions(extensions);
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("VALIDATION"));
assert!(json.contains("req-123"));
}
}