flipperzero_test_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{ToTokens, quote};
4use syn::{
5    Expr, ExprArray, Ident, Item, ItemMod, ReturnType, Stmt, Token,
6    parse::{self, Parse},
7    punctuated::Punctuated,
8    spanned::Spanned,
9    token,
10};
11
12mod deassert;
13
14struct TestRunner {
15    manifest_args: Punctuated<TestRunnerArg, Token![,]>,
16    test_suites: ExprArray,
17}
18
19impl Parse for TestRunner {
20    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
21        let mut manifest_args = Punctuated::new();
22        if !input.peek(token::Bracket) {
23            loop {
24                if input.is_empty() || input.peek(token::Bracket) {
25                    break;
26                }
27                let value = input.parse()?;
28                manifest_args.push_value(value);
29                if input.is_empty() || input.peek(token::Bracket) {
30                    break;
31                }
32                let punct = input.parse()?;
33                manifest_args.push_punct(punct);
34            }
35        };
36
37        let test_suites = input.parse()?;
38
39        Ok(TestRunner {
40            manifest_args,
41            test_suites,
42        })
43    }
44}
45
46struct TestRunnerArg {
47    ident: Ident,
48    eq_token: Token![=],
49    value: Box<Expr>,
50}
51
52impl Parse for TestRunnerArg {
53    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
54        let ident = input.parse()?;
55        let eq_token = input.parse()?;
56        let value = input.parse()?;
57        Ok(TestRunnerArg {
58            ident,
59            eq_token,
60            value,
61        })
62    }
63}
64
65#[proc_macro]
66pub fn tests_runner(args: TokenStream) -> TokenStream {
67    match tests_runner_impl(args) {
68        Ok(ts) => ts,
69        Err(e) => e.to_compile_error().into(),
70    }
71}
72
73fn tests_runner_impl(args: TokenStream) -> parse::Result<TokenStream> {
74    let TestRunner {
75        manifest_args,
76        test_suites,
77    } = syn::parse(args)?;
78
79    let test_suites = test_suites
80        .elems
81        .into_iter()
82        .map(|attr| {
83            let mut module = String::new();
84            for token in attr.to_token_stream() {
85                module.push_str(&token.to_string());
86            }
87            let module = module.trim_start_matches("crate::");
88
89            (
90                quote!(#attr::__test_list().len()),
91                quote!(#attr::__test_list().iter().copied().map(|(name, test_fn)| (#module, name, test_fn))),
92            )
93        })
94        .collect::<Vec<_>>();
95
96    let test_counts = test_suites.iter().map(|(count, _)| count);
97    let test_lists = test_suites.iter().map(|(_, list)| list);
98
99    let manifest_args = manifest_args.into_iter().map(
100        |TestRunnerArg {
101             ident,
102             eq_token,
103             value,
104         }| { quote!(#ident #eq_token #value) },
105    );
106
107    Ok(quote!(
108        #[cfg(test)]
109        mod __test_runner {
110            // Required for panic handler
111            extern crate flipperzero_rt;
112
113            // Required for allocator
114            #[cfg(feature = "alloc")]
115            extern crate flipperzero_alloc;
116
117            use flipperzero_rt::{entry, manifest};
118
119            manifest!(#(#manifest_args),*);
120            entry!(main);
121
122            const fn test_count() -> usize {
123                let ret = 0;
124                #( let ret = ret + #test_counts; )*
125                ret
126            }
127
128            fn test_list() -> impl Iterator<Item = (&'static str, &'static str, ::flipperzero_test::TestFn)> + Clone {
129                let ret = ::core::iter::empty();
130                #( let ret = ret.chain(#test_lists); )*
131                ret
132            }
133
134            // Test runner entry point
135            pub(super) fn main(args: Option<&::core::ffi::CStr>) -> i32 {
136                let args = ::flipperzero_test::__macro_support::Args::parse(args);
137                match ::flipperzero_test::__macro_support::run_tests(test_count(), test_list(), args) {
138                    Ok(()) => 0,
139                    Err(e) => e,
140                }
141            }
142        }
143
144        #[cfg(all(test, miri))]
145        #[start]
146        fn main(argc: isize, argv: *const *const u8) -> isize {
147            // TODO: Is there any benefit to Miri in hooking up the binary arguments to
148            // the test runner?
149            let ret = __test_runner::main(None);
150
151            // Clean up app state.
152            ::flipperzero_rt::__macro_support::__wait_for_thread_completion();
153
154            ret.try_into().unwrap_or(isize::MAX)
155        }
156    )
157    .into())
158}
159
160#[proc_macro_attribute]
161pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
162    match tests_impl(args, input) {
163        Ok(ts) => ts,
164        Err(e) => e.to_compile_error().into(),
165    }
166}
167
168fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStream> {
169    if !args.is_empty() {
170        return Err(parse::Error::new(
171            Span::call_site(),
172            "`#[tests]` attribute takes no arguments",
173        ));
174    }
175
176    let module: ItemMod = syn::parse(input)?;
177
178    let items = if let Some(content) = module.content {
179        content.1
180    } else {
181        return Err(parse::Error::new(
182            module.span(),
183            "module must be inline (e.g. `mod foo {}`)",
184        ));
185    };
186
187    let mut tests = vec![];
188    let mut test_cfgs = vec![];
189    let mut untouched_tokens = vec![];
190    for item in items {
191        match item {
192            Item::Fn(mut f) => {
193                let mut is_test = false;
194                let mut cfg = vec![];
195
196                // Find and extract the `#[test]` and `#[cfg(..)] attributes, if present.
197                f.attrs.retain(|attr| {
198                    if attr.path.is_ident("test") {
199                        is_test = true;
200                        false
201                    } else {
202                        if attr.path.is_ident("cfg") {
203                            cfg.push(attr.clone());
204                        }
205                        true
206                    }
207                });
208
209                if is_test {
210                    // Enforce expected function signature.
211                    if !f.sig.inputs.is_empty() {
212                        return Err(parse::Error::new(
213                            f.sig.inputs.span(),
214                            "`#[test]` function must have signature `fn()`",
215                        ));
216                    }
217                    if !matches!(f.sig.output, ReturnType::Default) {
218                        return Err(parse::Error::new(
219                            f.sig.output.span(),
220                            "`#[test]` function must have signature `fn()`",
221                        ));
222                    }
223
224                    // Add a `TestResult` return type.
225                    f.sig.output = syn::parse(quote!(-> ::flipperzero_test::TestResult).into())?;
226
227                    // Replace `assert` macros in the test with `TestResult` functions.
228                    f.block = deassert::box_block(f.block)?;
229
230                    // Enforce that the test doesn't return anything. This is somewhat
231                    // redundant with the function signature check above, as the compiler
232                    // will enforce that no value is returned by the unmodified function.
233                    // However, in certain cases this results in better errors, due to us
234                    // appending an `Ok(())` that can interfere with the previous expression.
235                    check_ret_block(&mut f.block.stmts)?;
236
237                    // Append an `Ok(())` to the test.
238                    f.block.stmts.push(Stmt::Expr(syn::parse(
239                        quote!(::core::result::Result::Ok(())).into(),
240                    )?));
241
242                    tests.push(f);
243                    test_cfgs.push(cfg);
244                } else {
245                    untouched_tokens.push(Item::Fn(f));
246                }
247            }
248            _ => {
249                untouched_tokens.push(item);
250            }
251        }
252    }
253
254    let ident = module.ident;
255    let test_names = tests.iter().zip(test_cfgs).map(|(test, cfg)| {
256        let ident = &test.sig.ident;
257        let name = ident.to_string();
258        quote! {
259            #(#cfg)*
260            (#name, #ident)
261        }
262    });
263
264    Ok(quote!(
265        #[cfg(test)]
266        pub(crate) mod #ident {
267            #(#untouched_tokens)*
268
269            #(#tests)*
270
271            pub(crate) const fn __test_list() -> &'static [(&'static str, ::flipperzero_test::TestFn)] {
272                &[#(#test_names), *]
273            }
274        }
275    )
276    .into())
277}
278
279fn check_ret_block(stmts: &mut [Stmt]) -> parse::Result<()> {
280    if let Some(stmt) = stmts.last_mut() {
281        if let Stmt::Expr(expr) = stmt {
282            if let Some(new_stmt) = check_ret_expr(expr)? {
283                *stmt = new_stmt;
284            }
285        }
286    }
287    Ok(())
288}
289
290fn check_ret_expr(expr: &mut Expr) -> parse::Result<Option<Stmt>> {
291    match expr {
292        // If `expr` is a block that implicitly returns `()`, do nothing.
293        Expr::ForLoop(_) | Expr::While(_) => Ok(None),
294        // If `expr` is a block, recurse into it.
295        Expr::Async(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
296        Expr::Block(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
297        Expr::If(e) => {
298            // Checking the first branch is sufficient; the compiler will enforce that the
299            // other branches match.
300            check_ret_block(&mut e.then_branch.stmts).map(|()| None)
301        }
302        Expr::Loop(e) => check_ret_block(&mut e.body.stmts).map(|()| None),
303        Expr::Match(e) => {
304            if let Some(arm) = e.arms.first_mut() {
305                if let Some(stmt) = check_ret_expr(&mut arm.body)? {
306                    *arm.body = Expr::Block(syn::parse(quote!({#stmt}).into())?);
307                }
308            }
309            Ok(None)
310        }
311        Expr::TryBlock(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
312        Expr::Unsafe(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
313        // If `expr` implicitly returns `()`, append a semicolon.
314        Expr::Assign(_) | Expr::AssignOp(_) => {
315            Ok(Some(Stmt::Semi(expr.clone(), Token!(;)(expr.span()))))
316        }
317        Expr::Break(brk) if brk.expr.is_none() => {
318            Ok(Some(Stmt::Semi(expr.clone(), Token!(;)(expr.span()))))
319        }
320        // For all other expressions, raise an error.
321        _ => Err(parse::Error::new(
322            expr.span(),
323            "`#[test]` function must not return anything",
324        )),
325    }
326}