error_forge_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::{quote, format_ident};
4use syn::{parse_macro_input, DeriveInput, Data, Fields};
5
6/// Derive macro for ModError
7///
8/// This macro automatically implements the ForgeError trait and common
9/// error handling functionality for a struct or enum, allowing for
10/// "lazy mode" error creation with minimal boilerplate.
11///
12/// # Example
13///
14/// ```rust
15/// use error_forge::ModError;
16///
17/// #[derive(Debug, ModError)]
18/// #[error_prefix("Database")]
19/// pub enum DbError {
20///     #[error_display("Connection to {0} failed")]
21///     ConnectionFailed(String),
22///
23///     #[error_display("Query execution failed: {reason}")]
24///     QueryFailed { reason: String },
25///
26///     #[error_display("Transaction error")]
27///     #[error_http_status(400)]
28///     TransactionError,
29/// }
30/// ```
31#[proc_macro_derive(ModError, attributes(error_prefix, error_display, error_kind,
32                                         error_caption, error_retryable, error_http_status,
33                                         error_exit_code))]
34pub fn derive_mod_error(input: TokenStream) -> TokenStream {
35    // Parse the input
36    let input = parse_macro_input!(input as DeriveInput);
37    
38    // Check if this is an enum or struct
39    let is_enum = match &input.data {
40        Data::Enum(_) => true,
41        Data::Struct(_) => false,
42        Data::Union(_) => panic!("ModError cannot be derived for unions"),
43    };
44    
45    // Get the error prefix from attributes
46    let error_prefix = get_error_prefix(&input.attrs);
47    
48    // Generate implementation based on whether it's an enum or struct
49    let implementation = if is_enum {
50        implement_for_enum(&input, &error_prefix)
51    } else {
52        implement_for_struct(&input, &error_prefix)
53    };
54    
55    // Return the generated implementation
56    TokenStream::from(implementation)
57}
58
59// Extract error_prefix attribute value
60fn get_error_prefix(attrs: &[syn::Attribute]) -> String {
61    for attr in attrs {
62        if attr.path.is_ident("error_prefix") {
63            // Try both attribute formats
64            // Format: #[error_prefix = "text"]
65            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
66                if let syn::Lit::Str(lit) = meta.lit {
67                    return lit.value();
68                }
69            }
70            // Format: #[error_prefix("text")]
71            else if let Ok(syn::Meta::List(meta)) = attr.parse_meta() {
72                if let Some(nested) = meta.nested.iter().next() {
73                    if let syn::NestedMeta::Lit(syn::Lit::Str(lit)) = nested {
74                        return lit.value();
75                    }
76                }
77            }
78        }
79    }
80    String::new()
81}
82
83// Implement ModError for an enum
84fn implement_for_enum(input: &DeriveInput, error_prefix: &str) -> proc_macro2::TokenStream {
85    let name = &input.ident;
86    let data_enum = match &input.data {
87        Data::Enum(data) => data,
88        _ => panic!("Expected enum"),
89    };
90    
91    // Generate match arms for each variant
92    let mut kind_match_arms = Vec::new();
93    let mut caption_match_arms = Vec::new();
94    let mut display_match_arms = Vec::new();
95    let mut retryable_match_arms = Vec::new();
96    let mut status_code_match_arms = Vec::new();
97    let mut exit_code_match_arms = Vec::new();
98    
99    // Process each variant
100    for variant in &data_enum.variants {
101        let variant_name = &variant.ident;
102        let variant_name_str = variant_name.to_string();
103        
104        // Default values
105        let mut display_format = variant_name_str.clone();
106        let mut retryable = false;
107        let mut status_code: u16 = 500;
108        let mut exit_code: i32 = 1;
109        
110        // Extract attributes
111        for attr in &variant.attrs {
112            if attr.path.is_ident("error_display") {
113                if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
114                    if let syn::Lit::Str(lit) = meta.lit {
115                        display_format = lit.value();
116                    }
117                }
118            } else if attr.path.is_ident("error_retryable") {
119                retryable = true;
120            } else if attr.path.is_ident("error_http_status") {
121                if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
122                    if let syn::Lit::Int(lit) = meta.lit {
123                        status_code = lit.base10_parse().unwrap_or(500);
124                    }
125                }
126            } else if attr.path.is_ident("error_exit_code") {
127                if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
128                    if let syn::Lit::Int(lit) = meta.lit {
129                        exit_code = lit.base10_parse().unwrap_or(1);
130                    }
131                }
132            }
133        }
134        
135        // Generate pattern matching based on the variant's fields
136        match &variant.fields {
137            Fields::Named(fields) => {
138                let field_names: Vec<_> = fields.named.iter()
139                    .map(|f| f.ident.as_ref().unwrap())
140                    .collect();
141                
142                // Format string handled directly in match arm
143                
144                kind_match_arms.push(quote! {
145                    Self::#variant_name { .. } => #variant_name_str
146                });
147                
148                caption_match_arms.push(quote! {
149                    Self::#variant_name { .. } => concat!(#error_prefix, ": Error")
150                });
151                
152                let _field_patterns = field_names.iter().map(|name| {
153                    let _name_str = name.to_string();
154                    quote! { #name, }
155                });
156                
157                // For struct variants, create a properly formatted string without using fields
158                display_match_arms.push(quote! {
159                    Self::#variant_name { .. } => format!("{}: {}", #error_prefix, #display_format)
160                });
161                
162                retryable_match_arms.push(quote! {
163                    Self::#variant_name { .. } => #retryable
164                });
165                
166                status_code_match_arms.push(quote! {
167                    Self::#variant_name { .. } => #status_code
168                });
169                
170                exit_code_match_arms.push(quote! {
171                    Self::#variant_name { .. } => #exit_code
172                });
173            },
174            Fields::Unnamed(fields) => {
175                let field_count = fields.unnamed.len();
176                let field_names: Vec<_> = (0..field_count)
177                    .map(|i| format_ident!("_{}", i))
178                    .collect();
179                
180                // Generate display format with tuple fields
181                // field_names is already Vec<Ident> so we can pass it directly
182                // Format string handled directly in match arm
183                
184                kind_match_arms.push(quote! {
185                    Self::#variant_name(..) => #variant_name_str
186                });
187                
188                caption_match_arms.push(quote! {
189                    Self::#variant_name(..) => concat!(#error_prefix, ": Error")
190                });
191                
192                let _field_patterns = field_names.iter().map(|name| {
193                    quote! { #name, }
194                });
195                
196                // For tuple variants, handle simple positional formatting for {0}, {1}, etc.
197                if display_format.contains("{0}") || display_format.contains("{}") {
198                    // Recreate the field pattern list here to avoid conflicts with renamed variables
199                    let field_pattern_list = field_names.iter().map(|name| quote! { #name, });
200                    display_match_arms.push(quote! {
201                        Self::#variant_name(#(#field_pattern_list)*) => format!("{}: {}", #error_prefix, format!(#display_format #(, #field_names)*))
202                    });
203                } else {
204                    // Fall back to simple display if no formatting placeholders
205                    display_match_arms.push(quote! {
206                        Self::#variant_name(..) => format!("{}: {}", #error_prefix, #display_format)
207                    });
208                }
209                
210                retryable_match_arms.push(quote! {
211                    Self::#variant_name(..) => #retryable
212                });
213                
214                status_code_match_arms.push(quote! {
215                    Self::#variant_name(..) => #status_code
216                });
217                
218                exit_code_match_arms.push(quote! {
219                    Self::#variant_name(..) => #exit_code
220                });
221            },
222            Fields::Unit => {
223                // Unit variant (no fields)
224                kind_match_arms.push(quote! {
225                    Self::#variant_name => #variant_name_str
226                });
227                
228                caption_match_arms.push(quote! {
229                    Self::#variant_name => concat!(#error_prefix, ": Error")
230                });
231                
232                display_match_arms.push(quote! {
233                    Self::#variant_name => format!("{}: {}", #error_prefix, #display_format)
234                });
235                
236                retryable_match_arms.push(quote! {
237                    Self::#variant_name => #retryable
238                });
239                
240                status_code_match_arms.push(quote! {
241                    Self::#variant_name => #status_code
242                });
243                
244                exit_code_match_arms.push(quote! {
245                    Self::#variant_name => #exit_code
246                });
247            },
248        }
249    }
250    
251    // Generate implementation
252    quote! {
253        impl ::std::fmt::Display for #name {
254            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
255                let msg = match self {
256                    #(#display_match_arms,)*
257                };
258                write!(f, "{}", msg)
259            }
260        }
261        
262        impl ::error_forge::error::ForgeError for #name {
263            fn kind(&self) -> &'static str {
264                match self {
265                    #(#kind_match_arms,)*
266                }
267            }
268            
269            fn caption(&self) -> &'static str {
270                match self {
271                    #(#caption_match_arms,)*
272                }
273            }
274            
275            fn is_retryable(&self) -> bool {
276                match self {
277                    #(#retryable_match_arms,)*
278                }
279            }
280            
281            fn status_code(&self) -> u16 {
282                match self {
283                    #(#status_code_match_arms,)*
284                }
285            }
286            
287            fn exit_code(&self) -> i32 {
288                match self {
289                    #(#exit_code_match_arms,)*
290                }
291            }
292        }
293        
294        impl ::std::error::Error for #name {
295            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
296                None
297            }
298        }
299    }
300}
301
302// Implement ModError for a struct
303fn implement_for_struct(input: &DeriveInput, error_prefix: &str) -> proc_macro2::TokenStream {
304    let name = &input.ident;
305    let name_str = name.to_string();
306    
307    quote! {
308        impl ::std::fmt::Display for #name {
309            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
310                write!(f, "{}: Error", #error_prefix)
311            }
312        }
313        
314        impl ::error_forge::error::ForgeError for #name {
315            fn kind(&self) -> &'static str {
316                #name_str
317            }
318            
319            fn caption(&self) -> &'static str {
320                concat!(#error_prefix, ": Error")
321            }
322        }
323        
324        impl ::std::error::Error for #name {
325            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
326                None
327            }
328        }
329    }
330}
331
332// Note: The implementation now handles formatting directly in the match arms instead of using a helper function