actix_error_derive/
lib.rs

1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3use proc_macro::TokenStream;
4use quote::{quote, format_ident};
5use convert_case::{Case, Casing};
6
7#[derive(FromVariant, Default)] 
8#[darling(default, attributes(error))]
9struct Opts {
10    code: Option<u16>,
11    status: Option<String>,
12    kind: Option<String>,
13    msg: Option<String>,
14    ignore: bool,
15    group: bool,
16}
17
18
19/// This derive macro is used to convert an enum into an ApiError.  
20/// You can use it by adding the ```#[derive(AsApiError)]``` attribute to your enum.  
21/// By default, the kind is ```snake case```.  
22/// ```#[error(kind = "your_message_id")]``` attribute to the variant.  
23/// You can also add a custom code to the error by adding the ```#[error(code = 400)]``` attribute to the variant.  
24/// The following status are available and return the corresponding status code: 
25/// ``` rust
26/// fn get_status_code(error_kind: &str) -> u16 {
27///     match error_kind {
28///         "BadRequest" => 400,
29///         "Unauthorized" => 401,
30///         "Forbidden" => 403,
31///         "NotFound" => 404,
32///         "MethodNotAllowed" => 405,
33///         "Conflict" => 409,
34///         "Gone" => 410,
35///         "PayloadTooLarge" => 413,
36///         "UnsupportedMediaType" => 415,
37///         "UnprocessableEntity" => 422,
38///         "TooManyRequests" => 429,
39///         "InternalServerError" => 500,
40///         "NotImplemented" => 501,
41///         "BadGateway" => 502,
42///         "ServiceUnavailable" => 503,
43///         "GatewayTimeout" => 504,
44///         _ => 0, // Or some other default/error handling
45///     }
46/// }
47///
48/// // Example usage:
49/// let code = get_status_code("NotFound");
50/// assert_eq!(code, 404);
51/// let default_code = get_status_code("SomeOtherError");
52/// assert_eq!(default_code, 0);
53/// ```
54#[proc_macro_derive(AsApiError, attributes(error))]
55pub fn derive(input: TokenStream) -> TokenStream {
56    // Parse the input tokens into a syntax tree
57    let ast = parse_macro_input!(input as DeriveInput); 
58    let ident_name = &ast.ident;
59
60    // Get the variants
61    let enum_data = match &ast.data {
62        syn::Data::Enum(data) => data,
63        _ => {
64            return syn::Error::new_spanned(
65                &ast, "AsApiError can only be derived for enums"
66            ).to_compile_error().into();
67        }
68    };
69    let variants_data = &enum_data.variants;
70
71    // Generate the match arms for the as_api_error method
72    let match_arms_results: Vec<Result<proc_macro2::TokenStream, syn::Error>> = variants_data.iter().map(|v| {
73        let variant_ident = &v.ident;
74        
75        // Determine the pattern for matching fields
76        let field_pats = match &v.fields {
77            syn::Fields::Unnamed(f) => {
78                let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
79                quote! { ( #( #idents ),* ) }
80            }
81            syn::Fields::Named(f) => {
82                let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
83                quote! { { #( #idents ),* } }
84            }
85            syn::Fields::Unit => quote! {},
86        };
87
88        let opts = match Opts::from_variant(&v) {
89            Ok(opts) => opts,
90            Err(e) => return Err(e.into()), // darling::Error can be converted to syn::Error then to_compile_error
91        };
92            
93        let status_code_val = if let Some(code) = opts.code {
94            code
95        } else if let Some(ref error_kind_str) = opts.status {
96            match error_kind_str.as_str() {
97                "BadRequest" => 400,
98                "Unauthorized" => 401,
99                "Forbidden" => 403,
100                "NotFound" => 404,
101                "MethodNotAllowed" => 405,
102                "Conflict" => 409,
103                "Gone" => 410,
104                "PayloadTooLarge" => 413,
105                "UnsupportedMediaType" => 415,
106                "UnprocessableEntity" => 422,
107                "TooManyRequests" => 429,
108                "InternalServerError" => 500,
109                "NotImplemented" => 501,
110                "BadGateway" => 502,
111                "ServiceUnavailable" => 503,
112                "GatewayTimeout" => 504,
113                _ => {
114                    return Err(syn::Error::new_spanned(
115                        &v.ident, // Span to the variant identifier
116                        format!("Invalid status attribute \"{}\" for variant {}", error_kind_str, variant_ident), // Corrected string escaping
117                    ));
118                }
119            }
120        } else {
121            500 // Default status code
122        };
123        
124        // Validate status code
125        if let Err(e) = actix_web::http::StatusCode::from_u16(status_code_val) {
126             return Err(syn::Error::new_spanned(
127                 &v.ident, 
128                 format!("Invalid status code {} for variant {}: {}", status_code_val, variant_ident, e)
129            ));
130        }
131        
132        let kind_str = opts.kind.unwrap_or_else(|| variant_ident.to_string().to_case(Case::Snake));
133
134        // Generate the message expression
135        let message_expr = match opts.msg {
136            Some(ref msg_s) => {
137                if opts.ignore {
138                    quote! { #msg_s.to_owned() }
139                } else if let syn::Fields::Unnamed(f) = &v.fields {
140                    // For tuple variants, interpolate fields named a0, a1, ...
141                    let field_vars = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
142                    quote! { format!(#msg_s, #( #field_vars ),*) }
143                } else if let syn::Fields::Named(_) = &v.fields {
144                    // For struct variants, msg is a format string, but no fields are interpolated by default by the macro
145                    // If specific field interpolation is needed for named fields, it would require a more complex parsing of msg_s
146                    // or a different attribute syntax.
147                    quote! { format!(#msg_s) } 
148                } else { // Unit variants
149                    quote! { #msg_s.to_owned() }
150                }
151            }
152            None => quote! { String::new() }, // Default empty message
153        };
154        
155        // Generate the ApiError construction call
156        let api_error_call = if opts.group {
157            // Assumes the first field of a tuple variant is 'a0' if 'group' is true
158            let group_var = format_ident!("a0"); 
159            quote! { #group_var.as_api_error() }
160        } else {
161            quote! { ApiError::new(#status_code_val, #kind_str, #message_expr, None) } // Pass None for the details argument
162        };
163
164        Ok(quote! {
165            #ident_name::#variant_ident #field_pats => {
166                #api_error_call
167            }
168        })
169    }).collect();
170
171    // Handle any errors that occurred during match arm generation
172    let mut compiled_match_arms = Vec::new();
173    for result in match_arms_results {
174        match result {
175            Ok(ts) => compiled_match_arms.push(ts),
176            Err(e) => return TokenStream::from(e.to_compile_error()),
177        }
178    }
179
180    // Generate the final implementations
181    let expanded = quote! {
182        impl AsApiErrorTrait for #ident_name {
183            fn as_api_error(&self) -> ApiError {
184                match self {
185                    #( #compiled_match_arms ),*
186                }
187            }
188        }
189    
190        impl std::fmt::Debug for #ident_name {
191            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192                // Use the generated as_api_error method
193                let api_error = self.as_api_error();
194                write!(f, "{:?}", api_error)
195            }
196        }
197    
198        impl std::fmt::Display for #ident_name {
199            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200                // Use the generated as_api_error method
201                let api_error = self.as_api_error();
202                write!(f, "{}", api_error)
203            }
204        }
205    
206        impl actix_web::ResponseError for #ident_name {
207            fn status_code(&self) -> actix_web::http::StatusCode {
208                let api_error = self.as_api_error();
209                actix_web::http::StatusCode::from_u16(api_error.code)
210                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR)
211            }
212        
213            fn error_response(&self) -> actix_web::HttpResponse {
214                let api_error = self.as_api_error();
215                let status = actix_web::http::StatusCode::from_u16(api_error.code)
216                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
217                actix_web::HttpResponse::build(status).json(api_error)
218            }
219        }
220    };
221
222    TokenStream::from(expanded)
223}