axum_derive_error/lib.rs
1//!  
2//!
3//! Proc macro to derive IntoResponse for error types for use with axum.
4//!
5//! Your error type just needs to implement Error (Snafu or thiserror could be useful here), IntoResponse and Debug will be derived for you.
6//! By default errors will return a 500 response, but this can be specified with the `#[status = ...]` attribute.
7//!
8//! ## Example:
9//! ```rust
10//! use std::{error::Error, fmt::Display};
11//! use axum_derive_error::ErrorResponse;
12//! use axum::http::StatusCode;
13//!
14//! #[derive(ErrorResponse)]
15//! pub enum CreateUserError {
16//! /// No status provided, so this will return a 500 error.
17//! /// All 5xx errors will not display their message to the user, but will produce a tracing::error log
18//! InsertUserToDb(sqlx::Error),
19//!
20//! /// 422 returned as the response, the display implementation will be used as a message for the user
21//! #[status(StatusCode::UNPROCESSABLE_ENTITY)]
22//! InvalidBody(String),
23//! }
24//!
25//! impl Error for CreateUserError {
26//! fn source(&self) -> Option<&(dyn Error + 'static)> {
27//! match self {
28//! Self::InsertUserToDb(source) => Some(source),
29//! Self::InvalidBody(_) => None,
30//! }
31//! }
32//! }
33//!
34//! impl Display for CreateUserError {
35//! fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36//! match self {
37//! Self::InsertUserToDb(source) => write!(f, "failed to insert user into the database"),
38//! Self::InvalidBody(message) => write!(f, "body is invalid: {message}"),
39//! }
40//! }
41//! }
42//! ```
43//!
44//! ## License
45//!
46//! Licensed under either of
47//!
48//! * Apache License, Version 2.0
49//! ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
50//! * MIT license
51//! ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
52//!
53//! at your option.
54//!
55//! ## Contribution
56//!
57//! Unless you explicitly state otherwise, any contribution intentionally submitted
58//! for inclusion in the work by you, as defined in the Apache-2.0 license, shall be
59//! dual licensed as above, without any additional terms or conditions.
60
61use proc_macro2::{Ident, TokenStream};
62use quote::quote;
63use syn::{parse_macro_input, Data, DataEnum, DeriveInput};
64
65#[proc_macro_derive(ErrorResponse, attributes(status))]
66pub fn derive_error_response(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
67 let input = parse_macro_input!(tokens as DeriveInput);
68 let ident = input.ident;
69
70 match input.data {
71 Data::Union(_) => panic!("cannot derive ErrorResponse for unions"),
72 Data::Struct(_) => panic!("cannot derive ErrorResponse for structs yet"),
73 Data::Enum(enum_data) => derive_error_response_for_enum(ident, enum_data).into(),
74 }
75}
76
77fn derive_error_response_for_enum(ident: Ident, enum_data: DataEnum) -> TokenStream {
78 let status_codes = enum_data.variants.into_iter().map(|variant| {
79 let variant_name = variant.ident;
80 let attr = variant.attrs.into_iter().find(|attr| {
81 attr.path
82 .get_ident()
83 .map(|ident| *ident == "status")
84 .unwrap_or_default()
85 });
86 let match_fields = match variant.fields {
87 syn::Fields::Named(_) => quote!({..}),
88 syn::Fields::Unnamed(fields) => {
89 let fields = fields.unnamed.into_iter().map(|_| quote!(_));
90 quote!{
91 (#(#fields,)*)
92 }
93 },
94 syn::Fields::Unit => quote! {},
95 };
96 match attr {
97 Some(attr) => {
98 let status = attr.tokens;
99 quote! {
100 Self::#variant_name #match_fields => {
101 #[allow(unused_parens)]
102 #status
103 }
104 }
105 },
106 None => {
107 quote! { Self::#variant_name #match_fields => ::axum::http::StatusCode::INTERNAL_SERVER_ERROR }
108 }
109 }
110 });
111
112 quote! {
113 impl #ident {
114 fn status_code(&self) -> ::axum::http::StatusCode {
115 match self {
116 #(#status_codes,)*
117 }
118 }
119 }
120
121 impl ::axum::response::IntoResponse for #ident {
122 fn into_response(self) -> ::axum::response::Response {
123 let status = self.status_code();
124 let mut error_message = self.to_string();
125
126 if status.is_server_error() {
127 ::tracing::error!(error_message, error_details = ?self, "internal server error");
128 error_message = "Internal server error".to_string()
129 }
130
131 let body = ::axum::Json(::serde_json::json!({
132 "code": status.as_u16(),
133 "error": error_message,
134 }));
135
136 ::axum::response::IntoResponse::into_response((status, body))
137 }
138 }
139
140 impl ::std::fmt::Debug for #ident {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 writeln!(f, "{}\n", self)?;
143 let mut current = ::std::error::Error::source(self);
144 while let Some(cause) = current {
145 writeln!(f, "Caused by:\n\t{}", cause)?;
146 current = ::std::error::Error::source(cause);
147 }
148 Ok(())
149 }
150 }
151 }
152}