error_http/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::Span;
4use proc_macro::TokenStream;
5use proc_macro_error::{abort, proc_macro_error};
6use quote::quote;
7use syn::{parse_macro_input, Attribute, Data, DataEnum, DeriveInput, Ident, Variant};
8
9#[proc_macro_error]
10#[proc_macro_derive(ToResponse, attributes(code, body))]
11pub fn to_http_error_code(item: TokenStream) -> TokenStream {
12    let ast = parse_macro_input!(item as DeriveInput);
13
14    let name = &ast.ident;
15    let Data::Enum(enum_data) = ast.data else {
16        abort!(Span::call_site(), "Only supported for enum");
17    };
18    impl_into_response(name, enum_data).into()
19}
20
21fn impl_into_response(_name: &Ident, enum_data: DataEnum) -> proc_macro2::TokenStream {
22    let _variants: Vec<proc_macro2::TokenStream> = enum_data
23        .variants
24        .iter()
25        .map(|v| make_enum_variant(v))
26        .collect();
27    cfg_if::cfg_if! {
28        if #[cfg(all(feature = "axum", not(feature = "rocket"), not(feature = "actix")))] {
29            quote! {
30                impl axum::response::IntoResponse for #_name {
31                    fn into_response(self) -> axum::response::Response {
32                        match &self {
33                            #(Self::#_variants,)*
34                        }
35                    }
36                }
37            }
38        } else if #[cfg(all(feature = "rocket", not(feature = "axum"), not(feature = "actix")))] {
39            quote! {
40                impl<'r, 'o: 'r> ::rocket::response::Responder<'r, 'o> for #_name {
41                    fn respond_to(self, request: &'r rocket::request::Request<'_>) -> rocket::response::Result<'o> {
42                        match &self {
43                            #(Self::#_variants,)*
44                        }
45                    }
46                }
47            }
48        } else if #[cfg(all(feature = "actix", not(feature = "axum"), not(feature = "rocket")))] {
49            quote! {
50                impl actix_web::ResponseError for #_name {
51                    fn error_response(&self) -> actix_web::HttpResponse {
52                        match &self {
53                            #(Self::#_variants,)*
54                        }
55                    }
56                }
57            }
58        } else {
59            abort!(Span::call_site(), "Use rocket OR axum OR actix feature!");
60        }
61    }
62}
63
64fn make_enum_variant(variant: &Variant) -> proc_macro2::TokenStream {
65    let _ident = &variant.ident;
66    let _fields = match &variant.fields {
67        syn::Fields::Unit => quote!(),
68        syn::Fields::Named(_) => quote!({ .. }),
69        syn::Fields::Unnamed(fields) => {
70            let unnamed = fields
71                .unnamed
72                .iter()
73                .map(|_| quote!(_))
74                .collect::<Vec<proc_macro2::TokenStream>>();
75            quote!((#(#unnamed),*))
76        }
77    };
78    let attrs: Vec<&Attribute> = variant
79        .attrs
80        .iter()
81        .filter(|attr| attr.path.is_ident("code"))
82        .collect();
83
84    // HTTP code
85    let code = if let Some(attr) = attrs.get(0) {
86        attr.tokens.clone().to_string()
87    } else {
88        quote! {(500)}.to_string()
89    };
90    //Trimming ( )
91    let _code: proc_macro2::TokenStream = code[1..code.len() - 1]
92        .parse()
93        .expect("Invalid token stream");
94
95    // Response body
96    let attrs: Vec<&Attribute> = variant
97        .attrs
98        .iter()
99        .filter(|attr| attr.path.is_ident("body"))
100        .collect();
101
102    let body = if let Some(attr) = attrs.get(0) {
103        attr.tokens.clone().to_string()
104    } else {
105        "({})".to_owned()
106    };
107    //Trimming ( )
108    let _body: proc_macro2::TokenStream = body[1..body.len() - 1]
109        .parse()
110        .expect("Invalid token stream");
111    cfg_if::cfg_if! {
112        if #[cfg(all(feature = "axum", not(feature = "rocket"), not(feature = "actix")))] {
113             quote! { #_ident #_fields => (axum::http::StatusCode::from_u16(#_code).unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR), #_body).into_response()}
114         } else if #[cfg(all(feature = "rocket", not(feature = "axum"), not(feature = "actix")))] {
115             quote! { #_ident #_fields =>
116             #_body.respond_to(request).map(|mut resp| {
117                     resp.set_status(rocket::http::Status::from_code(#_code).unwrap_or(rocket::http::Status::InternalServerError));
118                     resp
119                 })
120
121             }
122         } else if #[cfg(all(feature = "actix", not(feature = "axum"), not(feature = "rocket")))] {
123             quote! { #_ident #_fields =>  actix_web::HttpResponse::build(
124                 actix_web::http::StatusCode::from_u16(#_code)
125                 .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR))
126                 .body(#_body)
127             }
128         } else {
129            abort!(Span::call_site(), "Use rocket OR axum OR actix feature!");
130         }
131    }
132}