1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::{format_ident, quote, quote_spanned};
6use syn::{parse, spanned::Spanned, Attribute, Item, ItemFn, ItemMod, ReturnType, Type};
7
8#[proc_macro_attribute]
9pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
10 match tests_impl(args, input) {
11 Ok(ts) => ts,
12 Err(e) => e.to_compile_error().into(),
13 }
14}
15
16fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStream> {
17 if !args.is_empty() {
18 return Err(parse::Error::new(
19 Span::call_site(),
20 "`#[test]` attribute takes no arguments",
21 ));
22 }
23
24 let module: ItemMod = syn::parse(input)?;
25
26 let items = if let Some(content) = module.content {
27 content.1
28 } else {
29 return Err(parse::Error::new(
30 module.span(),
31 "module must be inline (e.g. `mod foo {}`)",
32 ));
33 };
34
35 let mut init = None;
36 let mut before_each = None;
37 let mut after_each = None;
38 let mut tests = vec![];
39 let mut untouched_tokens = vec![];
40 for item in items {
41 match item {
42 Item::Fn(mut f) => {
43 let mut test_kind = None;
44 let mut should_error = false;
45 let mut ignore = false;
46
47 f.attrs.retain(|attr| {
48 if attr.path.is_ident("init") {
49 test_kind = Some(Attr::Init);
50 false
51 } else if attr.path.is_ident("test") {
52 test_kind = Some(Attr::Test);
53 false
54 } else if attr.path.is_ident("before_each") {
55 test_kind = Some(Attr::BeforeEach);
56 false
57 } else if attr.path.is_ident("after_each") {
58 test_kind = Some(Attr::AfterEach);
59 false
60 } else if attr.path.is_ident("should_error") {
61 should_error = true;
62 false
63 } else if attr.path.is_ident("ignore") {
64 ignore = true;
65 false
66 } else {
67 true
68 }
69 });
70
71 let attr = match test_kind {
72 Some(it) => it,
73 None => {
74 return Err(parse::Error::new(
75 f.span(),
76 "function requires `#[init]`, `#[before_each]`, `#[after_each]`, or `#[test]` attribute",
77 ));
78 }
79 };
80
81 match attr {
82 Attr::Init => {
83 if init.is_some() {
84 return Err(parse::Error::new(
85 f.sig.ident.span(),
86 "only a single `#[init]` function can be defined",
87 ));
88 }
89
90 if should_error {
91 return Err(parse::Error::new(
92 f.sig.ident.span(),
93 "`#[should_error]` is not allowed on the `#[init]` function",
94 ));
95 }
96
97 if ignore {
98 return Err(parse::Error::new(
99 f.sig.ident.span(),
100 "`#[ignore]` is not allowed on the `#[init]` function",
101 ));
102 }
103
104 if check_fn_sig(&f.sig).is_err() || !f.sig.inputs.is_empty() {
105 return Err(parse::Error::new(
106 f.sig.ident.span(),
107 "`#[init]` function must have signature `fn() [-> Type]` (the return type is optional)",
108 ));
109 }
110
111 let state = match &f.sig.output {
112 ReturnType::Default => None,
113 ReturnType::Type(.., ty) => Some(ty.clone()),
114 };
115
116 init = Some(Init { func: f, state });
117 }
118
119 Attr::Test => {
120 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
121 return Err(parse::Error::new(
122 f.sig.ident.span(),
123 "`#[test]` function must have signature `fn(state: &mut Type)` (parameter is optional)",
124 ));
125 }
126
127 let input = if f.sig.inputs.len() == 1 {
128 let arg = &f.sig.inputs[0];
129
130 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
133 Some(Input { ty })
134 } else {
135 return Err(parse::Error::new(
137 arg.span(),
138 "parameter must be a mutable reference (`&mut $Type`)",
139 ));
140 }
141 } else {
142 None
143 };
144
145 tests.push(Test {
146 cfgs: extract_cfgs(&f.attrs),
147 func: f,
148 input,
149 should_error,
150 ignore,
151 })
152 }
153 Attr::BeforeEach => {
154 if before_each.is_some() {
155 return Err(parse::Error::new(
156 f.sig.ident.span(),
157 "only a single `#[before_each]` function can be defined",
158 ));
159 }
160
161 if should_error {
162 return Err(parse::Error::new(
163 f.sig.ident.span(),
164 "`#[should_error]` is not allowed on the `#[before_each]` function",
165 ));
166 }
167
168 if ignore {
169 return Err(parse::Error::new(
170 f.sig.ident.span(),
171 "`#[ignore]` is not allowed on the `#[before_each]` function",
172 ));
173 }
174
175 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
176 return Err(parse::Error::new(
177 f.sig.ident.span(),
178 "`#[before_each]` function must have signature `fn(state: &mut Type)` (parameter is optional)",
179 ));
180 }
181
182 let input = if f.sig.inputs.len() == 1 {
183 let arg = &f.sig.inputs[0];
184
185 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
188 Some(Input { ty })
189 } else {
190 return Err(parse::Error::new(
192 arg.span(),
193 "parameter must be a mutable reference (`&mut $Type`)",
194 ));
195 }
196 } else {
197 None
198 };
199
200 before_each = Some(BeforeEach { func: f, input });
201 }
202 Attr::AfterEach => {
203 if after_each.is_some() {
204 return Err(parse::Error::new(
205 f.sig.ident.span(),
206 "only a single `#[after_each]` function can be defined",
207 ));
208 }
209
210 if should_error {
211 return Err(parse::Error::new(
212 f.sig.ident.span(),
213 "`#[should_error]` is not allowed on the `#[after_each]` function",
214 ));
215 }
216
217 if ignore {
218 return Err(parse::Error::new(
219 f.sig.ident.span(),
220 "`#[ignore]` is not allowed on the `#[after_each]` function",
221 ));
222 }
223
224 if check_fn_sig(&f.sig).is_err() || f.sig.inputs.len() > 1 {
225 return Err(parse::Error::new(
226 f.sig.ident.span(),
227 "`#[after_each]` function must have signature `fn(state: &mut Type)` (parameter is optional)",
228 ));
229 }
230
231 let input = if f.sig.inputs.len() == 1 {
232 let arg = &f.sig.inputs[0];
233
234 if let Some(ty) = get_mutable_reference_type(arg).cloned() {
237 Some(Input { ty })
238 } else {
239 return Err(parse::Error::new(
241 arg.span(),
242 "parameter must be a mutable reference (`&mut $Type`)",
243 ));
244 }
245 } else {
246 None
247 };
248
249 after_each = Some(AfterEach { func: f, input });
250 }
251 }
252 }
253
254 _ => {
255 untouched_tokens.push(item);
256 }
257 }
258 }
259
260 let krate = format_ident!("mos_test");
261 let ident = module.ident;
262 let mut state_ty = None;
263 let (init_fn, init_expr) = if let Some(init) = init {
264 let init_func = &init.func;
265 let init_ident = &init.func.sig.ident;
266 state_ty = init.state;
267
268 (
269 Some(quote!(#init_func)),
270 Some(quote!(#[allow(dead_code)] let mut state = #init_ident();)),
271 )
272 } else {
273 (None, None)
274 };
275
276 let (before_each_fn, before_each_call) = if let Some(before_each) = before_each {
277 let before_each_func = &before_each.func;
278 let before_each_ident = &before_each.func.sig.ident;
279 let span = before_each.func.sig.ident.span();
280
281 let call = if let Some(input) = before_each.input.as_ref() {
282 if let Some(state) = &state_ty {
283 if input.ty != **state {
284 return Err(parse::Error::new(
285 input.ty.span(),
286 format!(
287 "this type must match `#[init]`s return type: {}",
288 type_ident(state)
289 ),
290 ));
291 }
292 } else {
293 return Err(parse::Error::new(
294 span,
295 "no state was initialized by `#[init]`; signature must be `fn()`",
296 ));
297 }
298
299 quote!(#before_each_ident(&mut state))
300 } else {
301 quote!(#before_each_ident())
302 };
303
304 (Some(quote!(#before_each_func)), Some(quote!(#call)))
305 } else {
306 (None, None)
307 };
308
309 let (after_each_fn, after_each_call) = if let Some(after_each) = after_each {
310 let after_each_func = &after_each.func;
311 let after_each_ident = &after_each.func.sig.ident;
312 let span = after_each.func.sig.ident.span();
313
314 let call = if let Some(input) = after_each.input.as_ref() {
315 if let Some(state) = &state_ty {
316 if input.ty != **state {
317 return Err(parse::Error::new(
318 input.ty.span(),
319 format!(
320 "this type must match `#[init]`s return type: {}",
321 type_ident(state)
322 ),
323 ));
324 }
325 } else {
326 return Err(parse::Error::new(
327 span,
328 "no state was initialized by `#[init]`; signature must be `fn()`",
329 ));
330 }
331
332 quote!(#after_each_ident(&mut state))
333 } else {
334 quote!(#after_each_ident())
335 };
336
337 (Some(quote!(#after_each_func)), Some(quote!(#call)))
338 } else {
339 (None, None)
340 };
341
342 let mut unit_test_calls = vec![];
343 for test in &tests {
344 let should_error = test.should_error;
345 let ignore = test.ignore;
346 let ident = &test.func.sig.ident;
347 let span = test.func.sig.ident.span();
348 let call = if let Some(input) = test.input.as_ref() {
349 if let Some(state) = &state_ty {
350 if input.ty != **state {
351 return Err(parse::Error::new(
352 input.ty.span(),
353 format!(
354 "this type must match `#[init]`s return type: {}",
355 type_ident(state)
356 ),
357 ));
358 }
359 } else {
360 return Err(parse::Error::new(
361 span,
362 "no state was initialized by `#[init]`; signature must be `fn()`",
363 ));
364 }
365
366 quote!(#ident(&mut state))
367 } else {
368 quote!(#ident())
369 };
370 if ignore {
371 unit_test_calls.push(quote!(let _ = #call;));
372 } else {
373 unit_test_calls.push(quote!(
374 #before_each_call;
375 #krate::export::check_outcome(#krate::OutcomeWrapper(#call), #should_error);
376 #after_each_call;
377 ));
378 }
379 }
380
381 let test_functions = tests.iter().map(|test| &test.func);
382 let test_cfgs = tests.iter().map(|test| &test.cfgs);
383 let declare_test_count = {
384 let test_cfgs = test_cfgs.clone();
385 quote!(
386 #[used]
389 #[no_mangle]
390 static DEFMT_TEST_COUNT: usize = {
391 let mut counter = 0;
392 #(
393 #(#test_cfgs)*
394 { counter += 1; }
395 )*
396 counter
397 };
398 )
399 };
400 let unit_test_progress = tests
401 .iter()
402 .map(|test| {
403 let message = format!(
404 "({{}}/{{}}) {} `{}`...",
405 if test.ignore { "ignoring" } else { "running" },
406 test.func.sig.ident
407 );
408 quote_spanned! {
409 test.func.sig.ident.span() => ufmt_stdio::println!(#message, __defmt_test_number, DEFMT_TEST_COUNT);
410 }
411 })
412 .collect::<Vec<_>>();
413 Ok(quote!(
414 #[cfg(test)]
415 mod #ident {
416 use ufmt_stdio::ufmt;
417 #(#untouched_tokens)*
418 #[export_name = "main"]
420 unsafe extern "C" fn __defmt_test_entry() -> isize {
421 #declare_test_count
422 #init_expr
423
424 let mut __defmt_test_number: usize = 1;
425 #(
426 #(#test_cfgs)*
427 {
428 #unit_test_progress
429 #unit_test_calls
430 __defmt_test_number += 1;
431 }
432 )*
433
434 ufmt_stdio::println!("all tests passed!");
435 0
437 }
438
439 #init_fn
440
441 #before_each_fn
442
443 #after_each_fn
444
445 #(
446 #test_functions
447 )*
448 })
449 .into())
450}
451
452#[derive(Clone, Copy)]
453enum Attr {
454 AfterEach,
455 BeforeEach,
456 Init,
457 Test,
458}
459
460struct AfterEach {
461 func: ItemFn,
462 input: Option<Input>,
463}
464
465struct BeforeEach {
466 func: ItemFn,
467 input: Option<Input>,
468}
469
470struct Init {
471 func: ItemFn,
472 state: Option<Box<Type>>,
473}
474
475struct Test {
476 func: ItemFn,
477 cfgs: Vec<Attribute>,
478 input: Option<Input>,
479 should_error: bool,
480 ignore: bool,
481}
482
483struct Input {
484 ty: Type,
485}
486
487fn check_fn_sig(sig: &syn::Signature) -> Result<(), ()> {
489 if sig.constness.is_none()
490 && sig.asyncness.is_none()
491 && sig.unsafety.is_none()
492 && sig.abi.is_none()
493 && sig.generics.params.is_empty()
494 && sig.generics.where_clause.is_none()
495 && sig.variadic.is_none()
496 {
497 Ok(())
498 } else {
499 Err(())
500 }
501}
502
503fn get_mutable_reference_type(arg: &syn::FnArg) -> Option<&Type> {
504 if let syn::FnArg::Typed(pat) = arg {
505 if let syn::Type::Reference(refty) = &*pat.ty {
506 if refty.mutability.is_some() {
507 Some(&refty.elem)
508 } else {
509 None
510 }
511 } else {
512 None
513 }
514 } else {
515 None
516 }
517}
518
519fn extract_cfgs(attrs: &[Attribute]) -> Vec<Attribute> {
520 let mut cfgs = vec![];
521
522 for attr in attrs {
523 if attr.path.is_ident("cfg") {
524 cfgs.push(attr.clone());
525 }
526 }
527
528 cfgs
529}
530
531fn type_ident(ty: impl AsRef<syn::Type>) -> String {
532 let mut ident = String::new();
533 let ty = ty.as_ref();
534 let ty = format!("{}", quote!(#ty));
535 ty.split_whitespace().for_each(|t| ident.push_str(t));
536 ident
537}