Skip to main content

alien_error_derive/
lib.rs

1//! Procedural macro that powers `alien-error`.
2//!
3//! Usage:
4//! ```rust
5//! use alien_error::AlienErrorData;
6//! #[derive(Debug, AlienErrorData)]
7//! enum MyError {
8//!     #[error(code = "SOMETHING_WRONG", message = "something went wrong", retryable = "false", internal = "false", http_status_code = 420)]
9//!     Oops,
10//! }
11//! ```
12//! The `error(...)` attribute supplies compile-time metadata:
13//! • `code`              – short machine friendly identifier (defaults to variant name).
14//! • `message`           – human-readable error message with field interpolation.
15//! • `retryable`         – flag set to `true` if the operation can be retried.
16//! • `internal`          – flag set to `true` if this error should not be exposed.
17//! • `http_status_code`  – HTTP status code for this error (defaults to 500).
18//!
19//! The macro also auto-implements `AlienErrorData` including a `context()` method
20//! that serialises variant fields into a JSON map for diagnostic payloads.
21
22use proc_macro::TokenStream;
23use quote::quote;
24use syn::{parse_macro_input, Attribute, Data, DeriveInput};
25
26#[proc_macro_derive(AlienErrorData, attributes(error))]
27pub fn derive_alien_error(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29    let name = input.ident;
30
31    let (
32        code_match_arms,
33        retryable_match_arms,
34        internal_match_arms,
35        http_status_code_match_arms,
36        context_match_arms,
37        message_match_arms,
38        retryable_inherit_match_arms,
39        internal_inherit_match_arms,
40        http_status_code_inherit_match_arms,
41    ) = match input.data {
42        Data::Enum(ref data_enum) => {
43            let mut code_arms = Vec::new();
44            let mut retryable_arms = Vec::new();
45            let mut internal_arms = Vec::new();
46            let mut http_status_code_arms = Vec::new();
47            let mut context_arms = Vec::new();
48            let mut message_arms = Vec::new();
49            let mut retryable_inherit_arms = Vec::new();
50            let mut internal_inherit_arms = Vec::new();
51            let mut http_status_code_inherit_arms = Vec::new();
52
53            for variant in &data_enum.variants {
54                let ident = &variant.ident;
55
56                let (
57                    code_val,
58                    retryable_val,
59                    internal_val,
60                    http_status_code_val,
61                    message_val,
62                    retryable_inherit,
63                    internal_inherit,
64                    http_status_code_inherit,
65                ) = parse_error_attrs(&variant.attrs, ident.to_string());
66
67                let matcher = if variant.fields.is_empty() {
68                    quote! { #name::#ident }
69                } else {
70                    quote! { #name::#ident { .. } }
71                };
72
73                let code_lit = code_val;
74                let retry_bool = retryable_val;
75                let internal_bool = internal_val;
76                let http_status_code_u16 = http_status_code_val;
77
78                code_arms.push(quote! { #matcher => #code_lit });
79                retryable_arms.push(quote! { #matcher => #retry_bool });
80                internal_arms.push(quote! { #matcher => #internal_bool });
81                http_status_code_arms.push(quote! { #matcher => #http_status_code_u16 });
82                retryable_inherit_arms.push(quote! { #matcher => #retryable_inherit });
83                internal_inherit_arms.push(quote! { #matcher => #internal_inherit });
84                http_status_code_inherit_arms
85                    .push(quote! { #matcher => #http_status_code_inherit });
86
87                // Generate message arm with field interpolation
88                match &variant.fields {
89                    syn::Fields::Named(fields_named) if !fields_named.named.is_empty() => {
90                        // SAFETY: Named fields are guaranteed to have identifiers.
91                        // The Option exists only for compatibility with tuple struct fields.
92                        let field_idents: Vec<_> = fields_named
93                            .named
94                            .iter()
95                            .map(|f| f.ident.as_ref().unwrap())
96                            .collect();
97                        let matcher = quote! { #name::#ident { #( ref #field_idents ),* } };
98
99                        // Generate message with field interpolation
100                        let interpolated_message =
101                            generate_message_interpolation(&message_val, &field_idents);
102                        message_arms.push(quote! { #matcher => #interpolated_message });
103
104                        context_arms.push(quote! { #matcher => {
105                            let mut map = serde_json::Map::new();
106                            #( map.insert(
107                                stringify!(#field_idents).to_string(), 
108                                serde_json::to_value(#field_idents)
109                                    .expect(&format!("Failed to serialize field '{}' to JSON. This field must implement Serialize correctly.", stringify!(#field_idents)))
110                            ); )*
111                            Some(serde_json::Value::Object(map))
112                        } });
113                    }
114                    _ => {
115                        let matcher = if variant.fields.is_empty() {
116                            quote! { #name::#ident }
117                        } else {
118                            quote! { #name::#ident { .. } }
119                        };
120                        message_arms.push(quote! { #matcher => #message_val.to_string() });
121                        context_arms.push(quote! { #matcher => None });
122                    }
123                }
124            }
125            (
126                code_arms,
127                retryable_arms,
128                internal_arms,
129                http_status_code_arms,
130                context_arms,
131                message_arms,
132                retryable_inherit_arms,
133                internal_inherit_arms,
134                http_status_code_inherit_arms,
135            )
136        }
137        _ => {
138            return quote! { compile_error!("AlienErrorData can only be derived for enums"); }
139                .into();
140        }
141    };
142
143    let expanded = quote! {
144        impl alien_error::AlienErrorData for #name {
145            fn code(&self) -> &'static str {
146                match self {
147                    #(#code_match_arms),*
148                }
149            }
150            fn retryable(&self) -> bool {
151                match self {
152                    #(#retryable_match_arms),*
153                }
154            }
155            fn internal(&self) -> bool {
156                match self {
157                    #(#internal_match_arms),*
158                }
159            }
160            fn http_status_code(&self) -> u16 {
161                match self {
162                    #(#http_status_code_match_arms),*
163                }
164            }
165            fn message(&self) -> String {
166                match self {
167                    #(#message_match_arms),*
168                }
169            }
170            fn context(&self) -> Option<serde_json::Value> {
171                match self {
172                    #(#context_match_arms),*
173                }
174            }
175            fn retryable_inherit(&self) -> Option<bool> {
176                match self {
177                    #(#retryable_inherit_match_arms),*
178                }
179            }
180            fn internal_inherit(&self) -> Option<bool> {
181                match self {
182                    #(#internal_inherit_match_arms),*
183                }
184            }
185            fn http_status_code_inherit(&self) -> Option<u16> {
186                match self {
187                    #(#http_status_code_inherit_match_arms),*
188                }
189            }
190        }
191
192        impl std::fmt::Display for #name {
193            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194                write!(f, "{}", self.message())
195            }
196        }
197    };
198
199    TokenStream::from(expanded)
200}
201
202fn parse_error_attrs(
203    attrs: &[Attribute],
204    default_code: String,
205) -> (
206    proc_macro2::TokenStream,
207    proc_macro2::TokenStream,
208    proc_macro2::TokenStream,
209    proc_macro2::TokenStream,
210    String,
211    proc_macro2::TokenStream,
212    proc_macro2::TokenStream,
213    proc_macro2::TokenStream,
214) {
215    let mut code = default_code;
216    let mut retryable: Option<String> = None;
217    let mut internal: Option<String> = None;
218    let mut http_status_code: Option<String> = None;
219    let mut message: Option<String> = None;
220
221    for attr in attrs {
222        if !attr.path().is_ident("error") {
223            continue;
224        }
225        if let Err(e) = attr.parse_nested_meta(|meta| {
226            if meta.path.is_ident("code") {
227                let lit: syn::LitStr = meta.value()?.parse()?;
228                code = lit.value();
229                Ok(())
230            } else if meta.path.is_ident("retryable") {
231                let lit: syn::LitStr = meta.value()?.parse()?;
232                retryable = Some(lit.value());
233                Ok(())
234            } else if meta.path.is_ident("internal") {
235                let lit: syn::LitStr = meta.value()?.parse()?;
236                internal = Some(lit.value());
237                Ok(())
238            } else if meta.path.is_ident("http_status_code") {
239                // Parse the value as a literal (either string or int)
240                let value = meta.value()?;
241
242                // Try to parse as a literal expression
243                let lit: syn::Lit = value.parse()?;
244
245                match lit {
246                    syn::Lit::Str(lit_str) => {
247                        // String literal like "inherit" or "404"
248                        http_status_code = Some(lit_str.value());
249                    }
250                    syn::Lit::Int(lit_int) => {
251                        // Integer literal like 404
252                        let parsed_value = lit_int.base10_parse::<u16>()?;
253                        http_status_code = Some(parsed_value.to_string());
254                    }
255                    _ => {
256                        return Err(
257                            meta.error("http_status_code must be a string or integer literal")
258                        );
259                    }
260                }
261                Ok(())
262            } else if meta.path.is_ident("message") {
263                let lit: syn::LitStr = meta.value()?.parse()?;
264                message = Some(lit.value());
265                Ok(())
266            } else {
267                Err(meta.error("unsupported error attribute key"))
268            }
269        }) {
270            // Re-emit the actual syn error instead of a generic message
271            return (
272                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
273                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
274                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
275                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
276                String::new(),
277                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
278                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
279                syn::Error::new(e.span(), e.to_string()).to_compile_error(),
280            );
281        }
282    }
283
284    // ensure all required fields are specified
285    macro_rules! parse_flag {
286        ($val:expr,$name:expr) => {
287            match $val {
288                Some(ref s) if s == "true" => quote! { true },
289                Some(ref s) if s == "false" => quote! { false },
290                Some(ref s) if s == "inherit" => quote! { false }, // For backward compatibility in the main method
291                Some(ref _other) => syn::Error::new(proc_macro2::Span::call_site(), format!("{} must be \"true\", \"false\" or \"inherit\"", $name)).to_compile_error(),
292                None => syn::Error::new(proc_macro2::Span::call_site(), format!("{}=\"...\" is required in #[error(...)]", $name)).to_compile_error(),
293            }
294        };
295    }
296
297    // Parse inheritance flags from the same values
298    macro_rules! parse_inherit_flag {
299        ($val:expr) => {
300            match $val {
301                Some(ref s) if s == "inherit" => quote! { None },
302                Some(ref s) if s == "true" => quote! { Some(true) },
303                Some(ref s) if s == "false" => quote! { Some(false) },
304                Some(_) => quote! { Some(false) }, // fallback for any other value
305                None => syn::Error::new(proc_macro2::Span::call_site(), "flag is required")
306                    .to_compile_error(),
307            }
308        };
309    }
310
311    let retry_ts = parse_flag!(retryable.clone(), "retryable");
312    let internal_ts = parse_flag!(internal.clone(), "internal");
313    let retryable_inherit_ts = parse_inherit_flag!(retryable);
314    let internal_inherit_ts = parse_inherit_flag!(internal);
315
316    let code_ts = {
317        let lit = syn::LitStr::new(&code, proc_macro2::Span::call_site());
318        quote! { #lit }
319    };
320
321    // Parse HTTP status code with inheritance support
322    let (http_status_code_ts, http_status_code_inherit_ts) = match http_status_code {
323        Some(ref s) if s == "inherit" => {
324            // When inherit is specified, use 500 as the default but return None for inherit
325            (quote! { 500 }, quote! { None })
326        }
327        Some(ref s) => {
328            // Parse as number
329            match s.parse::<u16>() {
330                Ok(status_code) => (quote! { #status_code }, quote! { Some(#status_code) }),
331                Err(_) => (
332                    syn::Error::new(
333                        proc_macro2::Span::call_site(),
334                        "http_status_code must be a number or \"inherit\"",
335                    )
336                    .to_compile_error(),
337                    syn::Error::new(
338                        proc_macro2::Span::call_site(),
339                        "http_status_code must be a number or \"inherit\"",
340                    )
341                    .to_compile_error(),
342                ),
343            }
344        }
345        None => {
346            // Default to 500
347            (quote! { 500 }, quote! { Some(500) })
348        }
349    };
350
351    let message_str = message.unwrap_or_else(|| code.clone());
352
353    (
354        code_ts,
355        retry_ts,
356        internal_ts,
357        http_status_code_ts,
358        message_str,
359        retryable_inherit_ts,
360        internal_inherit_ts,
361        http_status_code_inherit_ts,
362    )
363}
364
365fn generate_message_interpolation(
366    message_template: &str,
367    field_idents: &[&syn::Ident],
368) -> proc_macro2::TokenStream {
369    // Let Rust's format! macro handle the parsing - just pass the template and fields directly
370    // This leverages Rust's built-in format string parsing which handles all cases correctly
371
372    if field_idents.is_empty() {
373        quote! { #message_template.to_string() }
374    } else {
375        // Find which fields are actually used in the message template
376        let used_fields: Vec<&syn::Ident> = field_idents
377            .iter()
378            .filter(|field| {
379                let field_name = field.to_string();
380                message_template.contains(&format!("{{{}", field_name))
381            })
382            .cloned()
383            .collect();
384
385        if used_fields.is_empty() {
386            // No fields are referenced in the template
387            quote! { #message_template.to_string() }
388        } else {
389            // Use named parameters - only pass fields that are actually used
390            quote! { format!(#message_template, #(#used_fields = #used_fields),*) }
391        }
392    }
393}