defmt_test_macros/
lib.rs

1#![doc(html_logo_url = "https://knurling.ferrous-systems.com/knurling_logo_light_text.svg")]
2
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::{format_ident, quote, quote_spanned};
8use syn::{parse, spanned::Spanned, Attribute, Item, ItemFn, ItemMod, ReturnType, Type};
9
10#[proc_macro_attribute]
11pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
12    match tests_impl(args, input) {
13        Ok(ts) => ts,
14        Err(e) => e.to_compile_error().into(),
15    }
16}
17
18macro_rules! fn_state_signature_msg {
19    ($name:literal) => {
20        concat!(
21            "`#[",
22            $name,
23            "]` function must have signature `fn()` or `fn(state: &mut T)`"
24        )
25    };
26}
27
28fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStream> {
29    if !args.is_empty() {
30        return Err(parse::Error::new(
31            Span::call_site(),
32            "`#[test]` attribute takes no arguments",
33        ));
34    }
35
36    let module: ItemMod = syn::parse(input)?;
37
38    let items = if let Some(content) = module.content {
39        content.1
40    } else {
41        return Err(parse::Error::new(
42            module.span(),
43            "module must be inline (e.g. `mod foo {}`)",
44        ));
45    };
46
47    let mut init = None;
48    let mut teardown = None;
49    let mut before_each = None;
50    let mut after_each = None;
51    let mut tests = vec![];
52    let mut untouched_tokens = vec![];
53    for item in items {
54        match item {
55            Item::Fn(mut f) => {
56                let mut test_kind = None;
57                let mut should_error = false;
58                let mut ignore = false;
59
60                f.attrs.retain(|attr| {
61                    if attr.path().is_ident("init") {
62                        test_kind = Some(Attr::Init);
63                        false
64                    } else if attr.path().is_ident("teardown") {
65                        test_kind = Some(Attr::Teardown);
66                        false
67                    } else if attr.path().is_ident("test") {
68                        test_kind = Some(Attr::Test);
69                        false
70                    } else if attr.path().is_ident("before_each") {
71                        test_kind = Some(Attr::BeforeEach);
72                        false
73                    } else if attr.path().is_ident("after_each") {
74                        test_kind = Some(Attr::AfterEach);
75                        false
76                    } else if attr.path().is_ident("should_error") {
77                        should_error = true;
78                        false
79                    } else if attr.path().is_ident("ignore") {
80                        ignore = true;
81                        false
82                    } else {
83                        true
84                    }
85                });
86
87                let attr = match test_kind {
88                    Some(it) => it,
89                    None => {
90                        return Err(parse::Error::new(
91                            f.span(),
92                            "function requires `#[init]`, `#[teardown]`, `#[before_each]`, `#[after_each]`, or `#[test]` attribute",
93                        ));
94                    }
95                };
96
97                match attr {
98                    Attr::Init => {
99                        if init.is_some() {
100                            return Err(parse::Error::new(
101                                f.sig.ident.span(),
102                                "only a single `#[init]` function can be defined",
103                            ));
104                        }
105
106                        if should_error {
107                            return Err(parse::Error::new(
108                                f.sig.ident.span(),
109                                "`#[should_error]` is not allowed on the `#[init]` function",
110                            ));
111                        }
112
113                        if ignore {
114                            return Err(parse::Error::new(
115                                f.sig.ident.span(),
116                                "`#[ignore]` is not allowed on the `#[init]` function",
117                            ));
118                        }
119
120                        if check_fn_sig(&f.sig).is_err() || !f.sig.inputs.is_empty() {
121                            return Err(parse::Error::new(
122                                f.sig.ident.span(),
123                                "`#[init]` function must have signature `fn()` or `fn() -> T`",
124                            ));
125                        }
126
127                        let state = match &f.sig.output {
128                            ReturnType::Default => None,
129                            ReturnType::Type(.., ty) => Some(ty.clone()),
130                        };
131
132                        init = Some(Init { func: f, state });
133                    }
134
135                    Attr::Teardown => {
136                        if teardown.is_some() {
137                            return Err(parse::Error::new(
138                                f.sig.ident.span(),
139                                "only a single `#[teardown]` function can be defined",
140                            ));
141                        }
142
143                        if should_error {
144                            return Err(parse::Error::new(
145                                f.sig.ident.span(),
146                                "`#[should_error]` is not allowed on the `#[teardown]` function",
147                            ));
148                        }
149
150                        if ignore {
151                            return Err(parse::Error::new(
152                                f.sig.ident.span(),
153                                "`#[ignore]` is not allowed on the `#[teardown]` function",
154                            ));
155                        }
156
157                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
158                            return Err(parse::Error::new(
159                                f.sig.ident.span(),
160                                fn_state_signature_msg!("teardown"),
161                            ));
162                        }
163
164                        let input = if f.sig.inputs.len() == 1 {
165                            let arg = &f.sig.inputs[0];
166
167                            // NOTE we cannot check the argument type matches `init.state` at this
168                            // point
169                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
170                                Some(Input { ty })
171                            } else {
172                                // was not `&mut T`
173                                return Err(parse::Error::new(
174                                    arg.span(),
175                                    "parameter must be a mutable reference (`&mut $Type`)",
176                                ));
177                            }
178                        } else {
179                            None
180                        };
181
182                        teardown = Some(Teardown { func: f, input });
183                    }
184
185                    Attr::Test => {
186                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
187                            return Err(parse::Error::new(
188                                f.sig.ident.span(),
189                                fn_state_signature_msg!("test"),
190                            ));
191                        }
192
193                        let input = if f.sig.inputs.len() == 1 {
194                            let arg = &f.sig.inputs[0];
195
196                            // NOTE we cannot check the argument type matches `init.state` at this
197                            // point
198                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
199                                Some(Input { ty })
200                            } else {
201                                // was not `&mut T`
202                                return Err(parse::Error::new(
203                                    arg.span(),
204                                    "parameter must be a mutable reference (`&mut $Type`)",
205                                ));
206                            }
207                        } else {
208                            None
209                        };
210
211                        tests.push(Test {
212                            cfgs: extract_cfgs(&f.attrs),
213                            func: f,
214                            input,
215                            should_error,
216                            ignore,
217                        })
218                    }
219                    Attr::BeforeEach => {
220                        if before_each.is_some() {
221                            return Err(parse::Error::new(
222                                f.sig.ident.span(),
223                                "only a single `#[before_each]` function can be defined",
224                            ));
225                        }
226
227                        if should_error {
228                            return Err(parse::Error::new(
229                                f.sig.ident.span(),
230                                "`#[should_error]` is not allowed on the `#[before_each]` function",
231                            ));
232                        }
233
234                        if ignore {
235                            return Err(parse::Error::new(
236                                f.sig.ident.span(),
237                                "`#[ignore]` is not allowed on the `#[before_each]` function",
238                            ));
239                        }
240
241                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
242                            return Err(parse::Error::new(
243                                f.sig.ident.span(),
244                                fn_state_signature_msg!("before_each"),
245                            ));
246                        }
247
248                        let input = if f.sig.inputs.len() == 1 {
249                            let arg = &f.sig.inputs[0];
250
251                            // NOTE we cannot check the argument type matches `init.state` at this
252                            // point
253                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
254                                Some(Input { ty })
255                            } else {
256                                // was not `&mut T`
257                                return Err(parse::Error::new(
258                                    arg.span(),
259                                    "parameter must be a mutable reference (`&mut $Type`)",
260                                ));
261                            }
262                        } else {
263                            None
264                        };
265
266                        before_each = Some(BeforeEach { func: f, input });
267                    }
268                    Attr::AfterEach => {
269                        if after_each.is_some() {
270                            return Err(parse::Error::new(
271                                f.sig.ident.span(),
272                                "only a single `#[after_each]` function can be defined",
273                            ));
274                        }
275
276                        if should_error {
277                            return Err(parse::Error::new(
278                                f.sig.ident.span(),
279                                "`#[should_error]` is not allowed on the `#[after_each]` function",
280                            ));
281                        }
282
283                        if ignore {
284                            return Err(parse::Error::new(
285                                f.sig.ident.span(),
286                                "`#[ignore]` is not allowed on the `#[after_each]` function",
287                            ));
288                        }
289
290                        if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
291                            return Err(parse::Error::new(
292                                f.sig.ident.span(),
293                                fn_state_signature_msg!("after_each"),
294                            ));
295                        }
296
297                        let input = if f.sig.inputs.len() == 1 {
298                            let arg = &f.sig.inputs[0];
299
300                            // NOTE we cannot check the argument type matches `init.state` at this
301                            // point
302                            if let Some(ty) = get_mutable_reference_type(arg).cloned() {
303                                Some(Input { ty })
304                            } else {
305                                // was not `&mut T`
306                                return Err(parse::Error::new(
307                                    arg.span(),
308                                    "parameter must be a mutable reference (`&mut $Type`)",
309                                ));
310                            }
311                        } else {
312                            None
313                        };
314
315                        after_each = Some(AfterEach { func: f, input });
316                    }
317                }
318            }
319
320            _ => {
321                untouched_tokens.push(item);
322            }
323        }
324    }
325
326    let krate = format_ident!("defmt_test");
327    let ident = module.ident;
328    let mut state_ty = None;
329    let (init_fn, init_expr) = if let Some(init) = init {
330        let init_func = &init.func;
331        let init_ident = &init.func.sig.ident;
332        state_ty = init.state;
333
334        (
335            Some(quote!(#init_func)),
336            Some(quote!(#[allow(dead_code)] let mut state = #init_ident();)),
337        )
338    } else {
339        (None, None)
340    };
341
342    let (teardown_fn, teardown_expr) = if let Some(teardown) = teardown {
343        let teardown_func = &teardown.func;
344        let teardown_ident = &teardown.func.sig.ident;
345        let span = teardown.func.sig.ident.span();
346
347        let call = if let Some(input) = teardown.input.as_ref() {
348            if let Some(state) = &state_ty {
349                if input.ty != **state {
350                    return Err(parse::Error::new(
351                        input.ty.span(),
352                        format!(
353                            "this type must match `#[init]`s return type: {}",
354                            type_ident(state)
355                        ),
356                    ));
357                }
358            } else {
359                return Err(parse::Error::new(
360                    span,
361                    "no state was initialized by `#[init]`; signature must be `fn()`",
362                ));
363            }
364
365            quote!(#teardown_ident(&mut state))
366        } else {
367            quote!(#teardown_ident())
368        };
369
370        (Some(quote!(#teardown_func)), Some(quote!(#call;)))
371    } else {
372        (None, None)
373    };
374
375    let (before_each_fn, before_each_call) = if let Some(before_each) = before_each {
376        let before_each_func = &before_each.func;
377        let before_each_ident = &before_each.func.sig.ident;
378        let span = before_each.func.sig.ident.span();
379
380        let call = if let Some(input) = before_each.input.as_ref() {
381            if let Some(state) = &state_ty {
382                if input.ty != **state {
383                    return Err(parse::Error::new(
384                        input.ty.span(),
385                        format!(
386                            "this type must match `#[init]`s return type: {}",
387                            type_ident(state)
388                        ),
389                    ));
390                }
391            } else {
392                return Err(parse::Error::new(
393                    span,
394                    "no state was initialized by `#[init]`; signature must be `fn()`",
395                ));
396            }
397
398            quote!(#before_each_ident(&mut state))
399        } else {
400            quote!(#before_each_ident())
401        };
402
403        (Some(quote!(#before_each_func)), Some(quote!(#call)))
404    } else {
405        (None, None)
406    };
407
408    let (after_each_fn, after_each_call) = if let Some(after_each) = after_each {
409        let after_each_func = &after_each.func;
410        let after_each_ident = &after_each.func.sig.ident;
411        let span = after_each.func.sig.ident.span();
412
413        let call = if let Some(input) = after_each.input.as_ref() {
414            if let Some(state) = &state_ty {
415                if input.ty != **state {
416                    return Err(parse::Error::new(
417                        input.ty.span(),
418                        format!(
419                            "this type must match `#[init]`s return type: {}",
420                            type_ident(state)
421                        ),
422                    ));
423                }
424            } else {
425                return Err(parse::Error::new(
426                    span,
427                    "no state was initialized by `#[init]`; signature must be `fn()`",
428                ));
429            }
430
431            quote!(#after_each_ident(&mut state))
432        } else {
433            quote!(#after_each_ident())
434        };
435
436        (Some(quote!(#after_each_func)), Some(quote!(#call)))
437    } else {
438        (None, None)
439    };
440
441    let mut unit_test_calls = vec![];
442    for test in &tests {
443        let should_error = test.should_error;
444        let ignore = test.ignore;
445        let ident = &test.func.sig.ident;
446        let span = test.func.sig.ident.span();
447        let call = if let Some(input) = test.input.as_ref() {
448            if let Some(state) = &state_ty {
449                if input.ty != **state {
450                    return Err(parse::Error::new(
451                        input.ty.span(),
452                        format!(
453                            "this type must match `#[init]`s return type: {}",
454                            type_ident(state)
455                        ),
456                    ));
457                }
458            } else {
459                return Err(parse::Error::new(
460                    span,
461                    "no state was initialized by `#[init]`; signature must be `fn()`",
462                ));
463            }
464
465            quote!(#ident(&mut state))
466        } else {
467            quote!(#ident())
468        };
469        if ignore {
470            unit_test_calls.push(quote!(let _ = #call;));
471        } else {
472            unit_test_calls.push(quote!(
473                #before_each_call;
474                #krate::export::check_outcome(#call, #should_error);
475                #after_each_call;
476            ));
477        }
478    }
479
480    let test_functions = tests.iter().map(|test| &test.func);
481    let test_cfgs = tests.iter().map(|test| &test.cfgs);
482    let declare_test_count = {
483        let test_cfgs = test_cfgs.clone();
484        quote!(
485            // We can't evaluate `#[cfg]`s in the macro, but this works too.
486            // Below value can be used to read the number of tests from the produced ELF.
487            #[used]
488            #[no_mangle]
489            static DEFMT_TEST_COUNT: usize = {
490                let mut counter = 0;
491                #(
492                    #(#test_cfgs)*
493                    { counter += 1; }
494                )*
495                counter
496            };
497        )
498    };
499    let unit_test_progress = tests
500        .iter()
501        .map(|test| {
502            let message = format!(
503                "({{=usize}}/{{=usize}}) {} `{}`...",
504                if test.ignore { "ignoring" } else { "running" },
505                test.func.sig.ident
506            );
507            quote_spanned! {
508                test.func.sig.ident.span() => defmt::println!(#message, __defmt_test_number, DEFMT_TEST_COUNT);
509            }
510        })
511        .collect::<Vec<_>>();
512    Ok(quote!(
513    #[cfg(test)]
514    mod #ident {
515        #(#untouched_tokens)*
516        // TODO use `cortex-m-rt::entry` here to get the `static mut` transform
517        #[export_name = "main"]
518        unsafe extern "C" fn __defmt_test_entry() -> ! {
519            #declare_test_count
520            #init_expr
521
522            let mut __defmt_test_number: usize = 1;
523            #(
524                #(#test_cfgs)*
525                {
526                    #unit_test_progress
527                    #unit_test_calls
528                    __defmt_test_number += 1;
529                }
530            )*
531
532            defmt::println!("all tests passed!");
533
534            #teardown_expr
535
536            #krate::export::exit()
537        }
538
539        #init_fn
540
541        #teardown_fn
542
543        #before_each_fn
544
545        #after_each_fn
546
547        #(
548            #test_functions
549        )*
550    })
551    .into())
552}
553
554#[derive(Clone, Copy)]
555enum Attr {
556    AfterEach,
557    BeforeEach,
558    Init,
559    Teardown,
560    Test,
561}
562
563struct AfterEach {
564    func: ItemFn,
565    input: Option<Input>,
566}
567
568struct BeforeEach {
569    func: ItemFn,
570    input: Option<Input>,
571}
572
573struct Init {
574    func: ItemFn,
575    state: Option<Box<Type>>,
576}
577
578struct Teardown {
579    func: ItemFn,
580    input: Option<Input>,
581}
582
583struct Test {
584    func: ItemFn,
585    cfgs: Vec<Attribute>,
586    input: Option<Input>,
587    should_error: bool,
588    ignore: bool,
589}
590
591struct Input {
592    ty: Type,
593}
594
595// NOTE doesn't check the parameters or the return type
596fn check_fn_sig(sig: &syn::Signature) -> Result<(), ()> {
597    if sig.constness.is_none()
598        && sig.asyncness.is_none()
599        && sig.unsafety.is_none()
600        && sig.abi.is_none()
601        && sig.generics.params.is_empty()
602        && sig.generics.where_clause.is_none()
603        && sig.variadic.is_none()
604    {
605        Ok(())
606    } else {
607        Err(())
608    }
609}
610
611fn get_mutable_reference_type(arg: &syn::FnArg) -> Option<&Type> {
612    if let syn::FnArg::Typed(pat) = arg {
613        if let syn::Type::Reference(refty) = &*pat.ty {
614            if refty.mutability.is_some() {
615                Some(&refty.elem)
616            } else {
617                None
618            }
619        } else {
620            None
621        }
622    } else {
623        None
624    }
625}
626
627fn extract_cfgs(attrs: &[Attribute]) -> Vec<Attribute> {
628    let mut cfgs = vec![];
629
630    for attr in attrs {
631        if attr.path().is_ident("cfg") {
632            cfgs.push(attr.clone());
633        }
634    }
635
636    cfgs
637}
638
639fn type_ident(ty: impl AsRef<syn::Type>) -> String {
640    let mut ident = String::new();
641    let ty = ty.as_ref();
642    let ty = format!("{}", quote!(#ty));
643    ty.split_whitespace().for_each(|t| ident.push_str(t));
644    ident
645}