Skip to main content

axum_reject_macro/
lib.rs

1use darling::ast::{self, Fields};
2use darling::{util, FromDeriveInput, FromVariant};
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{Ident, Type};
6
7#[derive(Debug, FromVariant)]
8#[darling(attributes(http_error))]
9struct HttpErrorVariant {
10    ident: Ident,
11    fields: Fields<Type>,
12    status: Ident,
13    message: String,
14}
15
16#[derive(Debug, FromDeriveInput)]
17#[darling(attributes(http_error))]
18struct HttpError {
19    ident: Ident,
20    generics: syn::Generics,
21    data: ast::Data<HttpErrorVariant, util::Ignored>,
22}
23
24impl ToTokens for HttpError {
25    fn to_tokens(&self, tokens: &mut TokenStream) {
26        let ident = &self.ident;
27        let generics = &self.generics;
28        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
29
30        let match_arms = self.data.as_ref()
31            .take_enum()
32            .expect("Should be an enum")
33            .into_iter()
34            .map(|variant| {
35                let ident = &variant.ident;
36                let status = &variant.status;
37                let message = &variant.message;
38                let field = variant.fields.iter().map(|_| quote! { _ }).collect::<Vec<_>>();
39
40                if field.is_empty() {
41                    quote! {
42                        Self::#ident => (axum::http::StatusCode::#status, format!(r#"{{"error": "{}"}}"#, #message).to_string()).into_response()
43                    }
44                } else {
45                    quote! {
46                        Self::#ident(#(#field),*) => (axum::http::StatusCode::#status, format!(r#"{{"error": "{}"}}"#, #message).to_string()).into_response()
47                    }
48                }
49            });
50
51        tokens.extend(quote! {
52            impl #impl_generics axum::response::IntoResponse for #ident #ty_generics #where_clause {
53                fn into_response(self) -> axum::response::Response {
54                    match self {
55                        #(#match_arms),*
56                    }
57                }
58            }
59        });
60    }
61}
62
63#[proc_macro_derive(HttpError, attributes(http_error))]
64pub fn derive_http_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
65    let input = syn::parse_macro_input!(input as syn::DeriveInput);
66    let http_error = HttpError::from_derive_input(&input).unwrap();
67    http_error.into_token_stream().into()
68}