axum_derive_error/
lib.rs

1//! ![Crates.io](https://img.shields.io/crates/l/axum-derive-error) ![Crates.io](https://img.shields.io/crates/v/axum-derive-error)
2//!
3//! Proc macro to derive IntoResponse for error types for use with axum.
4//!
5//! Your error type just needs to implement Error (Snafu or thiserror could be useful here), IntoResponse and Debug will be derived for you.
6//! By default errors will return a 500 response, but this can be specified with the `#[status = ...]` attribute.
7//!
8//! ## Example:
9//! ```rust
10//! use std::{error::Error, fmt::Display};
11//! use axum_derive_error::ErrorResponse;
12//! use axum::http::StatusCode;
13//!
14//! #[derive(ErrorResponse)]
15//! pub enum CreateUserError {
16//!     /// No status provided, so this will return a 500 error.
17//!     /// All 5xx errors will not display their message to the user, but will produce a tracing::error log
18//!     InsertUserToDb(sqlx::Error),
19//!
20//!     /// 422 returned as the response, the display implementation will be used as a message for the user
21//!     #[status(StatusCode::UNPROCESSABLE_ENTITY)]
22//!     InvalidBody(String),
23//! }
24//!
25//! impl Error for CreateUserError {
26//!     fn source(&self) -> Option<&(dyn Error + 'static)> {
27//!         match self {
28//!             Self::InsertUserToDb(source) => Some(source),
29//!             Self::InvalidBody(_) => None,
30//!         }
31//!     }
32//! }
33//!
34//! impl Display for CreateUserError {
35//!     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36//!         match self {
37//!             Self::InsertUserToDb(source) => write!(f, "failed to insert user into the database"),
38//!             Self::InvalidBody(message) => write!(f, "body is invalid: {message}"),
39//!         }
40//!     }
41//! }
42//! ```
43//!
44//! ## License
45//!
46//! Licensed under either of
47//!
48//!  * Apache License, Version 2.0
49//!    ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
50//!  * MIT license
51//!    ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
52//!
53//! at your option.
54//!
55//! ## Contribution
56//!
57//! Unless you explicitly state otherwise, any contribution intentionally submitted
58//! for inclusion in the work by you, as defined in the Apache-2.0 license, shall be
59//! dual licensed as above, without any additional terms or conditions.
60
61use proc_macro2::{Ident, TokenStream};
62use quote::quote;
63use syn::{parse_macro_input, Data, DataEnum, DeriveInput};
64
65#[proc_macro_derive(ErrorResponse, attributes(status))]
66pub fn derive_error_response(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
67    let input = parse_macro_input!(tokens as DeriveInput);
68    let ident = input.ident;
69
70    match input.data {
71        Data::Union(_) => panic!("cannot derive ErrorResponse for unions"),
72        Data::Struct(_) => panic!("cannot derive ErrorResponse for structs yet"),
73        Data::Enum(enum_data) => derive_error_response_for_enum(ident, enum_data).into(),
74    }
75}
76
77fn derive_error_response_for_enum(ident: Ident, enum_data: DataEnum) -> TokenStream {
78    let status_codes = enum_data.variants.into_iter().map(|variant| {
79        let variant_name = variant.ident;
80        let attr = variant.attrs.into_iter().find(|attr| {
81            attr.path
82                .get_ident()
83                .map(|ident| *ident == "status")
84                .unwrap_or_default()
85        });
86        let match_fields = match variant.fields {
87            syn::Fields::Named(_) => quote!({..}),
88            syn::Fields::Unnamed(fields) => {
89                let fields = fields.unnamed.into_iter().map(|_| quote!(_));
90                quote!{
91                    (#(#fields,)*)
92                }
93            },
94            syn::Fields::Unit => quote! {},
95        };
96        match attr {
97            Some(attr) => {
98                let status = attr.tokens;
99                quote! {
100                    Self::#variant_name #match_fields => {
101                        #[allow(unused_parens)]
102                        #status
103                    }
104                }
105            },
106            None => {
107                quote! { Self::#variant_name #match_fields => ::axum::http::StatusCode::INTERNAL_SERVER_ERROR }
108            }
109        }
110    });
111
112    quote! {
113        impl #ident {
114            fn status_code(&self) -> ::axum::http::StatusCode {
115                match self {
116                    #(#status_codes,)*
117                }
118            }
119        }
120
121        impl ::axum::response::IntoResponse for #ident {
122            fn into_response(self) -> ::axum::response::Response {
123                let status = self.status_code();
124                let mut error_message = self.to_string();
125
126                if status.is_server_error() {
127                    ::tracing::error!(error_message, error_details = ?self, "internal server error");
128                    error_message = "Internal server error".to_string()
129                }
130
131                let body = ::axum::Json(::serde_json::json!({
132                    "code": status.as_u16(),
133                    "error": error_message,
134                }));
135
136                ::axum::response::IntoResponse::into_response((status, body))
137            }
138        }
139
140        impl ::std::fmt::Debug for #ident {
141            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142                writeln!(f, "{}\n", self)?;
143                let mut current = ::std::error::Error::source(self);
144                while let Some(cause) = current {
145                    writeln!(f, "Caused by:\n\t{}", cause)?;
146                    current = ::std::error::Error::source(cause);
147                }
148                Ok(())
149            }
150        }
151    }
152}