1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::Span;
4use proc_macro::TokenStream;
5use proc_macro_error::{abort, proc_macro_error};
6use quote::quote;
7use syn::{parse_macro_input, Attribute, Data, DataEnum, DeriveInput, Ident, Variant};
8
9#[proc_macro_error]
10#[proc_macro_derive(ToResponse, attributes(code, body))]
11pub fn to_http_error_code(item: TokenStream) -> TokenStream {
12 let ast = parse_macro_input!(item as DeriveInput);
13
14 let name = &ast.ident;
15 let Data::Enum(enum_data) = ast.data else {
16 abort!(Span::call_site(), "Only supported for enum");
17 };
18 impl_into_response(name, enum_data).into()
19}
20
21fn impl_into_response(_name: &Ident, enum_data: DataEnum) -> proc_macro2::TokenStream {
22 let _variants: Vec<proc_macro2::TokenStream> = enum_data
23 .variants
24 .iter()
25 .map(|v| make_enum_variant(v))
26 .collect();
27 cfg_if::cfg_if! {
28 if #[cfg(all(feature = "axum", not(feature = "rocket"), not(feature = "actix")))] {
29 quote! {
30 impl axum::response::IntoResponse for #_name {
31 fn into_response(self) -> axum::response::Response {
32 match &self {
33 #(Self::#_variants,)*
34 }
35 }
36 }
37 }
38 } else if #[cfg(all(feature = "rocket", not(feature = "axum"), not(feature = "actix")))] {
39 quote! {
40 impl<'r, 'o: 'r> ::rocket::response::Responder<'r, 'o> for #_name {
41 fn respond_to(self, request: &'r rocket::request::Request<'_>) -> rocket::response::Result<'o> {
42 match &self {
43 #(Self::#_variants,)*
44 }
45 }
46 }
47 }
48 } else if #[cfg(all(feature = "actix", not(feature = "axum"), not(feature = "rocket")))] {
49 quote! {
50 impl actix_web::ResponseError for #_name {
51 fn error_response(&self) -> actix_web::HttpResponse {
52 match &self {
53 #(Self::#_variants,)*
54 }
55 }
56 }
57 }
58 } else {
59 abort!(Span::call_site(), "Use rocket OR axum OR actix feature!");
60 }
61 }
62}
63
64fn make_enum_variant(variant: &Variant) -> proc_macro2::TokenStream {
65 let _ident = &variant.ident;
66 let _fields = match &variant.fields {
67 syn::Fields::Unit => quote!(),
68 syn::Fields::Named(_) => quote!({ .. }),
69 syn::Fields::Unnamed(fields) => {
70 let unnamed = fields
71 .unnamed
72 .iter()
73 .map(|_| quote!(_))
74 .collect::<Vec<proc_macro2::TokenStream>>();
75 quote!((#(#unnamed),*))
76 }
77 };
78 let attrs: Vec<&Attribute> = variant
79 .attrs
80 .iter()
81 .filter(|attr| attr.path.is_ident("code"))
82 .collect();
83
84 let code = if let Some(attr) = attrs.get(0) {
86 attr.tokens.clone().to_string()
87 } else {
88 quote! {(500)}.to_string()
89 };
90 let _code: proc_macro2::TokenStream = code[1..code.len() - 1]
92 .parse()
93 .expect("Invalid token stream");
94
95 let attrs: Vec<&Attribute> = variant
97 .attrs
98 .iter()
99 .filter(|attr| attr.path.is_ident("body"))
100 .collect();
101
102 let body = if let Some(attr) = attrs.get(0) {
103 attr.tokens.clone().to_string()
104 } else {
105 "({})".to_owned()
106 };
107 let _body: proc_macro2::TokenStream = body[1..body.len() - 1]
109 .parse()
110 .expect("Invalid token stream");
111 cfg_if::cfg_if! {
112 if #[cfg(all(feature = "axum", not(feature = "rocket"), not(feature = "actix")))] {
113 quote! { #_ident #_fields => (axum::http::StatusCode::from_u16(#_code).unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR), #_body).into_response()}
114 } else if #[cfg(all(feature = "rocket", not(feature = "axum"), not(feature = "actix")))] {
115 quote! { #_ident #_fields =>
116 #_body.respond_to(request).map(|mut resp| {
117 resp.set_status(rocket::http::Status::from_code(#_code).unwrap_or(rocket::http::Status::InternalServerError));
118 resp
119 })
120
121 }
122 } else if #[cfg(all(feature = "actix", not(feature = "axum"), not(feature = "rocket")))] {
123 quote! { #_ident #_fields => actix_web::HttpResponse::build(
124 actix_web::http::StatusCode::from_u16(#_code)
125 .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR))
126 .body(#_body)
127 }
128 } else {
129 abort!(Span::call_site(), "Use rocket OR axum OR actix feature!");
130 }
131 }
132}