Skip to main content

error_forge_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, Data, DeriveInput, Fields};
5
6/// Derive macro for ModError
7///
8/// This macro automatically implements the ForgeError trait and common
9/// error handling functionality for a struct or enum, allowing for
10/// "lazy mode" error creation with minimal boilerplate.
11///
12/// # Example
13///
14/// When the macro is used in your application where error-forge is a dependency:
15///
16/// ```ignore
17/// use error_forge::ModError;
18///
19/// #[derive(Debug, ModError)]
20/// #[error_prefix("Database")]
21/// pub enum DbError {
22///     #[error_display("Connection to {0} failed")]
23///     ConnectionFailed(String),
24///
25///     #[error_display("Query execution failed: {reason}")]
26///     QueryFailed { reason: String },
27///
28///     #[error_display("Transaction error")]
29///     #[error_http_status(400)]
30///     TransactionError,
31/// }
32/// ```
33///
34/// Note: This is a procedural macro that is re-exported by the `error-forge` crate.
35/// When using in your application, import it from the main crate with `use error_forge::ModError;`.
36#[proc_macro_derive(
37    ModError,
38    attributes(
39        error_prefix,
40        error_display,
41        error_kind,
42        error_caption,
43        error_retryable,
44        error_http_status,
45        error_exit_code,
46        error_fatal
47    )
48)]
49pub fn derive_mod_error(input: TokenStream) -> TokenStream {
50    // Parse the input
51    let input = parse_macro_input!(input as DeriveInput);
52
53    // Check if this is an enum or struct
54    let is_enum = match &input.data {
55        Data::Enum(_) => true,
56        Data::Struct(_) => false,
57        Data::Union(_) => panic!("ModError cannot be derived for unions"),
58    };
59
60    // Get the error prefix from attributes
61    let error_prefix = get_error_prefix(&input.attrs);
62
63    // Generate implementation based on whether it's an enum or struct
64    let implementation = if is_enum {
65        implement_for_enum(&input, &error_prefix)
66    } else {
67        implement_for_struct(&input, &error_prefix)
68    };
69
70    // Return the generated implementation
71    TokenStream::from(implementation)
72}
73
74// Extract error_prefix attribute value
75fn get_error_prefix(attrs: &[syn::Attribute]) -> String {
76    for attr in attrs {
77        if attr.path.is_ident("error_prefix") {
78            if let Some(value) = parse_string_attribute(attr) {
79                return value;
80            }
81        }
82    }
83    String::new()
84}
85
86fn parse_string_attribute(attr: &syn::Attribute) -> Option<String> {
87    match attr.parse_meta().ok()? {
88        syn::Meta::NameValue(meta) => match meta.lit {
89            syn::Lit::Str(lit) => Some(lit.value()),
90            _ => None,
91        },
92        syn::Meta::List(meta) => match meta.nested.iter().next() {
93            Some(syn::NestedMeta::Lit(syn::Lit::Str(lit))) => Some(lit.value()),
94            _ => None,
95        },
96        syn::Meta::Path(_) => None,
97    }
98}
99
100fn parse_int_attribute<T>(attr: &syn::Attribute) -> Option<T>
101where
102    T: std::str::FromStr,
103    T::Err: std::fmt::Display,
104{
105    match attr.parse_meta().ok()? {
106        syn::Meta::NameValue(meta) => match meta.lit {
107            syn::Lit::Int(lit) => lit.base10_parse().ok(),
108            _ => None,
109        },
110        syn::Meta::List(meta) => match meta.nested.iter().next() {
111            Some(syn::NestedMeta::Lit(syn::Lit::Int(lit))) => lit.base10_parse().ok(),
112            _ => None,
113        },
114        syn::Meta::Path(_) => None,
115    }
116}
117
118fn has_flag_attribute(attr: &syn::Attribute, name: &str) -> bool {
119    attr.path.is_ident(name)
120}
121
122// Implement ModError for an enum
123fn implement_for_enum(input: &DeriveInput, error_prefix: &str) -> proc_macro2::TokenStream {
124    let name = &input.ident;
125    let data_enum = match &input.data {
126        Data::Enum(data) => data,
127        _ => panic!("Expected enum"),
128    };
129
130    // Generate match arms for each variant
131    let mut kind_match_arms = Vec::new();
132    let mut caption_match_arms = Vec::new();
133    let mut display_match_arms = Vec::new();
134    let mut retryable_match_arms = Vec::new();
135    let mut fatal_match_arms = Vec::new();
136    let mut status_code_match_arms = Vec::new();
137    let mut exit_code_match_arms = Vec::new();
138
139    // Process each variant
140    for variant in &data_enum.variants {
141        let variant_name = &variant.ident;
142        let variant_name_str = variant_name.to_string();
143
144        // Default values
145        let mut display_format = variant_name_str.clone();
146        let mut kind_name = variant_name_str.clone();
147        let mut caption = format!("{}: Error", error_prefix);
148        let mut retryable = false;
149        let mut fatal = false;
150        let mut status_code: u16 = 500;
151        let mut exit_code: i32 = 1;
152
153        // Extract attributes
154        for attr in &variant.attrs {
155            if attr.path.is_ident("error_display") {
156                if let Some(value) = parse_string_attribute(attr) {
157                    display_format = value;
158                }
159            } else if attr.path.is_ident("error_kind") {
160                if let Some(value) = parse_string_attribute(attr) {
161                    kind_name = value;
162                }
163            } else if attr.path.is_ident("error_caption") {
164                if let Some(value) = parse_string_attribute(attr) {
165                    caption = value;
166                }
167            } else if attr.path.is_ident("error_retryable") {
168                retryable = true;
169            } else if has_flag_attribute(attr, "error_fatal") {
170                fatal = true;
171            } else if attr.path.is_ident("error_http_status") {
172                if let Some(value) = parse_int_attribute(attr) {
173                    status_code = value;
174                }
175            } else if attr.path.is_ident("error_exit_code") {
176                if let Some(value) = parse_int_attribute(attr) {
177                    exit_code = value;
178                }
179            }
180        }
181
182        // Generate pattern matching based on the variant's fields
183        match &variant.fields {
184            Fields::Named(fields) => {
185                let field_names: Vec<_> = fields
186                    .named
187                    .iter()
188                    .map(|f| f.ident.as_ref().unwrap())
189                    .collect();
190
191                // Format string handled directly in match arm
192
193                kind_match_arms.push(quote! {
194                    Self::#variant_name { .. } => #kind_name
195                });
196
197                caption_match_arms.push(quote! {
198                    Self::#variant_name { .. } => #caption
199                });
200
201                display_match_arms.push(quote! {
202                    Self::#variant_name { #(#field_names),* } => format!(#display_format, #(#field_names = #field_names),*)
203                });
204
205                retryable_match_arms.push(quote! {
206                    Self::#variant_name { .. } => #retryable
207                });
208
209                fatal_match_arms.push(quote! {
210                    Self::#variant_name { .. } => #fatal
211                });
212
213                status_code_match_arms.push(quote! {
214                    Self::#variant_name { .. } => #status_code
215                });
216
217                exit_code_match_arms.push(quote! {
218                    Self::#variant_name { .. } => #exit_code
219                });
220            }
221            Fields::Unnamed(fields) => {
222                let field_count = fields.unnamed.len();
223                let field_names: Vec<_> =
224                    (0..field_count).map(|i| format_ident!("_{}", i)).collect();
225
226                // Generate display format with tuple fields
227                kind_match_arms.push(quote! {
228                    Self::#variant_name(..) => #kind_name
229                });
230
231                caption_match_arms.push(quote! {
232                    Self::#variant_name(..) => #caption
233                });
234
235                let field_pattern_list = field_names.iter().map(|name| quote! { #name, });
236                display_match_arms.push(quote! {
237                    Self::#variant_name(#(#field_pattern_list)*) => format!(#display_format #(, #field_names)*)
238                });
239
240                retryable_match_arms.push(quote! {
241                    Self::#variant_name(..) => #retryable
242                });
243
244                fatal_match_arms.push(quote! {
245                    Self::#variant_name(..) => #fatal
246                });
247
248                status_code_match_arms.push(quote! {
249                    Self::#variant_name(..) => #status_code
250                });
251
252                exit_code_match_arms.push(quote! {
253                    Self::#variant_name(..) => #exit_code
254                });
255            }
256            Fields::Unit => {
257                // Unit variant (no fields)
258                kind_match_arms.push(quote! {
259                    Self::#variant_name => #kind_name
260                });
261
262                caption_match_arms.push(quote! {
263                    Self::#variant_name => #caption
264                });
265
266                display_match_arms.push(quote! {
267                    Self::#variant_name => #display_format.to_string()
268                });
269
270                retryable_match_arms.push(quote! {
271                    Self::#variant_name => #retryable
272                });
273
274                fatal_match_arms.push(quote! {
275                    Self::#variant_name => #fatal
276                });
277
278                status_code_match_arms.push(quote! {
279                    Self::#variant_name => #status_code
280                });
281
282                exit_code_match_arms.push(quote! {
283                    Self::#variant_name => #exit_code
284                });
285            }
286        }
287    }
288
289    // Generate implementation
290    quote! {
291        impl ::std::fmt::Display for #name {
292            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
293                let msg = match self {
294                    #(#display_match_arms,)*
295                };
296                write!(f, "{}", msg)
297            }
298        }
299
300        impl ::error_forge::error::ForgeError for #name {
301            fn kind(&self) -> &'static str {
302                match self {
303                    #(#kind_match_arms,)*
304                }
305            }
306
307            fn caption(&self) -> &'static str {
308                match self {
309                    #(#caption_match_arms,)*
310                }
311            }
312
313            fn is_retryable(&self) -> bool {
314                match self {
315                    #(#retryable_match_arms,)*
316                }
317            }
318
319            fn is_fatal(&self) -> bool {
320                match self {
321                    #(#fatal_match_arms,)*
322                }
323            }
324
325            fn status_code(&self) -> u16 {
326                match self {
327                    #(#status_code_match_arms,)*
328                }
329            }
330
331            fn exit_code(&self) -> i32 {
332                match self {
333                    #(#exit_code_match_arms,)*
334                }
335            }
336        }
337
338        impl ::std::error::Error for #name {
339            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
340                None
341            }
342        }
343    }
344}
345
346// Implement ModError for a struct
347fn implement_for_struct(input: &DeriveInput, error_prefix: &str) -> proc_macro2::TokenStream {
348    let name = &input.ident;
349    let name_str = name.to_string();
350
351    quote! {
352        impl ::std::fmt::Display for #name {
353            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
354                write!(f, "{}: Error", #error_prefix)
355            }
356        }
357
358        impl ::error_forge::error::ForgeError for #name {
359            fn kind(&self) -> &'static str {
360                #name_str
361            }
362
363            fn caption(&self) -> &'static str {
364                concat!(#error_prefix, ": Error")
365            }
366        }
367
368        impl ::std::error::Error for #name {
369            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
370                None
371            }
372        }
373    }
374}
375
376// Note: The implementation now handles formatting directly in the match arms instead of using a helper function