mos_test_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::{format_ident, quote, quote_spanned};
6use syn::{parse, spanned::Spanned, Attribute, Item, ItemFn, ItemMod, ReturnType, Type};
7
8#[proc_macro_attribute]
9pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
10    match tests_impl(args, input) {
11        Ok(ts) => ts,
12        Err(e) => e.to_compile_error().into(),
13    }
14}
15
16fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStream> {
17    if !args.is_empty() {
18        return Err(parse::Error::new(
19            Span::call_site(),
20            "`#[test]` attribute takes no arguments",
21        ));
22    }
23
24    let module: ItemMod = syn::parse(input)?;
25
26    let items = if let Some(content) = module.content {
27        content.1
28    } else {
29        return Err(parse::Error::new(
30            module.span(),
31            "module must be inline (e.g. `mod foo {}`)",
32        ));
33    };
34
35    let mut init = None;
36    let mut before_each = None;
37    let mut after_each = None;
38    let mut tests = vec![];
39    let mut untouched_tokens = vec![];
40    for item in items {
41        match item {
42            Item::Fn(mut f) => {
43                let mut test_kind = None;
44                let mut should_error = false;
45                let mut ignore = false;
46
47                f.attrs.retain(|attr| {
48                    if attr.path.is_ident("init") {
49                        test_kind = Some(Attr::Init);
50                        false
51                    } else if attr.path.is_ident("test") {
52                        test_kind = Some(Attr::Test);
53                        false
54                    } else if attr.path.is_ident("before_each") {
55                        test_kind = Some(Attr::BeforeEach);
56                        false
57                    } else if attr.path.is_ident("after_each") {
58                        test_kind = Some(Attr::AfterEach);
59                        false
60                    } else if attr.path.is_ident("should_error") {
61                        should_error = true;
62                        false
63                    } else if attr.path.is_ident("ignore") {
64                        ignore = true;
65                        false
66                    } else {
67                        true
68                    }
69                });
70
71                let attr = match test_kind {
72                    Some(it) => it,
73                    None => {
74                        return Err(parse::Error::new(
75                            f.span(),
76                            "function requires `#[init]`, `#[before_each]`, `#[after_each]`, or `#[test]` attribute",
77                        ));
78                    }
79                };
80
81                match attr {
82                    Attr::Init => {
83                        if init.is_some() {
84                            return Err(parse::Error::new(
85                                f.sig.ident.span(),
86                                "only a single `#[init]` function can be defined",
87                            ));
88                        }
89
90                        if should_error {
91                            return Err(parse::Error::new(
92                                f.sig.ident.span(),
93                                "`#[should_error]` is not allowed on the `#[init]` function",
94                            ));
95                        }
96
97                        if ignore {
98                            return Err(parse::Error::new(
99                                f.sig.ident.span(),
100                                "`#[ignore]` is not allowed on the `#[init]` function",
101                            ));
102                        }
103
104                        if check_fn_sig(&f.sig).is_err() || !f.sig.inputs.is_empty() {
105                            return Err(parse::Error::new(
106                                f.sig.ident.span(),
107                                "`#[init]` function must have signature `fn() [-> Type]` (the return type is optional)",
108                            ));
109                        }
110
111                        let state = match &f.sig.output {
112                            ReturnType::Default => None,
113                            ReturnType::Type(.., ty) => Some(ty.clone()),
114                        };
115
116                        init = Some(Init { func: f, state });
117                    }
118
119                    Attr::Test => {
120                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
121                            return Err(parse::Error::new(
122                                f.sig.ident.span(),
123                                "`#[test]` function must have signature `fn(state: &mut Type)` (parameter is optional)",
124                            ));
125                        }
126
127                        let input = if f.sig.inputs.len() == 1 {
128                            let arg = &f.sig.inputs[0];
129
130                            // NOTE we cannot check the argument type matches `init.state` at this
131                            // point
132                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
133                                Some(Input { ty })
134                            } else {
135                                // was not `&mut T`
136                                return Err(parse::Error::new(
137                                    arg.span(),
138                                    "parameter must be a mutable reference (`&mut $Type`)",
139                                ));
140                            }
141                        } else {
142                            None
143                        };
144
145                        tests.push(Test {
146                            cfgs: extract_cfgs(&f.attrs),
147                            func: f,
148                            input,
149                            should_error,
150                            ignore,
151                        })
152                    }
153                    Attr::BeforeEach => {
154                        if before_each.is_some() {
155                            return Err(parse::Error::new(
156                                f.sig.ident.span(),
157                                "only a single `#[before_each]` function can be defined",
158                            ));
159                        }
160
161                        if should_error {
162                            return Err(parse::Error::new(
163                                f.sig.ident.span(),
164                                "`#[should_error]` is not allowed on the `#[before_each]` function",
165                            ));
166                        }
167
168                        if ignore {
169                            return Err(parse::Error::new(
170                                f.sig.ident.span(),
171                                "`#[ignore]` is not allowed on the `#[before_each]` function",
172                            ));
173                        }
174
175                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
176                            return Err(parse::Error::new(
177                                f.sig.ident.span(),
178                                "`#[before_each]` function must have signature `fn(state: &mut Type)` (parameter is optional)",
179                            ));
180                        }
181
182                        let input = if f.sig.inputs.len() == 1 {
183                            let arg = &f.sig.inputs[0];
184
185                            // NOTE we cannot check the argument type matches `init.state` at this
186                            // point
187                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
188                                Some(Input { ty })
189                            } else {
190                                // was not `&mut T`
191                                return Err(parse::Error::new(
192                                    arg.span(),
193                                    "parameter must be a mutable reference (`&mut $Type`)",
194                                ));
195                            }
196                        } else {
197                            None
198                        };
199
200                        before_each = Some(BeforeEach { func: f, input });
201                    }
202                    Attr::AfterEach => {
203                        if after_each.is_some() {
204                            return Err(parse::Error::new(
205                                f.sig.ident.span(),
206                                "only a single `#[after_each]` function can be defined",
207                            ));
208                        }
209
210                        if should_error {
211                            return Err(parse::Error::new(
212                                f.sig.ident.span(),
213                                "`#[should_error]` is not allowed on the `#[after_each]` function",
214                            ));
215                        }
216
217                        if ignore {
218                            return Err(parse::Error::new(
219                                f.sig.ident.span(),
220                                "`#[ignore]` is not allowed on the `#[after_each]` function",
221                            ));
222                        }
223
224                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
225                            return Err(parse::Error::new(
226                                f.sig.ident.span(),
227                                "`#[after_each]` function must have signature `fn(state: &mut Type)` (parameter is optional)",
228                            ));
229                        }
230
231                        let input = if f.sig.inputs.len() == 1 {
232                            let arg = &f.sig.inputs[0];
233
234                            // NOTE we cannot check the argument type matches `init.state` at this
235                            // point
236                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
237                                Some(Input { ty })
238                            } else {
239                                // was not `&mut T`
240                                return Err(parse::Error::new(
241                                    arg.span(),
242                                    "parameter must be a mutable reference (`&mut $Type`)",
243                                ));
244                            }
245                        } else {
246                            None
247                        };
248
249                        after_each = Some(AfterEach { func: f, input });
250                    }
251                }
252            }
253
254            _ => {
255                untouched_tokens.push(item);
256            }
257        }
258    }
259
260    let krate = format_ident!("mos_test");
261    let ident = module.ident;
262    let mut state_ty = None;
263    let (init_fn, init_expr) = if let Some(init) = init {
264        let init_func = &init.func;
265        let init_ident = &init.func.sig.ident;
266        state_ty = init.state;
267
268        (
269            Some(quote!(#init_func)),
270            Some(quote!(#[allow(dead_code)] let mut state = #init_ident();)),
271        )
272    } else {
273        (None, None)
274    };
275
276    let (before_each_fn, before_each_call) = if let Some(before_each) = before_each {
277        let before_each_func = &before_each.func;
278        let before_each_ident = &before_each.func.sig.ident;
279        let span = before_each.func.sig.ident.span();
280
281        let call = if let Some(input) = before_each.input.as_ref() {
282            if let Some(state) = &state_ty {
283                if input.ty != **state {
284                    return Err(parse::Error::new(
285                        input.ty.span(),
286                        format!(
287                            "this type must match `#[init]`s return type: {}",
288                            type_ident(state)
289                        ),
290                    ));
291                }
292            } else {
293                return Err(parse::Error::new(
294                    span,
295                    "no state was initialized by `#[init]`; signature must be `fn()`",
296                ));
297            }
298
299            quote!(#before_each_ident(&mut state))
300        } else {
301            quote!(#before_each_ident())
302        };
303
304        (Some(quote!(#before_each_func)), Some(quote!(#call)))
305    } else {
306        (None, None)
307    };
308
309    let (after_each_fn, after_each_call) = if let Some(after_each) = after_each {
310        let after_each_func = &after_each.func;
311        let after_each_ident = &after_each.func.sig.ident;
312        let span = after_each.func.sig.ident.span();
313
314        let call = if let Some(input) = after_each.input.as_ref() {
315            if let Some(state) = &state_ty {
316                if input.ty != **state {
317                    return Err(parse::Error::new(
318                        input.ty.span(),
319                        format!(
320                            "this type must match `#[init]`s return type: {}",
321                            type_ident(state)
322                        ),
323                    ));
324                }
325            } else {
326                return Err(parse::Error::new(
327                    span,
328                    "no state was initialized by `#[init]`; signature must be `fn()`",
329                ));
330            }
331
332            quote!(#after_each_ident(&mut state))
333        } else {
334            quote!(#after_each_ident())
335        };
336
337        (Some(quote!(#after_each_func)), Some(quote!(#call)))
338    } else {
339        (None, None)
340    };
341
342    let mut unit_test_calls = vec![];
343    for test in &tests {
344        let should_error = test.should_error;
345        let ignore = test.ignore;
346        let ident = &test.func.sig.ident;
347        let span = test.func.sig.ident.span();
348        let call = if let Some(input) = test.input.as_ref() {
349            if let Some(state) = &state_ty {
350                if input.ty != **state {
351                    return Err(parse::Error::new(
352                        input.ty.span(),
353                        format!(
354                            "this type must match `#[init]`s return type: {}",
355                            type_ident(state)
356                        ),
357                    ));
358                }
359            } else {
360                return Err(parse::Error::new(
361                    span,
362                    "no state was initialized by `#[init]`; signature must be `fn()`",
363                ));
364            }
365
366            quote!(#ident(&mut state))
367        } else {
368            quote!(#ident())
369        };
370        if ignore {
371            unit_test_calls.push(quote!(let _ = #call;));
372        } else {
373            unit_test_calls.push(quote!(
374                #before_each_call;
375                #krate::export::check_outcome(#krate::OutcomeWrapper(#call), #should_error);
376                #after_each_call;
377            ));
378        }
379    }
380
381    let test_functions = tests.iter().map(|test| &test.func);
382    let test_cfgs = tests.iter().map(|test| &test.cfgs);
383    let declare_test_count = {
384        let test_cfgs = test_cfgs.clone();
385        quote!(
386            // We can't evaluate `#[cfg]`s in the macro, but this works too.
387            // Below value can be used to read the number of tests from the produced ELF.
388            #[used]
389            #[no_mangle]
390            static DEFMT_TEST_COUNT: usize = {
391                let mut counter = 0;
392                #(
393                    #(#test_cfgs)*
394                    { counter += 1; }
395                )*
396                counter
397            };
398        )
399    };
400    let unit_test_progress = tests
401        .iter()
402        .map(|test| {
403            let message = format!(
404                "({{}}/{{}}) {} `{}`...",
405                if test.ignore { "ignoring" } else { "running" },
406                test.func.sig.ident
407            );
408            quote_spanned! {
409                test.func.sig.ident.span() => ufmt_stdio::println!(#message, __defmt_test_number, DEFMT_TEST_COUNT);
410            }
411        })
412        .collect::<Vec<_>>();
413    Ok(quote!(
414    #[cfg(test)]
415    mod #ident {
416        use ufmt_stdio::ufmt;
417        #(#untouched_tokens)*
418        // TODO use `cortex-m-rt::entry` here to get the `static mut` transform
419        #[export_name = "main"]
420        unsafe extern "C" fn __defmt_test_entry() -> isize {
421            #declare_test_count
422            #init_expr
423
424            let mut __defmt_test_number: usize = 1;
425            #(
426                #(#test_cfgs)*
427                {
428                    #unit_test_progress
429                    #unit_test_calls
430                    __defmt_test_number += 1;
431                }
432            )*
433
434            ufmt_stdio::println!("all tests passed!");
435            // #krate::export::exit()
436            0
437        }
438
439        #init_fn
440
441        #before_each_fn
442
443        #after_each_fn
444
445        #(
446            #test_functions
447        )*
448    })
449    .into())
450}
451
452#[derive(Clone, Copy)]
453enum Attr {
454    AfterEach,
455    BeforeEach,
456    Init,
457    Test,
458}
459
460struct AfterEach {
461    func: ItemFn,
462    input: Option<Input>,
463}
464
465struct BeforeEach {
466    func: ItemFn,
467    input: Option<Input>,
468}
469
470struct Init {
471    func: ItemFn,
472    state: Option<Box<Type>>,
473}
474
475struct Test {
476    func: ItemFn,
477    cfgs: Vec<Attribute>,
478    input: Option<Input>,
479    should_error: bool,
480    ignore: bool,
481}
482
483struct Input {
484    ty: Type,
485}
486
487// NOTE doesn't check the parameters or the return type
488fn check_fn_sig(sig: &syn::Signature) -> Result<(), ()> {
489    if sig.constness.is_none()
490        && sig.asyncness.is_none()
491        && sig.unsafety.is_none()
492        && sig.abi.is_none()
493        && sig.generics.params.is_empty()
494        && sig.generics.where_clause.is_none()
495        && sig.variadic.is_none()
496    {
497        Ok(())
498    } else {
499        Err(())
500    }
501}
502
503fn get_mutable_reference_type(arg: &syn::FnArg) -> Option<&Type> {
504    if let syn::FnArg::Typed(pat) = arg {
505        if let syn::Type::Reference(refty) = &*pat.ty {
506            if refty.mutability.is_some() {
507                Some(&refty.elem)
508            } else {
509                None
510            }
511        } else {
512            None
513        }
514    } else {
515        None
516    }
517}
518
519fn extract_cfgs(attrs: &[Attribute]) -> Vec<Attribute> {
520    let mut cfgs = vec![];
521
522    for attr in attrs {
523        if attr.path.is_ident("cfg") {
524            cfgs.push(attr.clone());
525        }
526    }
527
528    cfgs
529}
530
531fn type_ident(ty: impl AsRef<syn::Type>) -> String {
532    let mut ident = String::new();
533    let ty = ty.as_ref();
534    let ty = format!("{}", quote!(#ty));
535    ty.split_whitespace().for_each(|t| ident.push_str(t));
536    ident
537}