burn_tensor_testgen/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{Attribute, Expr, ItemFn, Lit, Meta, MetaNameValue, parse_macro_input};
8
9struct AttributeArgs {
11 args: Punctuated<Meta, Comma>,
12}
13
14impl Parse for AttributeArgs {
15 fn parse(input: ParseStream) -> syn::Result<Self> {
16 Ok(AttributeArgs {
17 args: Punctuated::parse_terminated(input)?,
18 })
19 }
20}
21
22#[allow(clippy::test_attr_in_doctest)]
23#[proc_macro_attribute]
45pub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream {
46 let args = parse_macro_input!(args as AttributeArgs);
48 let input_fn = parse_macro_input!(input as ItemFn);
49
50 let mut expected_reason = None;
52 for arg in args.args.iter() {
53 if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg
54 && path.is_ident("reason")
55 && let Expr::Lit(lit) = value
56 && let Lit::Str(ref lit_str) = lit.lit
57 {
58 expected_reason = Some(lit_str.value());
59 }
60 }
61
62 let expected_reason = match expected_reason {
63 Some(reason) => reason,
64 None => {
65 return syn::Error::new(
66 proc_macro2::Span::call_site(),
67 "The #[might_panic] attribute requires a 'reason' parameter",
68 )
69 .to_compile_error()
70 .into();
71 }
72 };
73
74 let fn_name = &input_fn.sig.ident;
75 let fn_vis = &input_fn.vis;
76 let fn_generics = &input_fn.sig.generics;
77 let fn_block = &input_fn.block;
78 let fn_attrs = input_fn
79 .attrs
80 .iter()
81 .filter(|attr| !attr.path().is_ident("test"))
82 .collect::<Vec<&Attribute>>();
83
84 let wrapper_name = format_ident!("{}_might_panic", fn_name);
86
87 let expanded = quote! {
88 #(#fn_attrs)*
89 #fn_vis fn #fn_name #fn_generics() {
90 #fn_block
91 }
92
93 #[test]
94 #fn_vis fn #wrapper_name #fn_generics() {
95 use std::panic::{self, AssertUnwindSafe};
96
97 let expected_reason = #expected_reason;
98 let result = panic::catch_unwind(AssertUnwindSafe(|| {
99 #fn_name();
100 }));
101
102 match result {
103 Ok(_) => {
104 }
106 Err(e) => {
107 let panic_msg = if let Some(s) = e.downcast_ref::<String>() {
109 s.to_string()
110 } else if let Some(s) = e.downcast_ref::<&str>() {
111 s.to_string()
112 } else {
113 "Unknown panic".to_string()
114 };
115
116 if !panic_msg.starts_with(expected_reason) {
118 panic!(
119 "Test '{}' marked as 'might_panic' failed. Expected reason: '{}'",
120 stringify!(#fn_name),
121 expected_reason
122 );
123 }
124 }
125 }
126 }
127 };
128
129 expanded.into()
130}
131
132#[allow(missing_docs)]
133#[proc_macro_attribute]
134pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream {
135 let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item);
136 let attr: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr);
137 let macro_ident = format_ident!("testgen_{}", attr.to_string());
138
139 let macro_gen = quote! {
140 #[allow(missing_docs)]
141 #[macro_export]
142 macro_rules! #macro_ident {
143 () => {
144 mod #attr {
145 use super::*;
146
147 #item
148 }
149 };
150 }
151 };
152
153 macro_gen.into()
154}