use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Attribute, Expr, ItemFn, Lit, Meta, MetaNameValue, parse_macro_input};
struct AttributeArgs {
args: Punctuated<Meta, Comma>,
}
impl Parse for AttributeArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(AttributeArgs {
args: Punctuated::parse_terminated(input)?,
})
}
}
#[allow(clippy::test_attr_in_doctest)]
#[proc_macro_attribute]
pub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as AttributeArgs);
let input_fn = parse_macro_input!(input as ItemFn);
let mut expected_reason = None;
for arg in args.args.iter() {
if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg
&& path.is_ident("reason")
&& let Expr::Lit(lit) = value
&& let Lit::Str(ref lit_str) = lit.lit
{
expected_reason = Some(lit_str.value());
}
}
let expected_reason = match expected_reason {
Some(reason) => reason,
None => {
return syn::Error::new(
proc_macro2::Span::call_site(),
"The #[might_panic] attribute requires a 'reason' parameter",
)
.to_compile_error()
.into();
}
};
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_generics = &input_fn.sig.generics;
let fn_block = &input_fn.block;
let fn_attrs = input_fn
.attrs
.iter()
.filter(|attr| !attr.path().is_ident("test"))
.collect::<Vec<&Attribute>>();
let wrapper_name = format_ident!("{}_might_panic", fn_name);
quote! {
#(#fn_attrs)*
#fn_vis fn #fn_name #fn_generics() { #fn_block }
#[test]
#fn_vis fn #wrapper_name #fn_generics() {
use std::panic::{self, AssertUnwindSafe};
use std::sync::{Arc, Mutex, OnceLock};
let get_msg = |p: &(dyn std::any::Any + Send)| -> String {
p.downcast_ref::<String>().cloned()
.or_else(|| p.downcast_ref::<&str>().map(|s| s.to_string()))
.unwrap_or_else(|| "Unknown panic".to_string())
};
static PANIC_LOG: OnceLock<Mutex<Vec<String>>> = OnceLock::new();
let log = PANIC_LOG.get_or_init(|| Mutex::new(Vec::new()));
static HOOK: OnceLock<()> = OnceLock::new();
HOOK.get_or_init(|| {
let prev = panic::take_hook();
panic::set_hook(Box::new(move |info| {
if let Ok(mut v) = log.lock() {
v.push(get_msg(info.payload()));
}
prev(info);
}));
});
let start_idx = log.lock().unwrap().len();
let result = panic::catch_unwind(AssertUnwindSafe(|| #fn_name()));
if let Err(e) = result {
let main_msg = get_msg(&*e);
let panic_logs = log.lock().unwrap();
let window = &panic_logs[start_idx..];
let matched = window.iter().chain(std::iter::once(&main_msg))
.any(|m| m.contains(#expected_reason));
if !matched {
let all = window.iter().chain(std::iter::once(&main_msg))
.map(|m| format!("- {m}")).collect::<Vec<_>>().join("\n");
panic!("\nTest '{}' failed.\nExpected: '{}'\nFound:\n{}\n",
stringify!(#fn_name), #expected_reason, all);
} else {
let all = window.iter().chain(std::iter::once(&main_msg))
.map(|m| format!("- {m}")).collect::<Vec<_>>().join("\n");
println!("\nTest '{}' failed.\nExpected: '{}'\nFound:\n{}\n",
stringify!(#fn_name), #expected_reason, all);
}
}
}
}
.into()
}