burn_tensor_testgen/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{Attribute, Expr, ItemFn, Lit, Meta, MetaNameValue, parse_macro_input};
8
9// Define a structure to parse the attribute arguments
10struct AttributeArgs {
11    args: Punctuated<Meta, Comma>,
12}
13
14impl Parse for AttributeArgs {
15    fn parse(input: ParseStream) -> syn::Result<Self> {
16        Ok(AttributeArgs {
17            args: Punctuated::parse_terminated(input)?,
18        })
19    }
20}
21
22#[allow(clippy::test_attr_in_doctest)]
23/// **This is only meaningful when the `reason` is specific and clear.**
24///
25/// A proc macro attribute that adds panic handling to test functions.
26///
27/// # Usage
28/// ```rust, ignore
29/// #[might_panic(reason = "expected panic message prefix")]
30/// #[test]
31/// fn test_that_might_panic() {
32///     // test code that might panic (with acceptable reason)
33/// }
34/// ```
35///
36/// # Behavior
37/// - If the test does not panic, it passes.
38/// - If the test panics with a message starting with the expected prefix, the failure is ignored.
39/// - If the test panics with a different message, the test fails.
40///
41/// # Note
42/// This proc macro uses [`std::panic::catch_unwind`]. As such, it does not work in a no-std environment.
43/// Make sure it is feature gated when an `"std"` feature is available.
44#[proc_macro_attribute]
45pub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream {
46    // Parse the attribute arguments
47    let args = parse_macro_input!(args as AttributeArgs);
48    let input_fn = parse_macro_input!(input as ItemFn);
49
50    // Extract the expected panic reason
51    let mut expected_reason = None;
52    for arg in args.args.iter() {
53        if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg
54            && path.is_ident("reason")
55            && let Expr::Lit(lit) = value
56            && let Lit::Str(ref lit_str) = lit.lit
57        {
58            expected_reason = Some(lit_str.value());
59        }
60    }
61
62    let expected_reason = match expected_reason {
63        Some(reason) => reason,
64        None => {
65            return syn::Error::new(
66                proc_macro2::Span::call_site(),
67                "The #[might_panic] attribute requires a 'reason' parameter",
68            )
69            .to_compile_error()
70            .into();
71        }
72    };
73
74    let fn_name = &input_fn.sig.ident;
75    let fn_vis = &input_fn.vis;
76    let fn_generics = &input_fn.sig.generics;
77    let fn_block = &input_fn.block;
78    let fn_attrs = input_fn
79        .attrs
80        .iter()
81        .filter(|attr| !attr.path().is_ident("test"))
82        .collect::<Vec<&Attribute>>();
83
84    // Create a wrapped test function
85    let wrapper_name = format_ident!("{}_might_panic", fn_name);
86
87    let expanded = quote! {
88        #(#fn_attrs)*
89        #fn_vis fn #fn_name #fn_generics() {
90            #fn_block
91        }
92
93        #[test]
94        #fn_vis fn #wrapper_name #fn_generics() {
95            use std::panic::{self, AssertUnwindSafe};
96
97            let expected_reason = #expected_reason;
98            let result = panic::catch_unwind(AssertUnwindSafe(|| {
99                #fn_name();
100            }));
101
102            match result {
103                Ok(_) => {
104                    // Test passed without panic - this is OK
105                }
106                Err(e) => {
107                    // Convert the panic payload to a string
108                    let panic_msg = if let Some(s) = e.downcast_ref::<String>() {
109                        s.to_string()
110                    } else if let Some(s) = e.downcast_ref::<&str>() {
111                        s.to_string()
112                    } else {
113                        "Unknown panic".to_string()
114                    };
115
116                    // Check if the panic message starts with the expected reason
117                    if !panic_msg.starts_with(expected_reason) {
118                        panic!(
119                            "Test '{}' marked as 'might_panic' failed. Expected reason: '{}'",
120                            stringify!(#fn_name),
121                            expected_reason
122                        );
123                    }
124                }
125            }
126        }
127    };
128
129    expanded.into()
130}
131
132#[allow(missing_docs)]
133#[proc_macro_attribute]
134pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream {
135    let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item);
136    let attr: proc_macro2::TokenStream = proc_macro2::TokenStream::from(attr);
137    let macro_ident = format_ident!("testgen_{}", attr.to_string());
138
139    let macro_gen = quote! {
140        #[allow(missing_docs)]
141        #[macro_export]
142        macro_rules! #macro_ident {
143            () => {
144                mod #attr {
145                    use super::*;
146
147                    #item
148                }
149            };
150        }
151    };
152
153    macro_gen.into()
154}