1#![doc(html_logo_url = "https://knurling.ferrous-systems.com/knurling_logo_light_text.svg")]
2
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::{format_ident, quote, quote_spanned};
8use syn::{parse, spanned::Spanned, Attribute, Item, ItemFn, ItemMod, ReturnType, Type};
9
10#[proc_macro_attribute]
11pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
12 match tests_impl(args, input) {
13 Ok(ts) => ts,
14 Err(e) => e.to_compile_error().into(),
15 }
16}
17
18macro_rules! fn_state_signature_msg {
19 ($name:literal) => {
20 concat!(
21 "`#[",
22 $name,
23 "]` function must have signature `fn()` or `fn(state: &mut T)`"
24 )
25 };
26}
27
28fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStream> {
29 if !args.is_empty() {
30 return Err(parse::Error::new(
31 Span::call_site(),
32 "`#[test]` attribute takes no arguments",
33 ));
34 }
35
36 let module: ItemMod = syn::parse(input)?;
37
38 let items = if let Some(content) = module.content {
39 content.1
40 } else {
41 return Err(parse::Error::new(
42 module.span(),
43 "module must be inline (e.g. `mod foo {}`)",
44 ));
45 };
46
47 let mut init = None;
48 let mut teardown = None;
49 let mut before_each = None;
50 let mut after_each = None;
51 let mut tests = vec![];
52 let mut untouched_tokens = vec![];
53 for item in items {
54 match item {
55 Item::Fn(mut f) => {
56 let mut test_kind = None;
57 let mut should_error = false;
58 let mut ignore = false;
59
60 f.attrs.retain(|attr| {
61 if attr.path().is_ident("init") {
62 test_kind = Some(Attr::Init);
63 false
64 } else if attr.path().is_ident("teardown") {
65 test_kind = Some(Attr::Teardown);
66 false
67 } else if attr.path().is_ident("test") {
68 test_kind = Some(Attr::Test);
69 false
70 } else if attr.path().is_ident("before_each") {
71 test_kind = Some(Attr::BeforeEach);
72 false
73 } else if attr.path().is_ident("after_each") {
74 test_kind = Some(Attr::AfterEach);
75 false
76 } else if attr.path().is_ident("should_error") {
77 should_error = true;
78 false
79 } else if attr.path().is_ident("ignore") {
80 ignore = true;
81 false
82 } else {
83 true
84 }
85 });
86
87 let attr = match test_kind {
88 Some(it) => it,
89 None => {
90 return Err(parse::Error::new(
91 f.span(),
92 "function requires `#[init]`, `#[teardown]`, `#[before_each]`, `#[after_each]`, or `#[test]` attribute",
93 ));
94 }
95 };
96
97 match attr {
98 Attr::Init => {
99 if init.is_some() {
100 return Err(parse::Error::new(
101 f.sig.ident.span(),
102 "only a single `#[init]` function can be defined",
103 ));
104 }
105
106 if should_error {
107 return Err(parse::Error::new(
108 f.sig.ident.span(),
109 "`#[should_error]` is not allowed on the `#[init]` function",
110 ));
111 }
112
113 if ignore {
114 return Err(parse::Error::new(
115 f.sig.ident.span(),
116 "`#[ignore]` is not allowed on the `#[init]` function",
117 ));
118 }
119
120 if check_fn_sig(&f.sig).is_err() || !f.sig.inputs.is_empty() {
121 return Err(parse::Error::new(
122 f.sig.ident.span(),
123 "`#[init]` function must have signature `fn()` or `fn() -> T`",
124 ));
125 }
126
127 let state = match &f.sig.output {
128 ReturnType::Default => None,
129 ReturnType::Type(.., ty) => Some(ty.clone()),
130 };
131
132 init = Some(Init { func: f, state });
133 }
134
135 Attr::Teardown => {
136 if teardown.is_some() {
137 return Err(parse::Error::new(
138 f.sig.ident.span(),
139 "only a single `#[teardown]` function can be defined",
140 ));
141 }
142
143 if should_error {
144 return Err(parse::Error::new(
145 f.sig.ident.span(),
146 "`#[should_error]` is not allowed on the `#[teardown]` function",
147 ));
148 }
149
150 if ignore {
151 return Err(parse::Error::new(
152 f.sig.ident.span(),
153 "`#[ignore]` is not allowed on the `#[teardown]` function",
154 ));
155 }
156
157 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
158 return Err(parse::Error::new(
159 f.sig.ident.span(),
160 fn_state_signature_msg!("teardown"),
161 ));
162 }
163
164 let input = if f.sig.inputs.len() == 1 {
165 let arg = &f.sig.inputs[0];
166
167 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
170 Some(Input { ty })
171 } else {
172 return Err(parse::Error::new(
174 arg.span(),
175 "parameter must be a mutable reference (`&mut $Type`)",
176 ));
177 }
178 } else {
179 None
180 };
181
182 teardown = Some(Teardown { func: f, input });
183 }
184
185 Attr::Test => {
186 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
187 return Err(parse::Error::new(
188 f.sig.ident.span(),
189 fn_state_signature_msg!("test"),
190 ));
191 }
192
193 let input = if f.sig.inputs.len() == 1 {
194 let arg = &f.sig.inputs[0];
195
196 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
199 Some(Input { ty })
200 } else {
201 return Err(parse::Error::new(
203 arg.span(),
204 "parameter must be a mutable reference (`&mut $Type`)",
205 ));
206 }
207 } else {
208 None
209 };
210
211 tests.push(Test {
212 cfgs: extract_cfgs(&f.attrs),
213 func: f,
214 input,
215 should_error,
216 ignore,
217 })
218 }
219 Attr::BeforeEach => {
220 if before_each.is_some() {
221 return Err(parse::Error::new(
222 f.sig.ident.span(),
223 "only a single `#[before_each]` function can be defined",
224 ));
225 }
226
227 if should_error {
228 return Err(parse::Error::new(
229 f.sig.ident.span(),
230 "`#[should_error]` is not allowed on the `#[before_each]` function",
231 ));
232 }
233
234 if ignore {
235 return Err(parse::Error::new(
236 f.sig.ident.span(),
237 "`#[ignore]` is not allowed on the `#[before_each]` function",
238 ));
239 }
240
241 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
242 return Err(parse::Error::new(
243 f.sig.ident.span(),
244 fn_state_signature_msg!("before_each"),
245 ));
246 }
247
248 let input = if f.sig.inputs.len() == 1 {
249 let arg = &f.sig.inputs[0];
250
251 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
254 Some(Input { ty })
255 } else {
256 return Err(parse::Error::new(
258 arg.span(),
259 "parameter must be a mutable reference (`&mut $Type`)",
260 ));
261 }
262 } else {
263 None
264 };
265
266 before_each = Some(BeforeEach { func: f, input });
267 }
268 Attr::AfterEach => {
269 if after_each.is_some() {
270 return Err(parse::Error::new(
271 f.sig.ident.span(),
272 "only a single `#[after_each]` function can be defined",
273 ));
274 }
275
276 if should_error {
277 return Err(parse::Error::new(
278 f.sig.ident.span(),
279 "`#[should_error]` is not allowed on the `#[after_each]` function",
280 ));
281 }
282
283 if ignore {
284 return Err(parse::Error::new(
285 f.sig.ident.span(),
286 "`#[ignore]` is not allowed on the `#[after_each]` function",
287 ));
288 }
289
290 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
291 return Err(parse::Error::new(
292 f.sig.ident.span(),
293 fn_state_signature_msg!("after_each"),
294 ));
295 }
296
297 let input = if f.sig.inputs.len() == 1 {
298 let arg = &f.sig.inputs[0];
299
300 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
303 Some(Input { ty })
304 } else {
305 return Err(parse::Error::new(
307 arg.span(),
308 "parameter must be a mutable reference (`&mut $Type`)",
309 ));
310 }
311 } else {
312 None
313 };
314
315 after_each = Some(AfterEach { func: f, input });
316 }
317 }
318 }
319
320 _ => {
321 untouched_tokens.push(item);
322 }
323 }
324 }
325
326 let krate = format_ident!("defmt_test");
327 let ident = module.ident;
328 let mut state_ty = None;
329 let (init_fn, init_expr) = if let Some(init) = init {
330 let init_func = &init.func;
331 let init_ident = &init.func.sig.ident;
332 state_ty = init.state;
333
334 (
335 Some(quote!(#init_func)),
336 Some(quote!(#[allow(dead_code)] let mut state = #init_ident();)),
337 )
338 } else {
339 (None, None)
340 };
341
342 let (teardown_fn, teardown_expr) = if let Some(teardown) = teardown {
343 let teardown_func = &teardown.func;
344 let teardown_ident = &teardown.func.sig.ident;
345 let span = teardown.func.sig.ident.span();
346
347 let call = if let Some(input) = teardown.input.as_ref() {
348 if let Some(state) = &state_ty {
349 if input.ty != **state {
350 return Err(parse::Error::new(
351 input.ty.span(),
352 format!(
353 "this type must match `#[init]`s return type: {}",
354 type_ident(state)
355 ),
356 ));
357 }
358 } else {
359 return Err(parse::Error::new(
360 span,
361 "no state was initialized by `#[init]`; signature must be `fn()`",
362 ));
363 }
364
365 quote!(#teardown_ident(&mut state))
366 } else {
367 quote!(#teardown_ident())
368 };
369
370 (Some(quote!(#teardown_func)), Some(quote!(#call;)))
371 } else {
372 (None, None)
373 };
374
375 let (before_each_fn, before_each_call) = if let Some(before_each) = before_each {
376 let before_each_func = &before_each.func;
377 let before_each_ident = &before_each.func.sig.ident;
378 let span = before_each.func.sig.ident.span();
379
380 let call = if let Some(input) = before_each.input.as_ref() {
381 if let Some(state) = &state_ty {
382 if input.ty != **state {
383 return Err(parse::Error::new(
384 input.ty.span(),
385 format!(
386 "this type must match `#[init]`s return type: {}",
387 type_ident(state)
388 ),
389 ));
390 }
391 } else {
392 return Err(parse::Error::new(
393 span,
394 "no state was initialized by `#[init]`; signature must be `fn()`",
395 ));
396 }
397
398 quote!(#before_each_ident(&mut state))
399 } else {
400 quote!(#before_each_ident())
401 };
402
403 (Some(quote!(#before_each_func)), Some(quote!(#call)))
404 } else {
405 (None, None)
406 };
407
408 let (after_each_fn, after_each_call) = if let Some(after_each) = after_each {
409 let after_each_func = &after_each.func;
410 let after_each_ident = &after_each.func.sig.ident;
411 let span = after_each.func.sig.ident.span();
412
413 let call = if let Some(input) = after_each.input.as_ref() {
414 if let Some(state) = &state_ty {
415 if input.ty != **state {
416 return Err(parse::Error::new(
417 input.ty.span(),
418 format!(
419 "this type must match `#[init]`s return type: {}",
420 type_ident(state)
421 ),
422 ));
423 }
424 } else {
425 return Err(parse::Error::new(
426 span,
427 "no state was initialized by `#[init]`; signature must be `fn()`",
428 ));
429 }
430
431 quote!(#after_each_ident(&mut state))
432 } else {
433 quote!(#after_each_ident())
434 };
435
436 (Some(quote!(#after_each_func)), Some(quote!(#call)))
437 } else {
438 (None, None)
439 };
440
441 let mut unit_test_calls = vec![];
442 for test in &tests {
443 let should_error = test.should_error;
444 let ignore = test.ignore;
445 let ident = &test.func.sig.ident;
446 let span = test.func.sig.ident.span();
447 let call = if let Some(input) = test.input.as_ref() {
448 if let Some(state) = &state_ty {
449 if input.ty != **state {
450 return Err(parse::Error::new(
451 input.ty.span(),
452 format!(
453 "this type must match `#[init]`s return type: {}",
454 type_ident(state)
455 ),
456 ));
457 }
458 } else {
459 return Err(parse::Error::new(
460 span,
461 "no state was initialized by `#[init]`; signature must be `fn()`",
462 ));
463 }
464
465 quote!(#ident(&mut state))
466 } else {
467 quote!(#ident())
468 };
469 if ignore {
470 unit_test_calls.push(quote!(let _ = #call;));
471 } else {
472 unit_test_calls.push(quote!(
473 #before_each_call;
474 #krate::export::check_outcome(#call, #should_error);
475 #after_each_call;
476 ));
477 }
478 }
479
480 let test_functions = tests.iter().map(|test| &test.func);
481 let test_cfgs = tests.iter().map(|test| &test.cfgs);
482 let declare_test_count = {
483 let test_cfgs = test_cfgs.clone();
484 quote!(
485 #[used]
488 #[no_mangle]
489 static DEFMT_TEST_COUNT: usize = {
490 let mut counter = 0;
491 #(
492 #(#test_cfgs)*
493 { counter += 1; }
494 )*
495 counter
496 };
497 )
498 };
499 let unit_test_progress = tests
500 .iter()
501 .map(|test| {
502 let message = format!(
503 "({{=usize}}/{{=usize}}) {} `{}`...",
504 if test.ignore { "ignoring" } else { "running" },
505 test.func.sig.ident
506 );
507 quote_spanned! {
508 test.func.sig.ident.span() => defmt::println!(#message, __defmt_test_number, DEFMT_TEST_COUNT);
509 }
510 })
511 .collect::<Vec<_>>();
512 Ok(quote!(
513 #[cfg(test)]
514 mod #ident {
515 #(#untouched_tokens)*
516 #[export_name = "main"]
518 unsafe extern "C" fn __defmt_test_entry() -> ! {
519 #declare_test_count
520 #init_expr
521
522 let mut __defmt_test_number: usize = 1;
523 #(
524 #(#test_cfgs)*
525 {
526 #unit_test_progress
527 #unit_test_calls
528 __defmt_test_number += 1;
529 }
530 )*
531
532 defmt::println!("all tests passed!");
533
534 #teardown_expr
535
536 #krate::export::exit()
537 }
538
539 #init_fn
540
541 #teardown_fn
542
543 #before_each_fn
544
545 #after_each_fn
546
547 #(
548 #test_functions
549 )*
550 })
551 .into())
552}
553
554#[derive(Clone, Copy)]
555enum Attr {
556 AfterEach,
557 BeforeEach,
558 Init,
559 Teardown,
560 Test,
561}
562
563struct AfterEach {
564 func: ItemFn,
565 input: Option<Input>,
566}
567
568struct BeforeEach {
569 func: ItemFn,
570 input: Option<Input>,
571}
572
573struct Init {
574 func: ItemFn,
575 state: Option<Box<Type>>,
576}
577
578struct Teardown {
579 func: ItemFn,
580 input: Option<Input>,
581}
582
583struct Test {
584 func: ItemFn,
585 cfgs: Vec<Attribute>,
586 input: Option<Input>,
587 should_error: bool,
588 ignore: bool,
589}
590
591struct Input {
592 ty: Type,
593}
594
595fn check_fn_sig(sig: &syn::Signature) -> Result<(), ()> {
597 if sig.constness.is_none()
598 && sig.asyncness.is_none()
599 && sig.unsafety.is_none()
600 && sig.abi.is_none()
601 && sig.generics.params.is_empty()
602 && sig.generics.where_clause.is_none()
603 && sig.variadic.is_none()
604 {
605 Ok(())
606 } else {
607 Err(())
608 }
609}
610
611fn get_mutable_reference_type(arg: &syn::FnArg) -> Option<&Type> {
612 if let syn::FnArg::Typed(pat) = arg {
613 if let syn::Type::Reference(refty) = &*pat.ty {
614 if refty.mutability.is_some() {
615 Some(&refty.elem)
616 } else {
617 None
618 }
619 } else {
620 None
621 }
622 } else {
623 None
624 }
625}
626
627fn extract_cfgs(attrs: &[Attribute]) -> Vec<Attribute> {
628 let mut cfgs = vec![];
629
630 for attr in attrs {
631 if attr.path().is_ident("cfg") {
632 cfgs.push(attr.clone());
633 }
634 }
635
636 cfgs
637}
638
639fn type_ident(ty: impl AsRef<syn::Type>) -> String {
640 let mut ident = String::new();
641 let ty = ty.as_ref();
642 let ty = format!("{}", quote!(#ty));
643 ty.split_whitespace().for_each(|t| ident.push_str(t));
644 ident
645}