actix_error_derive/
lib.rs1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3use proc_macro::TokenStream;
4use quote::{quote, format_ident};
5use convert_case::{Case, Casing};
6
7#[derive(FromVariant, Default)]
8#[darling(default, attributes(error))]
9struct Opts {
10 code: Option<u16>,
11 status: Option<String>,
12 kind: Option<String>,
13 msg: Option<String>,
14 ignore: bool,
15 group: bool,
16}
17
18
19#[proc_macro_derive(AsApiError, attributes(error))]
55pub fn derive(input: TokenStream) -> TokenStream {
56 let ast = parse_macro_input!(input as DeriveInput);
58 let ident_name = &ast.ident;
59
60 let enum_data = match &ast.data {
62 syn::Data::Enum(data) => data,
63 _ => {
64 return syn::Error::new_spanned(
65 &ast, "AsApiError can only be derived for enums"
66 ).to_compile_error().into();
67 }
68 };
69 let variants_data = &enum_data.variants;
70
71 let match_arms_results: Vec<Result<proc_macro2::TokenStream, syn::Error>> = variants_data.iter().map(|v| {
73 let variant_ident = &v.ident;
74
75 let field_pats = match &v.fields {
77 syn::Fields::Unnamed(f) => {
78 let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
79 quote! { ( #( #idents ),* ) }
80 }
81 syn::Fields::Named(f) => {
82 let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
83 quote! { { #( #idents ),* } }
84 }
85 syn::Fields::Unit => quote! {},
86 };
87
88 let opts = match Opts::from_variant(&v) {
89 Ok(opts) => opts,
90 Err(e) => return Err(e.into()), };
92
93 let status_code_val = if let Some(code) = opts.code {
94 code
95 } else if let Some(ref error_kind_str) = opts.status {
96 match error_kind_str.as_str() {
97 "BadRequest" => 400,
98 "Unauthorized" => 401,
99 "Forbidden" => 403,
100 "NotFound" => 404,
101 "MethodNotAllowed" => 405,
102 "Conflict" => 409,
103 "Gone" => 410,
104 "PayloadTooLarge" => 413,
105 "UnsupportedMediaType" => 415,
106 "UnprocessableEntity" => 422,
107 "TooManyRequests" => 429,
108 "InternalServerError" => 500,
109 "NotImplemented" => 501,
110 "BadGateway" => 502,
111 "ServiceUnavailable" => 503,
112 "GatewayTimeout" => 504,
113 _ => {
114 return Err(syn::Error::new_spanned(
115 &v.ident, format!("Invalid status attribute \"{}\" for variant {}", error_kind_str, variant_ident), ));
118 }
119 }
120 } else {
121 500 };
123
124 if let Err(e) = actix_web::http::StatusCode::from_u16(status_code_val) {
126 return Err(syn::Error::new_spanned(
127 &v.ident,
128 format!("Invalid status code {} for variant {}: {}", status_code_val, variant_ident, e)
129 ));
130 }
131
132 let kind_str = opts.kind.unwrap_or_else(|| variant_ident.to_string().to_case(Case::Snake));
133
134 let message_expr = match opts.msg {
136 Some(ref msg_s) => {
137 if opts.ignore {
138 quote! { #msg_s.to_owned() }
139 } else if let syn::Fields::Unnamed(f) = &v.fields {
140 let field_vars = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
142 quote! { format!(#msg_s, #( #field_vars ),*) }
143 } else if let syn::Fields::Named(_) = &v.fields {
144 quote! { format!(#msg_s) }
148 } else { quote! { #msg_s.to_owned() }
150 }
151 }
152 None => quote! { String::new() }, };
154
155 let api_error_call = if opts.group {
157 let group_var = format_ident!("a0");
159 quote! { #group_var.as_api_error() }
160 } else {
161 quote! { ApiError::new(#status_code_val, #kind_str, #message_expr, None) } };
163
164 Ok(quote! {
165 #ident_name::#variant_ident #field_pats => {
166 #api_error_call
167 }
168 })
169 }).collect();
170
171 let mut compiled_match_arms = Vec::new();
173 for result in match_arms_results {
174 match result {
175 Ok(ts) => compiled_match_arms.push(ts),
176 Err(e) => return TokenStream::from(e.to_compile_error()),
177 }
178 }
179
180 let expanded = quote! {
182 impl AsApiErrorTrait for #ident_name {
183 fn as_api_error(&self) -> ApiError {
184 match self {
185 #( #compiled_match_arms ),*
186 }
187 }
188 }
189
190 impl std::fmt::Debug for #ident_name {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 let api_error = self.as_api_error();
194 write!(f, "{:?}", api_error)
195 }
196 }
197
198 impl std::fmt::Display for #ident_name {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 let api_error = self.as_api_error();
202 write!(f, "{}", api_error)
203 }
204 }
205
206 impl actix_web::ResponseError for #ident_name {
207 fn status_code(&self) -> actix_web::http::StatusCode {
208 let api_error = self.as_api_error();
209 actix_web::http::StatusCode::from_u16(api_error.code)
210 .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR)
211 }
212
213 fn error_response(&self) -> actix_web::HttpResponse {
214 let api_error = self.as_api_error();
215 let status = actix_web::http::StatusCode::from_u16(api_error.code)
216 .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
217 actix_web::HttpResponse::build(status).json(api_error)
218 }
219 }
220 };
221
222 TokenStream::from(expanded)
223}