actix_error_derive/lib.rs
1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3use proc_macro::TokenStream;
4use quote::{quote, format_ident};
5use convert_case::{Case, Casing};
6
7#[derive(FromVariant, Default)]
8#[darling(default, attributes(api_error))]
9struct Opts {
10 code: Option<u16>,
11 status: Option<String>,
12 kind: Option<String>,
13 msg: Option<String>,
14 ignore: bool,
15 group: bool,
16}
17
18
19/// Derives the `AsApiErrorTrait` for an enum, allowing it to be converted into an `ApiError`
20/// suitable for Actix-Web responses. It also conditionally implements `std::fmt::Display`.
21///
22/// ## Attributes
23///
24/// Attributes are placed on enum variants using `#[api_error(...)]`:
25///
26/// - `code = <u16>`: Specifies a raw HTTP status code (e.g., `code = 404`).
27/// If both `code` and `status` are provided, `code` takes precedence.
28///
29/// - `status = "<StatusCodeString>"`: Specifies the HTTP status using a predefined string.
30/// (e.g., `status = "NotFound"`). See below for a list of supported strings.
31/// If neither `code` nor `status` is provided, defaults to `500` (Internal Server Error).
32///
33/// - `kind = "<string>"`: Sets the `kind` field in the `ApiError`.
34/// Defaults to the `snake_case` version of the variant name (e.g., `MyVariant` becomes `"my_variant"`).
35///
36/// - `msg = "<string>"`: Provides a custom error message.
37/// - For variants with named fields: `msg = "Error for {field_name}"`.
38/// - For variants with unnamed (tuple) fields: `msg = "Error with value {0} and {1}"`.
39/// - If `msg` is not provided, the message is generated based on the `Display` trait:
40/// - If this macro generates `Display` (see "Conditional `std::fmt::Display` Implementation" below),
41/// it will be the variant name or a simple format derived from it.
42/// - If the user provides `Display` (e.g., via `thiserror`), that implementation is used (`self.to_string()`).
43///
44/// - `ignore = <bool>`: (Default: `false`)
45/// - If `true`, `msg` is *not* provided, and the macro does *not* generate `Display`,
46/// the message will be the variant name, and fields will not be automatically formatted into the message.
47/// - This attribute does *not* prevent field interpolation if a `msg` attribute *is* provided
48/// (e.g., `#[api_error(msg = "Value: {0}", ignore)] MyVariant(i32)` will still print the value).
49/// - Its primary use is to simplify the message to just the variant name when no `msg` is given
50/// and `Display` is not generated by this macro, overriding default field formatting.
51///
52/// - `group = <bool>`: (Default: `false`)
53/// - If `true`, the variant is expected to hold a single field that itself implements `AsApiErrorTrait`.
54/// The `as_api_error()` method of this inner error will be called.
55/// Other attributes like `code`, `status`, `msg`, `kind` on the group variant are ignored.
56///
57/// ## Automatic `details` Field Population
58///
59/// If a variant is *not* a `group` and contains a single field of type `serde_json::Value`
60/// or `Option<serde_json::Value>`, this field's value will automatically populate the
61/// `details` field of the generated `ApiError`.
62///
63/// ## Conditional `std::fmt::Display` Implementation
64///
65/// The `std::fmt::Display` trait is implemented for the enum by this macro *if and only if*
66/// at least one variant has an explicit `#[api_error(msg = "...")]` attribute.
67/// - If implemented by the macro:
68/// - Variants with `msg` will use that formatted message for their `Display` output.
69/// - Variants without `msg` will display as their variant name (e.g., `MyEnum::VariantName` displays as "VariantName").
70///
71/// If no variants use `#[api_error(msg = "...")]`, you are expected to provide your own
72/// `Display` implementation (e.g., using the `thiserror` crate or manually).
73/// The `as_api_error` method will then use `self.to_string()` for the `ApiError` message if `msg` is not set on the variant.
74///
75/// ## Supported `status` Strings and Their Codes
76///
77/// ```rust
78/// // "BadRequest" => 400
79/// // "Unauthorized" => 401
80/// // "Forbidden" => 403
81/// // "NotFound" => 404
82/// // "MethodNotAllowed" => 405
83/// // "Conflict" => 409
84/// // "Gone" => 410
85/// // "PayloadTooLarge" => 413
86/// // "UnsupportedMediaType" => 415
87/// // "UnprocessableEntity" => 422
88/// // "TooManyRequests" => 429
89/// // "InternalServerError" => 500 (Default if no code/status is specified)
90/// // "NotImplemented" => 501
91/// // "BadGateway" => 502
92/// // "ServiceUnavailable" => 503
93/// // "GatewayTimeout" => 504
94/// ```
95/// Using an unsupported string in `status` will result in a compile-time error.
96///
97/// ## Example
98///
99/// ```rust
100/// use actix_error_derive::AsApiError;
101/// // Ensure ApiError and AsApiErrorTrait are in scope, typically via:
102/// // use actix_error::{ApiError, AsApiErrorTrait};
103/// use serde_json::json;
104///
105/// // Dummy AnotherErrorType for the group example
106/// #[derive(Debug)]
107/// pub struct AnotherErrorType;
108/// impl actix_error::AsApiErrorTrait for AnotherErrorType {
109/// fn as_api_error(&self) -> actix_error::ApiError {
110/// actix_error::ApiError::new(401, "auth_failure", "Authentication failed".to_string(), None)
111/// }
112/// }
113/// impl std::fmt::Display for AnotherErrorType {
114/// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115/// write!(f, "AnotherErrorType: Authentication Failed")
116/// }
117/// }
118///
119/// #[derive(Debug, AsApiError)]
120/// pub enum MyError {
121/// #[api_error(status = "NotFound", msg = "Resource not found.")]
122/// NotFound, // Display will be "Resource not found."
123///
124/// // No msg, so if Display is macro-generated, it's "InvalidInput".
125/// // If user provides Display (e.g. with thiserror), that's used for ApiError.message.
126/// #[api_error(code = 400, kind = "input_validation")]
127/// InvalidInput { field: String, reason: String },
128///
129/// #[api_error(status = "UnprocessableEntity", msg = "Cannot process item: {0}")]
130/// Unprocessable(String), // Display will be "Cannot process item: <value>"
131///
132/// // 'details' will be auto-populated from the serde_json::Value field.
133/// // msg is present, so Display is "Detailed error occurred."
134/// #[api_error(status = "BadRequest", msg = "Detailed error occurred.")]
135/// DetailedError(serde_json::Value),
136///
137/// #[api_error(group)]
138/// AuthError(AnotherErrorType), // Delegates to AnotherErrorType's AsApiErrorTrait
139/// }
140///
141/// // Since MyError has variants with `msg`, `Display` is generated by AsApiError.
142/// // If no variants had `msg`, you would need to implement `Display` manually or with `thiserror`:
143/// //
144/// // #[derive(Debug, AsApiError, thiserror::Error)] // Example with thiserror
145/// // pub enum MyErrorWithoutMacroDisplay {
146/// // #[error("Item {0} was not found")] // thiserror message
147/// // #[api_error(status = "NotFound")]
148/// // NotFound(String),
149/// //
150/// // #[error("Input is invalid: {reason}")]
151/// // #[api_error(code = 400, kind = "bad_input")]
152/// // InvalidInput { reason: String }
153/// // }
154/// ```
155#[proc_macro_derive(AsApiError, attributes(api_error))]
156pub fn derive(input: TokenStream) -> TokenStream {
157 // Parse the input tokens into a syntax tree
158 let ast = parse_macro_input!(input as DeriveInput);
159 let ident_name = &ast.ident;
160
161 // Get the variants
162 let enum_data = match &ast.data {
163 syn::Data::Enum(data) => data,
164 _ => {
165 return syn::Error::new_spanned(
166 &ast, "AsApiError can only be derived for enums"
167 ).to_compile_error().into();
168 }
169 };
170 let variants_data = &enum_data.variants;
171
172 // Determine if any variant has an explicit 'msg' attribute.
173 // This will decide if a Display impl should be generated by this macro.
174 let mut any_variant_has_explicit_msg = false;
175 for v in variants_data.iter() {
176 match Opts::from_variant(v) {
177 Ok(opts) => {
178 if opts.msg.is_some() {
179 any_variant_has_explicit_msg = true;
180 break;
181 }
182 }
183 Err(e) => return TokenStream::from(e.write_errors()), // Propagate error from Opts parsing
184 }
185 }
186
187 // Generate the match arms for the as_api_error method
188 let match_arms_results: Vec<Result<proc_macro2::TokenStream, syn::Error>> = variants_data.iter().map(|v| {
189 let variant_ident = &v.ident;
190
191 // Determine the pattern for matching fields
192 let field_pats = match &v.fields {
193 syn::Fields::Unnamed(f) => {
194 let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
195 quote! { ( #( #idents ),* ) }
196 }
197 syn::Fields::Named(f) => {
198 let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
199 quote! { { #( #idents ),* } }
200 }
201 syn::Fields::Unit => quote! {},
202 };
203
204 let opts = match Opts::from_variant(&v) {
205 Ok(opts) => opts,
206 Err(e) => return Err(e.into()),
207 };
208
209 let status_code_val = if let Some(code) = opts.code {
210 code
211 } else if let Some(ref error_kind_str) = opts.status {
212 match error_kind_str.as_str() {
213 "BadRequest" => 400,
214 "Unauthorized" => 401,
215 "Forbidden" => 403,
216 "NotFound" => 404,
217 "MethodNotAllowed" => 405,
218 "Conflict" => 409,
219 "Gone" => 410,
220 "PayloadTooLarge" => 413,
221 "UnsupportedMediaType" => 415,
222 "UnprocessableEntity" => 422,
223 "TooManyRequests" => 429,
224 "InternalServerError" => 500,
225 "NotImplemented" => 501,
226 "BadGateway" => 502,
227 "ServiceUnavailable" => 503,
228 "GatewayTimeout" => 504,
229 _ => {
230 // Handle unknown status string
231 return Err(syn::Error::new_spanned(
232 // Span to where 'status = "..."' would be, or the variant if not directly available
233 v, // Spanning to the variant is a good approximation
234 format!("Invalid status attribute \"{}\" for variant {}. Supported values are: BadRequest, Unauthorized, etc.", error_kind_str, variant_ident),
235 ));
236 }
237 }
238 } else {
239 500 // Default status code
240 };
241
242 // Validate status code
243 if let Err(e) = actix_web::http::StatusCode::from_u16(status_code_val) {
244 return Err(syn::Error::new_spanned(
245 &v.ident,
246 format!("Invalid status code {} for variant {}: {}", status_code_val, variant_ident, e)
247 )); // Removed .into() as to_compile_error is not needed here
248 }
249
250 let kind_str = opts.kind.unwrap_or_else(|| variant_ident.to_string().to_case(Case::Snake));
251
252 // Generate the message expression
253 let message_expr = match opts.msg {
254 Some(ref msg_s) => {
255 match &v.fields {
256 syn::Fields::Unnamed(f) => {
257 // For unnamed fields, format if msg_s contains placeholders and there are fields.
258 // The 'ignore' attribute does not prevent formatting for unnamed fields here.
259 if f.unnamed.is_empty() || !msg_s.contains('{') { // Heuristic: check for presence of '{'
260 quote! { #msg_s.to_owned() } // Treat as literal
261 } else {
262 let field_vars_for_format = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
263 quote! { format!(#msg_s, #( #field_vars_for_format ),*) }
264 }
265 }
266 syn::Fields::Named(f) => {
267 // For named fields, format only if 'ignore' is false, msg_s has placeholders, and there are fields.
268 if opts.ignore || f.named.is_empty() || !msg_s.contains('{') { // Heuristic: check for presence of '{'
269 quote! { #msg_s.to_owned() } // Treat as literal
270 } else {
271 let named_field_idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
272 let format_assignments = named_field_idents.map(|ident| quote! { #ident = #ident }).collect::<Vec<_>>();
273 quote! { format!(#msg_s, #( #format_assignments ),*) }
274 }
275 }
276 syn::Fields::Unit => {
277 // For unit variants, msg_s is always used as a literal string.
278 quote! { #msg_s.to_owned() }
279 }
280 }
281 }
282 None => {
283 // If no `msg` attribute is provided in `api_error`:
284 if any_variant_has_explicit_msg {
285 // If the macro is generating a Display impl for this enum (because some other variant has a msg),
286 // we default to the variant's name to avoid recursion with the macro-generated Display.
287 // This matches test expectations for variants like ErrorEn::MissingMessageVariant.
288 let variant_name_str = variant_ident.to_string();
289 quote! { #variant_name_str.to_owned() }
290 } else {
291 // If the macro is NOT generating a Display impl (no variant has any msg attribute),
292 // we delegate to self.to_string() to allow using an external Display (e.g., from thiserror).
293 // This matches test expectations for enums like ErrorWithThiserrorDisplay.
294 quote! { self.to_string() }
295 }
296 }
297 };
298
299 let mut details_expr = quote! { None };
300
301 // Automatic detection of a field to be used for 'details'.
302 // This logic applies if the variant is not a 'group' error.
303 if !opts.group {
304 match &v.fields {
305 syn::Fields::Named(fields_named) => {
306 for field in &fields_named.named {
307 if let Some(field_ident) = &field.ident {
308 let field_ty = &field.ty;
309 let type_string = quote!(#field_ty).to_string().replace(" ", ""); // Normalize spaces
310
311 if type_string == "Option<serde_json::Value>" || type_string == "std::option::Option<serde_json::Value>" {
312 details_expr = quote! { #field_ident.clone() };
313 break; // Use the first found Option<serde_json::Value> field
314 } else if type_string == "serde_json::Value" {
315 details_expr = quote! { Some(#field_ident.clone()) };
316 break; // Use the first found serde_json::Value field
317 }
318 }
319 }
320 }
321 syn::Fields::Unnamed(fields_unnamed) => {
322 for (i, field) in fields_unnamed.unnamed.iter().enumerate() {
323 let field_ty = &field.ty;
324 let field_pat_ident = format_ident!("a{}", i); // Field pattern is a0, a1, etc.
325 let type_string = quote!(#field_ty).to_string().replace(" ", ""); // Normalize spaces
326
327 if type_string == "Option<serde_json::Value>" || type_string == "std::option::Option<serde_json::Value>" {
328 details_expr = quote! { #field_pat_ident.clone() };
329 break; // Use the first found Option<serde_json::Value> field
330 } else if type_string == "serde_json::Value" {
331 details_expr = quote! { Some(#field_pat_ident.clone()) };
332 break; // Use the first found serde_json::Value field
333 }
334 }
335 }
336 syn::Fields::Unit => {
337 // Unit variants cannot have details fields.
338 }
339 }
340 }
341
342 // Generate the ApiError construction call
343 let api_error_call = if opts.group {
344 // Assumes the first field of a tuple variant is 'a0' if 'group' is true
345 let group_var = format_ident!("a0");
346 quote! { #group_var.as_api_error() }
347 } else {
348 quote! { ApiError::new(#status_code_val, #kind_str, #message_expr, #details_expr) }
349 };
350
351 // If fields are destructured by field_pats but not necessarily used directly in api_error_call
352 // (e.g. if message comes from self.to_string() or variant_name),
353 // this dummy assignment helps to silence "unused variable" warnings.
354 let dummy_field_usage = match (opts.msg.is_none(), &v.fields) {
355 (true, syn::Fields::Unnamed(f)) if !f.unnamed.is_empty() && !opts.group => {
356 let idents = f.unnamed.iter().enumerate().map(|(i, _)| format_ident!("a{}", i));
357 quote! { let _ = (#( #idents ),*); }
358 }
359 (true, syn::Fields::Named(f)) if !f.named.is_empty() && !opts.group => {
360 let idents = f.named.iter().map(|field| field.ident.as_ref().unwrap());
361 quote! { let _ = (#( #idents ),*); }
362 }
363 _ => quote! {}, // No dummy usage needed if msg is Some, or it's a unit variant, or a group error
364 };
365
366 Ok(quote! {
367 #ident_name::#variant_ident #field_pats => {
368 #dummy_field_usage
369 #api_error_call
370 }
371 })
372 }).collect();
373
374 // Handle any errors that occurred during match arm generation
375 let mut compiled_match_arms = Vec::new();
376 for result in match_arms_results {
377 match result {
378 Ok(ts) => compiled_match_arms.push(ts),
379 Err(e) => return TokenStream::from(e.to_compile_error()),
380 }
381 }
382
383 // Conditionally generate Display implementation for the enum.
384 // It's generated if any variant has an explicit 'msg' attribute.
385 // Otherwise, the user is expected to provide Display (e.g., via thiserror).
386 let display_impl_block = if any_variant_has_explicit_msg {
387 quote! {
388 impl std::fmt::Display for #ident_name {
389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 // The message for display should be consistent with ApiError's message.
391 // This message is constructed within the as_api_error method for each variant,
392 // which itself might call self.to_string() if a variant has no 'msg' attribute.
393 write!(f, "{}", self.as_api_error().message)
394 }
395 }
396 }
397 } else {
398 quote! {} // Empty if no variant has an explicit 'msg' attribute.
399 };
400
401 // Generate the final implementations
402 let expanded = quote! {
403 impl AsApiErrorTrait for #ident_name {
404 fn as_api_error(&self) -> ApiError {
405 match self {
406 #(#compiled_match_arms)*
407 }
408 }
409 }
410
411 #display_impl_block // Include Display impl only if any_variant_has_explicit_msg is true
412
413 // The user is expected to provide Debug, e.g., via #[derive(Debug)]
414 // No Debug impl generated by this macro.
415
416 impl actix_web::ResponseError for #ident_name {
417 fn status_code(&self) -> actix_web::http::StatusCode {
418 // Delegate to the status_code method of the ApiError generated from this enum variant.
419 self.as_api_error().status_code()
420 }
421
422 fn error_response(&self) -> actix_web::HttpResponse {
423 // Delegate to the error_response method of the ApiError generated from this enum variant.
424 // This will ensure the ApiError struct (with kind, message, details) is serialized.
425 self.as_api_error().error_response()
426 }
427 }
428 };
429
430 TokenStream::from(expanded)
431}