Skip to main content

alien_error_derive/
lib.rs

1//! Procedural macro that powers `alien-error`.
2//!
3//! Usage:
4//! ```rust
5//! use alien_error::AlienErrorData;
6//! #[derive(Debug, AlienErrorData)]
7//! enum MyError {
8//!     #[error(code = "SOMETHING_WRONG", message = "something went wrong", retryable = "false", internal = "false", http_status_code = 420)]
9//!     Oops,
10//! }
11//! ```
12//! The `error(...)` attribute supplies compile-time metadata:
13//! • `code`              – short machine friendly identifier (defaults to variant name).
14//! • `message`           – human-readable error message with field interpolation.
15//! • `retryable`         – flag set to `true` if the operation can be retried.
16//! • `internal`          – flag set to `true` if this error should not be exposed.
17//! • `http_status_code`  – HTTP status code for this error (defaults to 500).
18//!
19//! The macro also auto-implements `AlienErrorData` including a `context()` method
20//! that serialises variant fields into a JSON map for diagnostic payloads.
21
22use proc_macro::TokenStream;
23use quote::quote;
24use syn::{parse_macro_input, Attribute, Data, DeriveInput};
25
26#[proc_macro_derive(AlienErrorData, attributes(error))]
27pub fn derive_alien_error(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29    let name = input.ident;
30
31    let (
32        code_match_arms,
33        retryable_match_arms,
34        internal_match_arms,
35        http_status_code_match_arms,
36        context_match_arms,
37        message_match_arms,
38        hint_match_arms,
39        retryable_inherit_match_arms,
40        internal_inherit_match_arms,
41        http_status_code_inherit_match_arms,
42        human_layer_presentation_match_arms,
43    ) = match input.data {
44        Data::Enum(ref data_enum) => {
45            let mut code_arms = Vec::new();
46            let mut retryable_arms = Vec::new();
47            let mut internal_arms = Vec::new();
48            let mut http_status_code_arms = Vec::new();
49            let mut context_arms = Vec::new();
50            let mut message_arms = Vec::new();
51            let mut hint_arms = Vec::new();
52            let mut retryable_inherit_arms = Vec::new();
53            let mut internal_inherit_arms = Vec::new();
54            let mut http_status_code_inherit_arms = Vec::new();
55            let mut human_layer_presentation_arms = Vec::new();
56
57            for variant in &data_enum.variants {
58                let ident = &variant.ident;
59
60                let (
61                    code_val,
62                    retryable_val,
63                    internal_val,
64                    http_status_code_val,
65                    message_val,
66                    hint_val,
67                    retryable_inherit,
68                    internal_inherit,
69                    http_status_code_inherit,
70                    human_layer_presentation,
71                ) = parse_error_attrs(&variant.attrs, ident.to_string());
72
73                let matcher = if variant.fields.is_empty() {
74                    quote! { #name::#ident }
75                } else {
76                    quote! { #name::#ident { .. } }
77                };
78
79                let code_lit = code_val;
80                let retry_bool = retryable_val;
81                let internal_bool = internal_val;
82                let http_status_code_u16 = http_status_code_val;
83
84                code_arms.push(quote! { #matcher => #code_lit });
85                retryable_arms.push(quote! { #matcher => #retry_bool });
86                internal_arms.push(quote! { #matcher => #internal_bool });
87                http_status_code_arms.push(quote! { #matcher => #http_status_code_u16 });
88                retryable_inherit_arms.push(quote! { #matcher => #retryable_inherit });
89                internal_inherit_arms.push(quote! { #matcher => #internal_inherit });
90                http_status_code_inherit_arms
91                    .push(quote! { #matcher => #http_status_code_inherit });
92                human_layer_presentation_arms
93                    .push(quote! { #matcher => #human_layer_presentation });
94
95                // Generate message arm with field interpolation
96                match &variant.fields {
97                    syn::Fields::Named(fields_named) if !fields_named.named.is_empty() => {
98                        // SAFETY: Named fields are guaranteed to have identifiers.
99                        // The Option exists only for compatibility with tuple struct fields.
100                        let field_idents: Vec<_> = fields_named
101                            .named
102                            .iter()
103                            .map(|f| f.ident.as_ref().unwrap())
104                            .collect();
105                        let matcher = quote! { #name::#ident { #( ref #field_idents ),* } };
106
107                        // Generate message with field interpolation
108                        let interpolated_message =
109                            generate_message_interpolation(&message_val, &field_idents);
110                        message_arms.push(quote! { #matcher => #interpolated_message });
111                        let interpolated_hint =
112                            generate_optional_interpolation(&hint_val, &field_idents);
113                        hint_arms.push(quote! { #matcher => #interpolated_hint });
114
115                        context_arms.push(quote! { #matcher => {
116                            let mut map = serde_json::Map::new();
117                            #( map.insert(
118                                stringify!(#field_idents).to_string(),
119                                serde_json::to_value(#field_idents)
120                                    .expect(&format!("Failed to serialize field '{}' to JSON. This field must implement Serialize correctly.", stringify!(#field_idents)))
121                            ); )*
122                            Some(serde_json::Value::Object(map))
123                        } });
124                    }
125                    _ => {
126                        let matcher = if variant.fields.is_empty() {
127                            quote! { #name::#ident }
128                        } else {
129                            quote! { #name::#ident { .. } }
130                        };
131                        message_arms.push(quote! { #matcher => #message_val.to_string() });
132                        let interpolated_hint = generate_optional_interpolation(&hint_val, &[]);
133                        hint_arms.push(quote! { #matcher => #interpolated_hint });
134                        context_arms.push(quote! { #matcher => None });
135                    }
136                }
137            }
138            (
139                code_arms,
140                retryable_arms,
141                internal_arms,
142                http_status_code_arms,
143                context_arms,
144                message_arms,
145                hint_arms,
146                retryable_inherit_arms,
147                internal_inherit_arms,
148                http_status_code_inherit_arms,
149                human_layer_presentation_arms,
150            )
151        }
152        _ => {
153            return quote! { compile_error!("AlienErrorData can only be derived for enums"); }
154                .into();
155        }
156    };
157
158    let expanded = quote! {
159        impl alien_error::AlienErrorData for #name {
160            fn code(&self) -> &'static str {
161                match self {
162                    #(#code_match_arms),*
163                }
164            }
165            fn retryable(&self) -> bool {
166                match self {
167                    #(#retryable_match_arms),*
168                }
169            }
170            fn internal(&self) -> bool {
171                match self {
172                    #(#internal_match_arms),*
173                }
174            }
175            fn http_status_code(&self) -> u16 {
176                match self {
177                    #(#http_status_code_match_arms),*
178                }
179            }
180            fn message(&self) -> String {
181                match self {
182                    #(#message_match_arms),*
183                }
184            }
185            fn hint(&self) -> Option<String> {
186                match self {
187                    #(#hint_match_arms),*
188                }
189            }
190            fn context(&self) -> Option<serde_json::Value> {
191                match self {
192                    #(#context_match_arms),*
193                }
194            }
195            fn retryable_inherit(&self) -> Option<bool> {
196                match self {
197                    #(#retryable_inherit_match_arms),*
198                }
199            }
200            fn internal_inherit(&self) -> Option<bool> {
201                match self {
202                    #(#internal_inherit_match_arms),*
203                }
204            }
205            fn http_status_code_inherit(&self) -> Option<u16> {
206                match self {
207                    #(#http_status_code_inherit_match_arms),*
208                }
209            }
210            fn human_layer_presentation(&self) -> alien_error::HumanLayerPresentation {
211                match self {
212                    #(#human_layer_presentation_match_arms),*
213                }
214            }
215        }
216
217        impl std::fmt::Display for #name {
218            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219                write!(f, "{}", self.message())
220            }
221        }
222    };
223
224    TokenStream::from(expanded)
225}
226
227fn parse_error_attrs(
228    attrs: &[Attribute],
229    default_code: String,
230) -> (
231    proc_macro2::TokenStream,
232    proc_macro2::TokenStream,
233    proc_macro2::TokenStream,
234    proc_macro2::TokenStream,
235    String,
236    Option<String>,
237    proc_macro2::TokenStream,
238    proc_macro2::TokenStream,
239    proc_macro2::TokenStream,
240    proc_macro2::TokenStream,
241) {
242    let mut code = default_code;
243    let mut retryable: Option<String> = None;
244    let mut internal: Option<String> = None;
245    let mut http_status_code: Option<String> = None;
246    let mut message: Option<String> = None;
247    let mut hint: Option<String> = None;
248    let mut human: Option<String> = None;
249
250    for attr in attrs {
251        if !attr.path().is_ident("error") {
252            continue;
253        }
254        if let Err(e) = attr.parse_nested_meta(|meta| {
255            if meta.path.is_ident("code") {
256                let lit: syn::LitStr = meta.value()?.parse()?;
257                code = lit.value();
258                Ok(())
259            } else if meta.path.is_ident("retryable") {
260                let lit: syn::LitStr = meta.value()?.parse()?;
261                retryable = Some(lit.value());
262                Ok(())
263            } else if meta.path.is_ident("internal") {
264                let lit: syn::LitStr = meta.value()?.parse()?;
265                internal = Some(lit.value());
266                Ok(())
267            } else if meta.path.is_ident("http_status_code") {
268                // Parse the value as a literal (either string or int)
269                let value = meta.value()?;
270
271                // Try to parse as a literal expression
272                let lit: syn::Lit = value.parse()?;
273
274                match lit {
275                    syn::Lit::Str(lit_str) => {
276                        // String literal like "inherit" or "404"
277                        http_status_code = Some(lit_str.value());
278                    }
279                    syn::Lit::Int(lit_int) => {
280                        // Integer literal like 404
281                        let parsed_value = lit_int.base10_parse::<u16>()?;
282                        http_status_code = Some(parsed_value.to_string());
283                    }
284                    _ => {
285                        return Err(
286                            meta.error("http_status_code must be a string or integer literal")
287                        );
288                    }
289                }
290                Ok(())
291            } else if meta.path.is_ident("message") {
292                let lit: syn::LitStr = meta.value()?.parse()?;
293                message = Some(lit.value());
294                Ok(())
295            } else if meta.path.is_ident("hint") {
296                let lit: syn::LitStr = meta.value()?.parse()?;
297                hint = Some(lit.value());
298                Ok(())
299            } else if meta.path.is_ident("human") {
300                let lit: syn::LitStr = meta.value()?.parse()?;
301                human = Some(lit.value());
302                Ok(())
303            } else {
304                Err(meta.error("unsupported error attribute key"))
305            }
306        }) {
307            // Re-emit the actual syn error instead of a generic message
308            return (
309                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
310                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
311                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
312                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
313                String::new(),
314                None,
315                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
316                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
317                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
318                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
319            );
320        }
321    }
322
323    // ensure all required fields are specified
324    macro_rules! parse_flag {
325        ($val:expr,$name:expr) => {
326            match $val {
327                Some(ref s) if s == "true" => quote! { true },
328                Some(ref s) if s == "false" => quote! { false },
329                Some(ref s) if s == "inherit" => quote! { false }, // For backward compatibility in the main method
330                Some(ref _other) => syn::Error::new(proc_macro2::Span::call_site(), format!("{} must be \"true\", \"false\" or \"inherit\"", $name)).to_compile_error(),
331                None => syn::Error::new(proc_macro2::Span::call_site(), format!("{}=\"...\" is required in #[error(...)]", $name)).to_compile_error(),
332            }
333        };
334    }
335
336    // Parse inheritance flags from the same values
337    macro_rules! parse_inherit_flag {
338        ($val:expr) => {
339            match $val {
340                Some(ref s) if s == "inherit" => quote! { None },
341                Some(ref s) if s == "true" => quote! { Some(true) },
342                Some(ref s) if s == "false" => quote! { Some(false) },
343                Some(_) => quote! { Some(false) }, // fallback for any other value
344                None => syn::Error::new(proc_macro2::Span::call_site(), "flag is required")
345                    .to_compile_error(),
346            }
347        };
348    }
349
350    let retry_ts = parse_flag!(retryable.clone(), "retryable");
351    let internal_ts = parse_flag!(internal.clone(), "internal");
352    let retryable_inherit_ts = parse_inherit_flag!(retryable);
353    let internal_inherit_ts = parse_inherit_flag!(internal);
354
355    let code_ts = {
356        let lit = syn::LitStr::new(&code, proc_macro2::Span::call_site());
357        quote! { #lit }
358    };
359
360    // Parse HTTP status code with inheritance support
361    let (http_status_code_ts, http_status_code_inherit_ts) = match http_status_code {
362        Some(ref s) if s == "inherit" => {
363            // When inherit is specified, use 500 as the default but return None for inherit
364            (quote! { 500 }, quote! { None })
365        }
366        Some(ref s) => {
367            // Parse as number
368            match s.parse::<u16>() {
369                Ok(status_code) => (quote! { #status_code }, quote! { Some(#status_code) }),
370                Err(_) => (
371                    syn::Error::new(
372                        proc_macro2::Span::call_site(),
373                        "http_status_code must be a number or \"inherit\"",
374                    )
375                    .to_compile_error(),
376                    syn::Error::new(
377                        proc_macro2::Span::call_site(),
378                        "http_status_code must be a number or \"inherit\"",
379                    )
380                    .to_compile_error(),
381                ),
382            }
383        }
384        None => {
385            // Default to 500
386            (quote! { 500 }, quote! { Some(500) })
387        }
388    };
389
390    let message_str = message.unwrap_or_else(|| code.clone());
391    let human_ts = match human.as_deref() {
392        None | Some("normal") => quote! { alien_error::HumanLayerPresentation::Normal },
393        Some("transparent") => quote! { alien_error::HumanLayerPresentation::Transparent },
394        Some(other) => syn::Error::new(
395            proc_macro2::Span::call_site(),
396            format!(
397                "human must be \"normal\" or \"transparent\", got \"{}\"",
398                other
399            ),
400        )
401        .to_compile_error(),
402    };
403
404    (
405        code_ts,
406        retry_ts,
407        internal_ts,
408        http_status_code_ts,
409        message_str,
410        hint,
411        retryable_inherit_ts,
412        internal_inherit_ts,
413        http_status_code_inherit_ts,
414        human_ts,
415    )
416}
417
418fn generate_message_interpolation(
419    message_template: &str,
420    field_idents: &[&syn::Ident],
421) -> proc_macro2::TokenStream {
422    // Let Rust's format! macro handle the parsing - just pass the template and fields directly
423    // This leverages Rust's built-in format string parsing which handles all cases correctly
424
425    if field_idents.is_empty() {
426        quote! { #message_template.to_string() }
427    } else {
428        // Find which fields are actually used in the message template
429        let used_fields: Vec<&syn::Ident> = field_idents
430            .iter()
431            .filter(|field| {
432                let field_name = field.to_string();
433                message_template.contains(&format!("{{{}", field_name))
434            })
435            .cloned()
436            .collect();
437
438        if used_fields.is_empty() {
439            // No fields are referenced in the template
440            quote! { #message_template.to_string() }
441        } else {
442            // Use named parameters - only pass fields that are actually used
443            quote! { format!(#message_template, #(#used_fields = #used_fields),*) }
444        }
445    }
446}
447
448fn generate_optional_interpolation(
449    template: &Option<String>,
450    field_idents: &[&syn::Ident],
451) -> proc_macro2::TokenStream {
452    match template {
453        Some(template) => {
454            let interpolated = generate_message_interpolation(template, field_idents);
455            quote! { Some(#interpolated) }
456        }
457        None => quote! { None },
458    }
459}