pretty_panic_proc_macro/
lib.rs

1use proc_macro::{self, TokenStream};
2use proc_macro_error::{abort_call_site, proc_macro_error};
3use quote::quote;
4use syn::{parse_macro_input, ExprPath, ItemFn, PathArguments, ReturnType, Type};
5
6#[proc_macro_attribute]
7#[proc_macro_error]
8#[doc(hidden)]
9pub fn pretty_panic(attrs: TokenStream, input: TokenStream) -> TokenStream {
10    // Parse the function
11    let input_fn = parse_macro_input!(input as ItemFn);
12
13    // Check if the function is `main`
14    if input_fn.sig.ident != "main" {
15        abort_call_site!("The `#[pretty]` attribute can only be used on the `main` function.");
16    }
17
18    let mut format_error: Option<ExprPath> = None;
19    let mut format_panic: Option<ExprPath> = None;
20    // Parse the attributes using syn
21    let parser = syn::meta::parser(|meta| {
22        if meta.path.is_ident("formatter") {
23            format_error = Some(meta.value()?.parse()?);
24            Ok(())
25        } else if meta.path.is_ident("panic_formatter") {
26            format_panic = Some(meta.value()?.parse()?);
27            Ok(())
28        } else {
29            Err(meta.error("unsupported pretty_panic property"))
30        }
31    });
32
33    parse_macro_input!(attrs with parser);
34
35    let format_error = if let Some(format_error) = &format_error {
36        quote! { #format_error }
37    } else {
38        if cfg!(feature = "default_formatters") {
39            quote! { pretty_panics::default_formatters::error_formatter }
40        } else {
41            abort_call_site!(
42                "`formatter` not provided and `default_formatters` feature is not enabled."
43            );
44        }
45    };
46
47    let format_panic = if let Some(format_panic) = &format_panic {
48        quote! { eprintln!("{}", #format_panic(panic_hook_info, message.to_string())); }
49    } else {
50        if cfg!(feature = "default_formatters") {
51            quote! { eprintln!("{}", pretty_panics::default_formatters::panic_formatter(panic_hook_info, message.to_string())); }
52        } else {
53            abort_call_site!(
54                "`panic_formatter` not provided and `default_formatters` feature is not enabled."
55            );
56        }
57    };
58
59    let output = match &input_fn.sig.output {
60        ReturnType::Type(_, ty) => {
61            // Check if the return type is `Result<T, E>`
62            if let Type::Path(type_path) = &**ty {
63                let path = &type_path.path;
64
65                // Check if the path is `Result`
66                if let Some(segment) = path.segments.last() {
67                    if segment.ident == "Result" {
68                        // Ensure that the generic arguments are exactly two
69                        if let PathArguments::AngleBracketed(args) = &segment.arguments {
70                            if args.args.len() >= 1 {
71                                let return_type = quote! { #ty };
72                                let body = &input_fn.block;
73                                // Generate a new function with the original content
74                                let new_fn_name = syn::Ident::new(
75                                    &format!("{}_wrapped", input_fn.sig.ident),
76                                    input_fn.sig.ident.span(),
77                                );
78
79                                let gen_new_fn = quote! {
80                                    fn #new_fn_name() -> #return_type {
81                                        // Original function body
82                                        #body
83                                    }
84                                };
85
86                                // Generate the modified main function that calls the new function
87                                let gen_main = quote! {
88                                    fn main() {
89                                        std::panic::set_hook(Box::new(|panic_hook_info| {
90                                            let payload = panic_hook_info.payload();
91                                            if let Some(message) = payload.downcast_ref::<String>() {
92                                                #format_panic
93                                            } else if let Some(message) = payload.downcast_ref::<&str>() {
94                                                #format_panic
95                                            }
96                                        }));
97
98                                        if let Err(e) = #new_fn_name() {
99                                            eprintln!("{}", #format_error(&e));
100                                            std::process::exit(1);
101                                        }
102                                    }
103                                };
104
105                                // Combine the two function definitions
106                                let expanded = quote! {
107                                    #gen_new_fn
108                                    #gen_main
109                                };
110
111                                return expanded.into();
112                            }
113                        }
114                        abort_call_site!("The `main` function must return a `Result<T, E>` type.");
115                    }
116                }
117            }
118
119            quote! {
120                #input_fn
121            }
122            .into()
123        }
124        _ => quote! {
125            #input_fn
126        }
127        .into(),
128    };
129
130    output
131}