flipperzero_test_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, ToTokens};
4use syn::{
5    parse::{self, Parse},
6    punctuated::Punctuated,
7    spanned::Spanned,
8    token, Expr, ExprArray, Ident, Item, ItemMod, ReturnType, Stmt, Token,
9};
10
11mod deassert;
12
13struct TestRunner {
14    manifest_args: Punctuated<TestRunnerArg, Token![,]>,
15    test_suites: ExprArray,
16}
17
18impl Parse for TestRunner {
19    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
20        let mut manifest_args = Punctuated::new();
21        if !input.peek(token::Bracket) {
22            loop {
23                if input.is_empty() || input.peek(token::Bracket) {
24                    break;
25                }
26                let value = input.parse()?;
27                manifest_args.push_value(value);
28                if input.is_empty() || input.peek(token::Bracket) {
29                    break;
30                }
31                let punct = input.parse()?;
32                manifest_args.push_punct(punct);
33            }
34        };
35
36        let test_suites = input.parse()?;
37
38        Ok(TestRunner {
39            manifest_args,
40            test_suites,
41        })
42    }
43}
44
45struct TestRunnerArg {
46    ident: Ident,
47    eq_token: Token![=],
48    value: Box<Expr>,
49}
50
51impl Parse for TestRunnerArg {
52    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
53        let ident = input.parse()?;
54        let eq_token = input.parse()?;
55        let value = input.parse()?;
56        Ok(TestRunnerArg {
57            ident,
58            eq_token,
59            value,
60        })
61    }
62}
63
64#[proc_macro]
65pub fn tests_runner(args: TokenStream) -> TokenStream {
66    match tests_runner_impl(args) {
67        Ok(ts) => ts,
68        Err(e) => e.to_compile_error().into(),
69    }
70}
71
72fn tests_runner_impl(args: TokenStream) -> parse::Result<TokenStream> {
73    let TestRunner {
74        manifest_args,
75        test_suites,
76    } = syn::parse(args)?;
77
78    let test_suites = test_suites
79        .elems
80        .into_iter()
81        .map(|attr| {
82            let mut module = String::new();
83            for token in attr.to_token_stream() {
84                module.push_str(&token.to_string());
85            }
86            let module = module.trim_start_matches("crate::");
87
88            (
89                quote!(#attr::__test_list().len()),
90                quote!(#attr::__test_list().iter().copied().map(|(name, test_fn)| (#module, name, test_fn))),
91            )
92        })
93        .collect::<Vec<_>>();
94
95    let test_counts = test_suites.iter().map(|(count, _)| count);
96    let test_lists = test_suites.iter().map(|(_, list)| list);
97
98    let manifest_args = manifest_args.into_iter().map(
99        |TestRunnerArg {
100             ident,
101             eq_token,
102             value,
103         }| { quote!(#ident #eq_token #value) },
104    );
105
106    Ok(quote!(
107        #[cfg(test)]
108        mod __test_runner {
109            // Required for panic handler
110            extern crate flipperzero_rt;
111
112            // Required for allocator
113            #[cfg(feature = "alloc")]
114            extern crate flipperzero_alloc;
115
116            use flipperzero_rt::{entry, manifest};
117
118            manifest!(#(#manifest_args),*);
119            entry!(main);
120
121            const fn test_count() -> usize {
122                let ret = 0;
123                #( let ret = ret + #test_counts; )*
124                ret
125            }
126
127            fn test_list() -> impl Iterator<Item = (&'static str, &'static str, ::flipperzero_test::TestFn)> + Clone {
128                let ret = ::core::iter::empty();
129                #( let ret = ret.chain(#test_lists); )*
130                ret
131            }
132
133            // Test runner entry point
134            pub(super) fn main(args: Option<&::core::ffi::CStr>) -> i32 {
135                let args = ::flipperzero_test::__macro_support::Args::parse(args);
136                match ::flipperzero_test::__macro_support::run_tests(test_count(), test_list(), args) {
137                    Ok(()) => 0,
138                    Err(e) => e,
139                }
140            }
141        }
142
143        #[cfg(all(test, miri))]
144        #[start]
145        fn main(argc: isize, argv: *const *const u8) -> isize {
146            // TODO: Is there any benefit to Miri in hooking up the binary arguments to
147            // the test runner?
148            let ret = __test_runner::main(None);
149
150            // Clean up app state.
151            ::flipperzero_rt::__macro_support::__wait_for_thread_completion();
152
153            ret.try_into().unwrap_or(isize::MAX)
154        }
155    )
156    .into())
157}
158
159#[proc_macro_attribute]
160pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
161    match tests_impl(args, input) {
162        Ok(ts) => ts,
163        Err(e) => e.to_compile_error().into(),
164    }
165}
166
167fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStream> {
168    if !args.is_empty() {
169        return Err(parse::Error::new(
170            Span::call_site(),
171            "`#[tests]` attribute takes no arguments",
172        ));
173    }
174
175    let module: ItemMod = syn::parse(input)?;
176
177    let items = if let Some(content) = module.content {
178        content.1
179    } else {
180        return Err(parse::Error::new(
181            module.span(),
182            "module must be inline (e.g. `mod foo {}`)",
183        ));
184    };
185
186    let mut tests = vec![];
187    let mut test_cfgs = vec![];
188    let mut untouched_tokens = vec![];
189    for item in items {
190        match item {
191            Item::Fn(mut f) => {
192                let mut is_test = false;
193                let mut cfg = vec![];
194
195                // Find and extract the `#[test]` and `#[cfg(..)] attributes, if present.
196                f.attrs.retain(|attr| {
197                    if attr.path.is_ident("test") {
198                        is_test = true;
199                        false
200                    } else {
201                        if attr.path.is_ident("cfg") {
202                            cfg.push(attr.clone());
203                        }
204                        true
205                    }
206                });
207
208                if is_test {
209                    // Enforce expected function signature.
210                    if !f.sig.inputs.is_empty() {
211                        return Err(parse::Error::new(
212                            f.sig.inputs.span(),
213                            "`#[test]` function must have signature `fn()`",
214                        ));
215                    }
216                    if !matches!(f.sig.output, ReturnType::Default) {
217                        return Err(parse::Error::new(
218                            f.sig.output.span(),
219                            "`#[test]` function must have signature `fn()`",
220                        ));
221                    }
222
223                    // Add a `TestResult` return type.
224                    f.sig.output = syn::parse(quote!(-> ::flipperzero_test::TestResult).into())?;
225
226                    // Replace `assert` macros in the test with `TestResult` functions.
227                    f.block = deassert::box_block(f.block)?;
228
229                    // Enforce that the test doesn't return anything. This is somewhat
230                    // redundant with the function signature check above, as the compiler
231                    // will enforce that no value is returned by the unmodified function.
232                    // However, in certain cases this results in better errors, due to us
233                    // appending an `Ok(())` that can interfere with the previous expression.
234                    check_ret_block(&mut f.block.stmts)?;
235
236                    // Append an `Ok(())` to the test.
237                    f.block.stmts.push(Stmt::Expr(syn::parse(
238                        quote!(::core::result::Result::Ok(())).into(),
239                    )?));
240
241                    tests.push(f);
242                    test_cfgs.push(cfg);
243                } else {
244                    untouched_tokens.push(Item::Fn(f));
245                }
246            }
247            _ => {
248                untouched_tokens.push(item);
249            }
250        }
251    }
252
253    let ident = module.ident;
254    let test_names = tests.iter().zip(test_cfgs).map(|(test, cfg)| {
255        let ident = &test.sig.ident;
256        let name = ident.to_string();
257        quote! {
258            #(#cfg)*
259            (#name, #ident)
260        }
261    });
262
263    Ok(quote!(
264        #[cfg(test)]
265        pub(crate) mod #ident {
266            #(#untouched_tokens)*
267
268            #(#tests)*
269
270            pub(crate) const fn __test_list() -> &'static [(&'static str, ::flipperzero_test::TestFn)] {
271                &[#(#test_names), *]
272            }
273        }
274    )
275    .into())
276}
277
278fn check_ret_block(stmts: &mut [Stmt]) -> parse::Result<()> {
279    if let Some(stmt) = stmts.last_mut() {
280        if let Stmt::Expr(expr) = stmt {
281            if let Some(new_stmt) = check_ret_expr(expr)? {
282                *stmt = new_stmt;
283            }
284        }
285    }
286    Ok(())
287}
288
289fn check_ret_expr(expr: &mut Expr) -> parse::Result<Option<Stmt>> {
290    match expr {
291        // If `expr` is a block that implicitly returns `()`, do nothing.
292        Expr::ForLoop(_) | Expr::While(_) => Ok(None),
293        // If `expr` is a block, recurse into it.
294        Expr::Async(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
295        Expr::Block(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
296        Expr::If(e) => {
297            // Checking the first branch is sufficient; the compiler will enforce that the
298            // other branches match.
299            check_ret_block(&mut e.then_branch.stmts).map(|()| None)
300        }
301        Expr::Loop(e) => check_ret_block(&mut e.body.stmts).map(|()| None),
302        Expr::Match(e) => {
303            if let Some(arm) = e.arms.first_mut() {
304                if let Some(stmt) = check_ret_expr(&mut arm.body)? {
305                    *arm.body = Expr::Block(syn::parse(quote!({#stmt}).into())?);
306                }
307            }
308            Ok(None)
309        }
310        Expr::TryBlock(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
311        Expr::Unsafe(e) => check_ret_block(&mut e.block.stmts).map(|()| None),
312        // If `expr` implicitly returns `()`, append a semicolon.
313        Expr::Assign(_) | Expr::AssignOp(_) => {
314            Ok(Some(Stmt::Semi(expr.clone(), Token!(;)(expr.span()))))
315        }
316        Expr::Break(brk) if brk.expr.is_none() => {
317            Ok(Some(Stmt::Semi(expr.clone(), Token!(;)(expr.span()))))
318        }
319        // For all other expressions, raise an error.
320        _ => Err(parse::Error::new(
321            expr.span(),
322            "`#[test]` function must not return anything",
323        )),
324    }
325}