actix_error_derive/
lib.rs

1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3use proc_macro::TokenStream;
4use quote::{quote, format_ident};
5use convert_case::{Case, Casing};
6
7#[derive(FromVariant, Default)] 
8#[darling(default, attributes(error))]
9struct Opts {
10    code: Option<u16>,
11    status: Option<String>,
12    kind: Option<String>,
13    msg: Option<String>,
14    ignore: bool,
15    group: bool,
16}
17
18
19/// This derive macro is used to convert an enum into an ApiError.  
20/// You can use it by adding the ```#[derive(AsApiError)]``` attribute to your enum.  
21/// By default, the kind is ```snake case```.  
22/// ```#[error(kind = "your_message_id")]``` attribute to the variant.  
23/// You can also add a custom code to the error by adding the ```#[error(code = 400)]``` attribute to the variant.  
24/// The following status are available and return the corresponding status code: 
25/// ``` rust
26/// fn get_status_code(error_kind: &str) -> u16 {
27///     match error_kind {
28///         "BadRequest" => 400,
29///         "Unauthorized" => 401,
30///         "Forbidden" => 403,
31///         "NotFound" => 404,
32///         "MethodNotAllowed" => 405,
33///         "Conflict" => 409,
34///         "Gone" => 410,
35///         "PayloadTooLarge" => 413,
36///         "UnsupportedMediaType" => 415,
37///         "UnprocessableEntity" => 422,
38///         "TooManyRequests" => 429,
39///         "InternalServerError" => 500,
40///         "NotImplemented" => 501,
41///         "BadGateway" => 502,
42///         "ServiceUnavailable" => 503,
43///         "GatewayTimeout" => 504,
44///         _ => 0, // Or some other default/error handling
45///     }
46/// }
47///
48/// // Example usage:
49/// let code = get_status_code("NotFound");
50/// assert_eq!(code, 404);
51/// let default_code = get_status_code("SomeOtherError");
52/// assert_eq!(default_code, 0);
53/// ```
54#[proc_macro_derive(AsApiError, attributes(error))]
55pub fn derive(input: TokenStream) -> TokenStream {
56    // Parse the input tokens into a syntax tree
57    let ast = parse_macro_input!(input as DeriveInput); 
58    let ident_name = &ast.ident;
59
60    // Get the variants
61    let enum_data = match &ast.data {
62        syn::Data::Enum(data) => data,
63        _ => {
64            return syn::Error::new_spanned(
65                &ast, "AsApiError can only be derived for enums"
66            ).to_compile_error().into();
67        }
68    };
69    let variants_data = &enum_data.variants;
70
71    // Generate the match arms for the as_api_error method
72    let match_arms_results: Vec<Result<proc_macro2::TokenStream, syn::Error>> = variants_data.iter().map(|v| {
73        let variant_ident = &v.ident;
74        
75        // Determine the pattern for matching fields
76        let field_pats = match &v.fields {
77            syn::Fields::Unnamed(f) => {
78                let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
79                quote! { ( #( #idents ),* ) }
80            }
81            syn::Fields::Named(f) => {
82                let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
83                quote! { { #( #idents ),* } }
84            }
85            syn::Fields::Unit => quote! {},
86        };
87
88        let opts = match Opts::from_variant(&v) {
89            Ok(opts) => opts,
90            Err(e) => return Err(e.into()), // darling::Error can be converted to syn::Error then to_compile_error
91        };
92            
93        let status_code_val = if let Some(code) = opts.code {
94            code
95        } else if let Some(ref error_kind_str) = opts.status {
96            match error_kind_str.as_str() {
97                "BadRequest" => 400,
98                "Unauthorized" => 401,
99                "Forbidden" => 403,
100                "NotFound" => 404,
101                "MethodNotAllowed" => 405,
102                "Conflict" => 409,
103                "Gone" => 410,
104                "PayloadTooLarge" => 413,
105                "UnsupportedMediaType" => 415,
106                "UnprocessableEntity" => 422,
107                "TooManyRequests" => 429,
108                "InternalServerError" => 500,
109                "NotImplemented" => 501,
110                "BadGateway" => 502,
111                "ServiceUnavailable" => 503,
112                "GatewayTimeout" => 504,
113                _ => {
114                    return Err(syn::Error::new_spanned(
115                        &v.ident, // Span to the variant identifier
116                        format!("Invalid status attribute \"{}\" for variant {}", error_kind_str, variant_ident), // Corrected string escaping
117                    ));
118                }
119            }
120        } else {
121            500 // Default status code
122        };
123        
124        // Validate status code
125        if let Err(e) = actix_web::http::StatusCode::from_u16(status_code_val) {
126             return Err(syn::Error::new_spanned(
127                 &v.ident, 
128                 format!("Invalid status code {} for variant {}: {}", status_code_val, variant_ident, e)
129            ));
130        }
131        
132        let kind_str = opts.kind.unwrap_or_else(|| variant_ident.to_string().to_case(Case::Snake));
133
134        // Generate the message expression
135        let message_expr = match opts.msg {
136            Some(ref msg_s) => {
137                if opts.ignore {
138                    quote! { #msg_s.to_owned() }
139                } else if let syn::Fields::Unnamed(f) = &v.fields {
140                    // For tuple variants, interpolate fields named a0, a1, ...
141                    let field_vars = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
142                    quote! { format!(#msg_s, #( #field_vars ),*) }
143                } else if let syn::Fields::Named(_) = &v.fields {
144                    // For struct variants, msg is a format string. Fields are interpolated from local scope.
145                    quote! { format!(#msg_s) } 
146                } else { // Unit variants
147                    quote! { #msg_s.to_owned() }
148                }
149            }
150            None => quote! { String::new() }, // Default empty message
151        };
152        
153        let mut details_expr = quote! { None };
154
155        if !opts.group {
156            // Check if the variant has exactly one unnamed field
157            if let syn::Fields::Unnamed(fields_unnamed) = &v.fields {
158                if fields_unnamed.unnamed.len() == 1 {
159                    let first_field = fields_unnamed.unnamed.first().unwrap();
160                    // Get the type of the first field
161                    let field_ty = &first_field.ty;
162                    // Convert the type to a string for comparison
163                    let type_string = quote!(#field_ty).to_string();
164                    
165                    let field_ident = format_ident!("a0"); // Identifier for the first unnamed field
166
167                    // Check if the type is serde_json::Value
168                    if type_string == "serde_json :: Value" {
169                        details_expr = quote! { Some(#field_ident.clone()) };
170                    }
171                    // Check if the type is Option<serde_json::Value>
172                    else if type_string == "Option < serde_json :: Value >" || type_string == "std :: option :: Option < serde_json :: Value >" {
173                        details_expr = quote! { #field_ident.clone() };
174                    }
175                }
176            }
177        }
178        
179        // Generate the ApiError construction call
180        let api_error_call = if opts.group {
181            // Assumes the first field of a tuple variant is 'a0' if 'group' is true
182            let group_var = format_ident!("a0"); 
183            quote! { #group_var.as_api_error() }
184        } else {
185            quote! { ApiError::new(#status_code_val, #kind_str, #message_expr, #details_expr) } 
186        };
187
188        Ok(quote! {
189            #ident_name::#variant_ident #field_pats => {
190                #api_error_call
191            }
192        })
193    }).collect();
194
195    // Handle any errors that occurred during match arm generation
196    let mut compiled_match_arms = Vec::new();
197    for result in match_arms_results {
198        match result {
199            Ok(ts) => compiled_match_arms.push(ts),
200            Err(e) => return TokenStream::from(e.to_compile_error()),
201        }
202    }
203
204    // Generate the final implementations
205    let expanded = quote! {
206        impl AsApiErrorTrait for #ident_name {
207            fn as_api_error(&self) -> ApiError {
208                match self {
209                    #( #compiled_match_arms ),*
210                }
211            }
212        }
213    
214        impl std::fmt::Debug for #ident_name {
215            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216                // Use the generated as_api_error method
217                let api_error = self.as_api_error();
218                write!(f, "{:?}", api_error)
219            }
220        }
221    
222        impl std::fmt::Display for #ident_name {
223            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224                // Use the generated as_api_error method
225                let api_error = self.as_api_error();
226                write!(f, "{}", api_error)
227            }
228        }
229    
230        impl actix_web::ResponseError for #ident_name {
231            fn status_code(&self) -> actix_web::http::StatusCode {
232                let api_error = self.as_api_error();
233                actix_web::http::StatusCode::from_u16(api_error.code)
234                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR)
235            }
236        
237            fn error_response(&self) -> actix_web::HttpResponse {
238                let api_error = self.as_api_error();
239                let status = actix_web::http::StatusCode::from_u16(api_error.code)
240                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
241                actix_web::HttpResponse::build(status).json(api_error)
242            }
243        }
244    };
245
246    TokenStream::from(expanded)
247}