1use proc_macro2::{Ident, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::{
4 parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DeriveInput,
5 Expr, LitInt, Meta, Path, Variant,
6};
7
8#[proc_macro_derive(ErrorStatus, attributes(status))]
37pub fn derive_error_status(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38 let ast: DeriveInput = parse_macro_input!(input);
39 let enum_ident = ast.ident;
40 let cases: Punctuated<TokenStream, Comma> = match ast.data {
41 Data::Enum(data) => data.variants,
42 _ => panic!(
43 "#[derive(ErrorStatus)] is only available for enums, other types are not supported."
44 ),
45 }
46 .iter()
47 .map(impl_enum_variant)
48 .collect();
49
50 quote! {
51 impl axum::response::IntoResponse for #enum_ident {
52 fn into_response(self) -> axum::response::Response {
53 match self {
54 #cases
55 }
56 }
57 }
58 }
59 .into()
60}
61
62fn impl_enum_variant(input: &Variant) -> TokenStream {
63 let status_code = find_status_code(input);
64 let case = if input.fields.is_empty() {
65 case_empty_fields(input)
66 } else if input.fields.iter().filter(|f| f.ident.is_none()).count() > 0 {
67 case_unnamed_fields(input)
68 } else {
69 case_named_fields(input)
70 };
71 quote! {
72 Self::#case => (#status_code, format!("{}", self)).into_response()
73 }
74}
75
76fn case_empty_fields(input: &Variant) -> TokenStream {
77 let ident = &input.ident;
78 quote!(#ident)
79}
80
81fn case_unnamed_fields(input: &Variant) -> TokenStream {
82 let ident = &input.ident;
83 quote!(#ident( .. ))
84}
85
86fn case_named_fields(input: &Variant) -> TokenStream {
87 let ident = &input.ident;
88 quote!(#ident { .. })
89}
90
91fn find_status_code(input: &Variant) -> TokenStream {
92 match input
93 .attrs
94 .iter()
95 .find(|attr| attr.path().is_ident("status"))
96 {
97 Some(attr) => match &attr.meta {
98 Meta::List(l) => {
99 if let Ok(number) = l.parse_args::<LitInt>() {
100 quote! {
101 axum::http::StatusCode::from_u16(#number as u16).unwrap()
102 }
103 } else if let Ok(expr) = l.parse_args::<Path>() {
104 quote! {
105 #expr
106 }
107 } else {
108 quote_spanned!(l.span() => compile_error!("Only #[status(StatusCode)] or #[status(u16)] syntaxe is supported"))
109 }
110 }
111 _ => {
112 quote_spanned! { attr.span() => compile_error!("Only #[status(StatusCode)] or #[status(u16)] syntaxe is supported") }
113 }
114 },
115 None => {
116 quote_spanned! { input.span() => compile_error!("Each enum variant should have a status code provided using the #[status()] attribute") }
117 }
118 }
119}