1#![forbid(unsafe_code, clippy::unwrap_used)]
53use proc_macro2::TokenStream;
54use quote::quote;
55
56#[proc_macro_derive(ResponseError, attributes(status_code))]
58pub fn derive_response_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
59 derive(syn::parse_macro_input!(input))
60 .unwrap_or_else(|e| e.to_compile_error())
61 .into()
62}
63
64fn derive(input: syn::DeriveInput) -> syn::Result<TokenStream> {
66 let ident = input.ident;
67
68 match input.data {
69 syn::Data::Enum(e) => {
70 let response_variants = e
71 .variants
72 .into_iter()
73 .filter_map(ResponseVariant::from_variant)
74 .collect::<Result<Vec<_>, _>>()?;
75
76 let status_code_arms = response_variants
77 .iter()
78 .map(ResponseVariant::to_status_code_arm);
79
80 let error_response_patterns = response_variants
81 .iter()
82 .map(ResponseVariant::to_error_response_pattern);
83
84 let error_response_arm = if response_variants.len() != 0 {
85 quote! {
86 #(#error_response_patterns)|* => {
87 ::actix_web::HttpResponse::build(self.status_code()).json(::serde_json::json!({
88 "error": self,
89 "message": self.to_string(),
90 }))
91 }
92 }
93 } else {
94 quote! {}
95 };
96
97 Ok(quote! {
98 impl ::actix_web::ResponseError for #ident {
99 fn status_code(&self) -> ::actix_web::http::StatusCode {
100 match self {
101 #(#status_code_arms,)*
102 _ => ::actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
103 }
104 }
105
106 fn error_response(&self) -> ::actix_web::HttpResponse {
107 match self {
108 #error_response_arm
109 _ => ::actix_web::HttpResponse::InternalServerError().finish(),
110 }
111 }
112 }
113 })
114 }
115 syn::Data::Struct(_) => match get_status_code(&input.attrs) {
116 Some(Ok(status_code)) => Ok(quote! {
117 impl ::actix_web::ResponseError for #ident {
118 fn status_code(&self) -> ::actix_web::http::StatusCode {
119 ::actix_web::http::StatusCode::#status_code
120 }
121
122 fn error_response(&self) -> ::actix_web::HttpResponse {
123 ::actix_web::HttpResponse::build(self.status_code()).json(::serde_json::json!({
124 "error": self,
125 "message": self.to_string(),
126 }))
127 }
128 }
129 }),
130 None => Ok(quote! {
131 impl ::actix_web::ResponseError for #ident {
132 fn error_response(&self) -> ::actix_web::HttpResponse {
133 ::actix_web::HttpResponse::InternalServerError().finish()
134 }
135 }
136 }),
137 Some(Err(e)) => Err(e),
138 },
139 syn::Data::Union(_) => Err(syn::Error::new_spanned(
140 ident,
141 "ResponseError derive cannot be applied to unions",
142 )),
143 }
144}
145
146fn get_status_code(attrs: &[syn::Attribute]) -> Option<syn::Result<syn::Ident>> {
148 let response_attrs: Vec<_> = attrs
149 .iter()
150 .filter(|attr| attr.path.is_ident("status_code"))
151 .collect();
152
153 match response_attrs.len() {
154 1 => Some(response_attrs[0].parse_args()),
155 0 => None,
156 _ => Some(Err(syn::Error::new_spanned(
157 response_attrs[1],
158 "only one #[status_code(...)] attribute is allowed",
159 ))),
160 }
161}
162
163struct ResponseVariant {
165 pub status_code: syn::Ident,
166 pub variant: syn::Ident,
167}
168
169impl ResponseVariant {
170 pub fn from_variant(variant: syn::Variant) -> Option<syn::Result<Self>> {
172 let ident = variant.ident;
173
174 get_status_code(&variant.attrs).map(|r| {
175 r.map(|status_code| Self {
176 status_code,
177 variant: ident,
178 })
179 })
180 }
181
182 pub fn to_status_code_arm(&self) -> TokenStream {
184 let Self {
185 status_code,
186 variant,
187 } = self;
188
189 quote! {
190 Self::#variant { .. } => ::actix_web::http::StatusCode::#status_code
191 }
192 }
193
194 pub fn to_error_response_pattern(&self) -> TokenStream {
196 let variant = &self.variant;
197
198 quote! {
199 Self::#variant { .. }
200 }
201 }
202}