actix_error_derive/
lib.rs1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3use proc_macro::TokenStream;
4use quote::{quote, format_ident};
5use convert_case::{Case, Casing};
6
7#[derive(FromVariant, Default)]
8#[darling(default, attributes(error))]
9struct Opts {
10 code: Option<u16>,
11 status: Option<String>,
12 kind: Option<String>,
13 msg: Option<String>,
14 ignore: bool,
15 group: bool,
16}
17
18
19#[proc_macro_derive(AsApiError, attributes(error))]
55pub fn derive(input: TokenStream) -> TokenStream {
56 let ast = parse_macro_input!(input as DeriveInput);
58 let ident_name = &ast.ident;
59
60 let enum_data = match &ast.data {
62 syn::Data::Enum(data) => data,
63 _ => {
64 return syn::Error::new_spanned(
65 &ast, "AsApiError can only be derived for enums"
66 ).to_compile_error().into();
67 }
68 };
69 let variants_data = &enum_data.variants;
70
71 let match_arms_results: Vec<Result<proc_macro2::TokenStream, syn::Error>> = variants_data.iter().map(|v| {
73 let variant_ident = &v.ident;
74
75 let field_pats = match &v.fields {
77 syn::Fields::Unnamed(f) => {
78 let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
79 quote! { ( #( #idents ),* ) }
80 }
81 syn::Fields::Named(f) => {
82 let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
83 quote! { { #( #idents ),* } }
84 }
85 syn::Fields::Unit => quote! {},
86 };
87
88 let opts = match Opts::from_variant(&v) {
89 Ok(opts) => opts,
90 Err(e) => return Err(e.into()), };
92
93 let status_code_val = if let Some(code) = opts.code {
94 code
95 } else if let Some(ref error_kind_str) = opts.status {
96 match error_kind_str.as_str() {
97 "BadRequest" => 400,
98 "Unauthorized" => 401,
99 "Forbidden" => 403,
100 "NotFound" => 404,
101 "MethodNotAllowed" => 405,
102 "Conflict" => 409,
103 "Gone" => 410,
104 "PayloadTooLarge" => 413,
105 "UnsupportedMediaType" => 415,
106 "UnprocessableEntity" => 422,
107 "TooManyRequests" => 429,
108 "InternalServerError" => 500,
109 "NotImplemented" => 501,
110 "BadGateway" => 502,
111 "ServiceUnavailable" => 503,
112 "GatewayTimeout" => 504,
113 _ => {
114 return Err(syn::Error::new_spanned(
115 &v.ident, format!("Invalid status attribute \"{}\" for variant {}", error_kind_str, variant_ident), ));
118 }
119 }
120 } else {
121 500 };
123
124 if let Err(e) = actix_web::http::StatusCode::from_u16(status_code_val) {
126 return Err(syn::Error::new_spanned(
127 &v.ident,
128 format!("Invalid status code {} for variant {}: {}", status_code_val, variant_ident, e)
129 ));
130 }
131
132 let kind_str = opts.kind.unwrap_or_else(|| variant_ident.to_string().to_case(Case::Snake));
133
134 let message_expr = match opts.msg {
136 Some(ref msg_s) => {
137 if opts.ignore {
138 quote! { #msg_s.to_owned() }
139 } else if let syn::Fields::Unnamed(f) = &v.fields {
140 let field_vars = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
142 quote! { format!(#msg_s, #( #field_vars ),*) }
143 } else if let syn::Fields::Named(_) = &v.fields {
144 quote! { format!(#msg_s) }
146 } else { quote! { #msg_s.to_owned() }
148 }
149 }
150 None => quote! { String::new() }, };
152
153 let mut details_expr = quote! { None };
154
155 if !opts.group {
156 if let syn::Fields::Unnamed(fields_unnamed) = &v.fields {
158 if fields_unnamed.unnamed.len() == 1 {
159 let first_field = fields_unnamed.unnamed.first().unwrap();
160 let field_ty = &first_field.ty;
162 let type_string = quote!(#field_ty).to_string();
164
165 let field_ident = format_ident!("a0"); if type_string == "serde_json :: Value" {
169 details_expr = quote! { Some(#field_ident.clone()) };
170 }
171 else if type_string == "Option < serde_json :: Value >" || type_string == "std :: option :: Option < serde_json :: Value >" {
173 details_expr = quote! { #field_ident.clone() };
174 }
175 }
176 }
177 }
178
179 let api_error_call = if opts.group {
181 let group_var = format_ident!("a0");
183 quote! { #group_var.as_api_error() }
184 } else {
185 quote! { ApiError::new(#status_code_val, #kind_str, #message_expr, #details_expr) }
186 };
187
188 Ok(quote! {
189 #ident_name::#variant_ident #field_pats => {
190 #api_error_call
191 }
192 })
193 }).collect();
194
195 let mut compiled_match_arms = Vec::new();
197 for result in match_arms_results {
198 match result {
199 Ok(ts) => compiled_match_arms.push(ts),
200 Err(e) => return TokenStream::from(e.to_compile_error()),
201 }
202 }
203
204 let expanded = quote! {
206 impl AsApiErrorTrait for #ident_name {
207 fn as_api_error(&self) -> ApiError {
208 match self {
209 #( #compiled_match_arms ),*
210 }
211 }
212 }
213
214 impl std::fmt::Debug for #ident_name {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 let api_error = self.as_api_error();
218 write!(f, "{:?}", api_error)
219 }
220 }
221
222 impl std::fmt::Display for #ident_name {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 let api_error = self.as_api_error();
226 write!(f, "{}", api_error)
227 }
228 }
229
230 impl actix_web::ResponseError for #ident_name {
231 fn status_code(&self) -> actix_web::http::StatusCode {
232 let api_error = self.as_api_error();
233 actix_web::http::StatusCode::from_u16(api_error.code)
234 .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR)
235 }
236
237 fn error_response(&self) -> actix_web::HttpResponse {
238 let api_error = self.as_api_error();
239 let status = actix_web::http::StatusCode::from_u16(api_error.code)
240 .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
241 actix_web::HttpResponse::build(status).json(api_error)
242 }
243 }
244 };
245
246 TokenStream::from(expanded)
247}