pretty_panic_proc_macro/
lib.rs1use proc_macro::{self, TokenStream};
2use proc_macro_error::{abort_call_site, proc_macro_error};
3use quote::quote;
4use syn::{parse_macro_input, ExprPath, ItemFn, PathArguments, ReturnType, Type};
5
6#[proc_macro_attribute]
7#[proc_macro_error]
8#[doc(hidden)]
9pub fn pretty_panic(attrs: TokenStream, input: TokenStream) -> TokenStream {
10 let input_fn = parse_macro_input!(input as ItemFn);
12
13 if input_fn.sig.ident != "main" {
15 abort_call_site!("The `#[pretty]` attribute can only be used on the `main` function.");
16 }
17
18 let mut format_error: Option<ExprPath> = None;
19 let mut format_panic: Option<ExprPath> = None;
20 let parser = syn::meta::parser(|meta| {
22 if meta.path.is_ident("formatter") {
23 format_error = Some(meta.value()?.parse()?);
24 Ok(())
25 } else if meta.path.is_ident("panic_formatter") {
26 format_panic = Some(meta.value()?.parse()?);
27 Ok(())
28 } else {
29 Err(meta.error("unsupported pretty_panic property"))
30 }
31 });
32
33 parse_macro_input!(attrs with parser);
34
35 let format_error = if let Some(format_error) = &format_error {
36 quote! { #format_error }
37 } else {
38 if cfg!(feature = "default_formatters") {
39 quote! { pretty_panics::default_formatters::error_formatter }
40 } else {
41 abort_call_site!(
42 "`formatter` not provided and `default_formatters` feature is not enabled."
43 );
44 }
45 };
46
47 let format_panic = if let Some(format_panic) = &format_panic {
48 quote! { eprintln!("{}", #format_panic(panic_hook_info, message.to_string())); }
49 } else {
50 if cfg!(feature = "default_formatters") {
51 quote! { eprintln!("{}", pretty_panics::default_formatters::panic_formatter(panic_hook_info, message.to_string())); }
52 } else {
53 abort_call_site!(
54 "`panic_formatter` not provided and `default_formatters` feature is not enabled."
55 );
56 }
57 };
58
59 let output = match &input_fn.sig.output {
60 ReturnType::Type(_, ty) => {
61 if let Type::Path(type_path) = &**ty {
63 let path = &type_path.path;
64
65 if let Some(segment) = path.segments.last() {
67 if segment.ident == "Result" {
68 if let PathArguments::AngleBracketed(args) = &segment.arguments {
70 if args.args.len() >= 1 {
71 let return_type = quote! { #ty };
72 let body = &input_fn.block;
73 let new_fn_name = syn::Ident::new(
75 &format!("{}_wrapped", input_fn.sig.ident),
76 input_fn.sig.ident.span(),
77 );
78
79 let gen_new_fn = quote! {
80 fn #new_fn_name() -> #return_type {
81 #body
83 }
84 };
85
86 let gen_main = quote! {
88 fn main() {
89 std::panic::set_hook(Box::new(|panic_hook_info| {
90 let payload = panic_hook_info.payload();
91 if let Some(message) = payload.downcast_ref::<String>() {
92 #format_panic
93 } else if let Some(message) = payload.downcast_ref::<&str>() {
94 #format_panic
95 }
96 }));
97
98 if let Err(e) = #new_fn_name() {
99 eprintln!("{}", #format_error(&e));
100 std::process::exit(1);
101 }
102 }
103 };
104
105 let expanded = quote! {
107 #gen_new_fn
108 #gen_main
109 };
110
111 return expanded.into();
112 }
113 }
114 abort_call_site!("The `main` function must return a `Result<T, E>` type.");
115 }
116 }
117 }
118
119 quote! {
120 #input_fn
121 }
122 .into()
123 }
124 _ => quote! {
125 #input_fn
126 }
127 .into(),
128 };
129
130 output
131}