actix_error_derive/
lib.rs

1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3use proc_macro::TokenStream;
4use quote::{quote, format_ident};
5use convert_case::{Case, Casing};
6
7#[derive(FromVariant, Default)] 
8#[darling(default, attributes(api_error))]
9struct Opts {
10    code: Option<u16>,
11    status: Option<String>,
12    kind: Option<String>,
13    msg: Option<String>,
14    ignore: bool,
15    group: bool,
16}
17
18
19/// Derives the `AsApiErrorTrait` for an enum, allowing it to be converted into an `ApiError`
20/// suitable for Actix-Web responses. It also conditionally implements `std::fmt::Display`.
21///
22/// ## Attributes
23///
24/// Attributes are placed on enum variants using `#[api_error(...)]`:
25///
26/// - `code = <u16>`: Specifies a raw HTTP status code (e.g., `code = 404`).
27///   If both `code` and `status` are provided, `code` takes precedence.
28///
29/// - `status = "<StatusCodeString>"`: Specifies the HTTP status using a predefined string.
30///   (e.g., `status = "NotFound"`). See below for a list of supported strings.
31///   If neither `code` nor `status` is provided, defaults to `500` (Internal Server Error).
32///
33/// - `kind = "<string>"`: Sets the `kind` field in the `ApiError`.
34///   Defaults to the `snake_case` version of the variant name (e.g., `MyVariant` becomes `"my_variant"`).
35///
36/// - `msg = "<string>"`: Provides a custom error message.
37///   - For variants with named fields: `msg = "Error for {field_name}"`.
38///   - For variants with unnamed (tuple) fields: `msg = "Error with value {0} and {1}"`.
39///   - If `msg` is not provided, the message is generated based on the `Display` trait:
40///     - If this macro generates `Display` (see "Conditional `std::fmt::Display` Implementation" below), 
41///       it will be the variant name or a simple format derived from it.
42///     - If the user provides `Display` (e.g., via `thiserror`), that implementation is used (`self.to_string()`).
43///
44/// - `ignore = <bool>`: (Default: `false`)
45///   - If `true`, `msg` is *not* provided, and the macro does *not* generate `Display`,
46///     the message will be the variant name, and fields will not be automatically formatted into the message.
47///   - This attribute does *not* prevent field interpolation if a `msg` attribute *is* provided
48///     (e.g., `#[api_error(msg = "Value: {0}", ignore)] MyVariant(i32)` will still print the value).
49///   - Its primary use is to simplify the message to just the variant name when no `msg` is given
50///     and `Display` is not generated by this macro, overriding default field formatting.
51///
52/// - `group = <bool>`: (Default: `false`)
53///   - If `true`, the variant is expected to hold a single field that itself implements `AsApiErrorTrait`.
54///     The `as_api_error()` method of this inner error will be called.
55///     Other attributes like `code`, `status`, `msg`, `kind` on the group variant are ignored.
56///
57/// ## Automatic `details` Field Population
58///
59/// If a variant is *not* a `group` and contains a single field of type `serde_json::Value`
60/// or `Option<serde_json::Value>`, this field's value will automatically populate the
61/// `details` field of the generated `ApiError`.
62///
63/// ## Conditional `std::fmt::Display` Implementation
64///
65/// The `std::fmt::Display` trait is implemented for the enum by this macro *if and only if*
66/// at least one variant has an explicit `#[api_error(msg = "...")]` attribute.
67/// - If implemented by the macro:
68///   - Variants with `msg` will use that formatted message for their `Display` output.
69///   - Variants without `msg` will display as their variant name (e.g., `MyEnum::VariantName` displays as "VariantName").
70///
71/// If no variants use `#[api_error(msg = "...")]`, you are expected to provide your own
72/// `Display` implementation (e.g., using the `thiserror` crate or manually).
73/// The `as_api_error` method will then use `self.to_string()` for the `ApiError` message if `msg` is not set on the variant.
74///
75/// ## Supported `status` Strings and Their Codes
76///
77/// ```rust
78/// // "BadRequest" => 400
79/// // "Unauthorized" => 401
80/// // "Forbidden" => 403
81/// // "NotFound" => 404
82/// // "MethodNotAllowed" => 405
83/// // "Conflict" => 409
84/// // "Gone" => 410
85/// // "PayloadTooLarge" => 413
86/// // "UnsupportedMediaType" => 415
87/// // "UnprocessableEntity" => 422
88/// // "TooManyRequests" => 429
89/// // "InternalServerError" => 500 (Default if no code/status is specified)
90/// // "NotImplemented" => 501
91/// // "BadGateway" => 502
92/// // "ServiceUnavailable" => 503
93/// // "GatewayTimeout" => 504
94/// ```
95/// Using an unsupported string in `status` will result in a compile-time error.
96///
97/// ## Example
98///
99/// ```rust
100/// use actix_error_derive::AsApiError;
101/// // Ensure ApiError and AsApiErrorTrait are in scope, typically via:
102/// // use actix_error::{ApiError, AsApiErrorTrait}; 
103/// use serde_json::json;
104///
105/// // Dummy AnotherErrorType for the group example
106/// #[derive(Debug)]
107/// pub struct AnotherErrorType;
108/// impl actix_error::AsApiErrorTrait for AnotherErrorType {
109///     fn as_api_error(&self) -> actix_error::ApiError {
110///         actix_error::ApiError::new(401, "auth_failure", "Authentication failed".to_string(), None)
111///     }
112/// }
113/// impl std::fmt::Display for AnotherErrorType { 
114///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115///         write!(f, "AnotherErrorType: Authentication Failed")
116///     }
117/// }
118///
119/// #[derive(Debug, AsApiError)]
120/// pub enum MyError {
121///     #[api_error(status = "NotFound", msg = "Resource not found.")]
122///     NotFound, // Display will be "Resource not found."
123///
124///     // No msg, so if Display is macro-generated, it's "InvalidInput".
125///     // If user provides Display (e.g. with thiserror), that's used for ApiError.message.
126///     #[api_error(code = 400, kind = "input_validation")]
127///     InvalidInput { field: String, reason: String }, 
128///
129///     #[api_error(status = "UnprocessableEntity", msg = "Cannot process item: {0}")]
130///     Unprocessable(String), // Display will be "Cannot process item: <value>"
131///
132///     // 'details' will be auto-populated from the serde_json::Value field.
133///     // msg is present, so Display is "Detailed error occurred."
134///     #[api_error(status = "BadRequest", msg = "Detailed error occurred.")] 
135///     DetailedError(serde_json::Value),
136///
137///     #[api_error(group)]
138///     AuthError(AnotherErrorType), // Delegates to AnotherErrorType's AsApiErrorTrait
139/// }
140///
141/// // Since MyError has variants with `msg`, `Display` is generated by AsApiError.
142/// // If no variants had `msg`, you would need to implement `Display` manually or with `thiserror`:
143/// //
144/// // #[derive(Debug, AsApiError, thiserror::Error)] // Example with thiserror
145/// // pub enum MyErrorWithoutMacroDisplay {
146/// //     #[error("Item {0} was not found")] // thiserror message
147/// //     #[api_error(status = "NotFound")]
148/// //     NotFound(String),
149/// //
150/// //     #[error("Input is invalid: {reason}")]
151/// //     #[api_error(code = 400, kind = "bad_input")]
152/// //     InvalidInput { reason: String }
153/// // }
154/// ```
155#[proc_macro_derive(AsApiError, attributes(api_error))]
156pub fn derive(input: TokenStream) -> TokenStream {
157    // Parse the input tokens into a syntax tree
158    let ast = parse_macro_input!(input as DeriveInput); 
159    let ident_name = &ast.ident;
160
161    // Get the variants
162    let enum_data = match &ast.data {
163        syn::Data::Enum(data) => data,
164        _ => {
165            return syn::Error::new_spanned(
166                &ast, "AsApiError can only be derived for enums"
167            ).to_compile_error().into();
168        }
169    };
170    let variants_data = &enum_data.variants;
171
172    // Determine if any variant has an explicit 'msg' attribute.
173    // This will decide if a Display impl should be generated by this macro.
174    let mut any_variant_has_explicit_msg = false;
175    for v in variants_data.iter() {
176        match Opts::from_variant(v) {
177            Ok(opts) => {
178                if opts.msg.is_some() {
179                    any_variant_has_explicit_msg = true;
180                    break;
181                }
182            }
183            Err(e) => return TokenStream::from(e.write_errors()), // Propagate error from Opts parsing
184        }
185    }
186
187    // Generate the match arms for the as_api_error method
188    let match_arms_results: Vec<Result<proc_macro2::TokenStream, syn::Error>> = variants_data.iter().map(|v| {
189        let variant_ident = &v.ident;
190        
191        // Determine the pattern for matching fields
192        let field_pats = match &v.fields {
193            syn::Fields::Unnamed(f) => {
194                let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
195                quote! { ( #( #idents ),* ) }
196            }
197            syn::Fields::Named(f) => {
198                let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
199                quote! { { #( #idents ),* } }
200            }
201            syn::Fields::Unit => quote! {},
202        };
203
204        let opts = match Opts::from_variant(&v) {
205            Ok(opts) => opts,
206            Err(e) => return Err(e.into()),
207        };
208            
209        let status_code_val = if let Some(code) = opts.code {
210            code
211        } else if let Some(ref error_kind_str) = opts.status {
212            match error_kind_str.as_str() {
213                "BadRequest" => 400,
214                "Unauthorized" => 401,
215                "Forbidden" => 403,
216                "NotFound" => 404,
217                "MethodNotAllowed" => 405,
218                "Conflict" => 409,
219                "Gone" => 410,
220                "PayloadTooLarge" => 413,
221                "UnsupportedMediaType" => 415,
222                "UnprocessableEntity" => 422,
223                "TooManyRequests" => 429,
224                "InternalServerError" => 500,
225                "NotImplemented" => 501,
226                "BadGateway" => 502,
227                "ServiceUnavailable" => 503,
228                "GatewayTimeout" => 504,
229                _ => {
230                    // Handle unknown status string
231                    return Err(syn::Error::new_spanned(
232                        // Span to where 'status = "..."' would be, or the variant if not directly available
233                        v, // Spanning to the variant is a good approximation
234                        format!("Invalid status attribute \"{}\" for variant {}. Supported values are: BadRequest, Unauthorized, etc.", error_kind_str, variant_ident),
235                    ));
236                }
237            }
238        } else {
239            500 // Default status code
240        };
241        
242        // Validate status code
243        if let Err(e) = actix_web::http::StatusCode::from_u16(status_code_val) {
244             return Err(syn::Error::new_spanned(
245                 &v.ident, 
246                 format!("Invalid status code {} for variant {}: {}", status_code_val, variant_ident, e)
247            )); // Removed .into() as to_compile_error is not needed here
248        }
249        
250        let kind_str = opts.kind.unwrap_or_else(|| variant_ident.to_string().to_case(Case::Snake));
251
252        // Generate the message expression
253        let message_expr = match opts.msg {
254            Some(ref msg_s) => {
255                match &v.fields {
256                    syn::Fields::Unnamed(f) => {
257                        // For unnamed fields, format if msg_s contains placeholders and there are fields.
258                        // The 'ignore' attribute does not prevent formatting for unnamed fields here.
259                        if f.unnamed.is_empty() || !msg_s.contains('{') { // Heuristic: check for presence of '{'
260                            quote! { #msg_s.to_owned() } // Treat as literal
261                        } else {
262                            let field_vars_for_format = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
263                            quote! { format!(#msg_s, #( #field_vars_for_format ),*) }
264                        }
265                    }
266                    syn::Fields::Named(f) => {
267                        // For named fields, format only if 'ignore' is false, msg_s has placeholders, and there are fields.
268                        if opts.ignore || f.named.is_empty() || !msg_s.contains('{') { // Heuristic: check for presence of '{'
269                            quote! { #msg_s.to_owned() } // Treat as literal
270                        } else {
271                            let named_field_idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
272                            let format_assignments = named_field_idents.map(|ident| quote! { #ident = #ident }).collect::<Vec<_>>();
273                            quote! { format!(#msg_s, #( #format_assignments ),*) }
274                        }
275                    }
276                    syn::Fields::Unit => {
277                        // For unit variants, msg_s is always used as a literal string.
278                        quote! { #msg_s.to_owned() }
279                    }
280                }
281            }
282            None => {
283                // If no `msg` attribute is provided in `api_error`:
284                if any_variant_has_explicit_msg {
285                    // If the macro is generating a Display impl for this enum (because some other variant has a msg),
286                    // we default to the variant's name to avoid recursion with the macro-generated Display.
287                    // This matches test expectations for variants like ErrorEn::MissingMessageVariant.
288                    let variant_name_str = variant_ident.to_string();
289                    quote! { #variant_name_str.to_owned() }
290                } else {
291                    // If the macro is NOT generating a Display impl (no variant has any msg attribute),
292                    // we delegate to self.to_string() to allow using an external Display (e.g., from thiserror).
293                    // This matches test expectations for enums like ErrorWithThiserrorDisplay.
294                    quote! { self.to_string() }
295                }
296            }
297        };
298        
299        let mut details_expr = quote! { None };
300
301        // Automatic detection of a field to be used for 'details'.
302        // This logic applies if the variant is not a 'group' error.
303        if !opts.group {
304            match &v.fields {
305                syn::Fields::Named(fields_named) => {
306                    for field in &fields_named.named {
307                        if let Some(field_ident) = &field.ident {
308                            let field_ty = &field.ty;
309                            let type_string = quote!(#field_ty).to_string().replace(" ", ""); // Normalize spaces
310
311                            if type_string == "Option<serde_json::Value>" || type_string == "std::option::Option<serde_json::Value>" {
312                                details_expr = quote! { #field_ident.clone() };
313                                break; // Use the first found Option<serde_json::Value> field
314                            } else if type_string == "serde_json::Value" {
315                                details_expr = quote! { Some(#field_ident.clone()) };
316                                break; // Use the first found serde_json::Value field
317                            }
318                        }
319                    }
320                }
321                syn::Fields::Unnamed(fields_unnamed) => {
322                    for (i, field) in fields_unnamed.unnamed.iter().enumerate() {
323                        let field_ty = &field.ty;
324                        let field_pat_ident = format_ident!("a{}", i); // Field pattern is a0, a1, etc.
325                        let type_string = quote!(#field_ty).to_string().replace(" ", ""); // Normalize spaces
326
327                        if type_string == "Option<serde_json::Value>" || type_string == "std::option::Option<serde_json::Value>" {
328                            details_expr = quote! { #field_pat_ident.clone() };
329                            break; // Use the first found Option<serde_json::Value> field
330                        } else if type_string == "serde_json::Value" {
331                            details_expr = quote! { Some(#field_pat_ident.clone()) };
332                            break; // Use the first found serde_json::Value field
333                        }
334                    }
335                }
336                syn::Fields::Unit => {
337                    // Unit variants cannot have details fields.
338                }
339            }
340        }
341        
342        // Generate the ApiError construction call
343        let api_error_call = if opts.group {
344            // Assumes the first field of a tuple variant is 'a0' if 'group' is true
345            let group_var = format_ident!("a0"); 
346            quote! { #group_var.as_api_error() }
347        } else {
348            quote! { ApiError::new(#status_code_val, #kind_str, #message_expr, #details_expr) } 
349        };
350
351        // If fields are destructured by field_pats but not necessarily used directly in api_error_call
352        // (e.g. if message comes from self.to_string() or variant_name),
353        // this dummy assignment helps to silence "unused variable" warnings.
354        let dummy_field_usage = match (opts.msg.is_none(), &v.fields) {
355            (true, syn::Fields::Unnamed(f)) if !f.unnamed.is_empty() && !opts.group => {
356                let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
357                quote! { let _ = (#( #idents ),*); }
358            }
359            (true, syn::Fields::Named(f)) if !f.named.is_empty() && !opts.group => {
360                let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
361                quote! { let _ = (#( #idents ),*); }
362            }
363            _ => quote! {}, // No dummy usage needed if msg is Some, or it's a unit variant, or a group error
364        };
365
366        Ok(quote! {
367            #ident_name::#variant_ident #field_pats => {
368                #dummy_field_usage
369                #api_error_call
370            }
371        })
372    }).collect();
373
374    // Handle any errors that occurred during match arm generation
375    let mut compiled_match_arms = Vec::new();
376    for result in match_arms_results {
377        match result {
378            Ok(ts) => compiled_match_arms.push(ts),
379            Err(e) => return TokenStream::from(e.to_compile_error()),
380        }
381    }
382
383    // Conditionally generate Display implementation for the enum.
384    // It's generated if any variant has an explicit 'msg' attribute.
385    // Otherwise, the user is expected to provide Display (e.g., via thiserror).
386    let display_impl_block = if any_variant_has_explicit_msg {
387        quote! {
388            impl std::fmt::Display for #ident_name {
389                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390                    // The message for display should be consistent with ApiError's message.
391                    // This message is constructed within the as_api_error method for each variant,
392                    // which itself might call self.to_string() if a variant has no 'msg' attribute.
393                    write!(f, "{}", self.as_api_error().message)
394                }
395            }
396        }
397    } else {
398        quote! {} // Empty if no variant has an explicit 'msg' attribute.
399    };
400
401    // Generate the final implementations
402    let expanded = quote! {
403        impl AsApiErrorTrait for #ident_name {
404            fn as_api_error(&self) -> ApiError {
405                match self {
406                    #(#compiled_match_arms)*
407                }
408            }
409        }
410
411        #display_impl_block // Include Display impl only if any_variant_has_explicit_msg is true
412
413        // The user is expected to provide Debug, e.g., via #[derive(Debug)]
414        // No Debug impl generated by this macro.
415    
416        impl actix_web::ResponseError for #ident_name {
417            fn status_code(&self) -> actix_web::http::StatusCode {
418                // Delegate to the status_code method of the ApiError generated from this enum variant.
419                self.as_api_error().status_code()
420            }
421        
422            fn error_response(&self) -> actix_web::HttpResponse {
423                // Delegate to the error_response method of the ApiError generated from this enum variant.
424                // This will ensure the ApiError struct (with kind, message, details) is serialized.
425                self.as_api_error().error_response()
426            }
427        }
428    };
429
430    TokenStream::from(expanded)
431}