1#![recursion_limit = "128"]
2#![deny(unused_must_use)]
3extern crate proc_macro;
4
5use proc_macro2::{Span, TokenStream};
6use quote::quote;
7use std::collections::HashMap;
8use syn::parse::{Parse, ParseStream, Result as ParseResult};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::token::Comma;
12use syn::{braced, parse_macro_input, FnArg, Ident, ItemFn, Pat, PatIdent, PatType, Token, Type};
13
14type Error = syn::parse::Error;
15
16struct TemplateArg {
17 ident: syn::Ident,
18 is_pattern: bool,
19 ignore_fn: Option<syn::Path>,
20 value: syn::LitStr,
21}
22
23impl Parse for TemplateArg {
24 fn parse(input: ParseStream) -> ParseResult<Self> {
25 let mut ignore_fn = None;
26 let ident = input.parse::<syn::Ident>()?;
27
28 let is_pattern = if input.peek(syn::token::In) {
29 let _in = input.parse::<syn::token::In>()?;
30 true
31 } else {
32 let _eq = input.parse::<syn::token::Eq>()?;
33 false
34 };
35 let value = input.parse::<syn::LitStr>()?;
36 if is_pattern && input.peek(syn::token::If) {
37 let _if = input.parse::<syn::token::If>()?;
38 let _not = input.parse::<syn::token::Not>()?;
39 ignore_fn = Some(input.parse::<syn::Path>()?);
40 }
41 Ok(Self {
42 ident,
43 is_pattern,
44 ignore_fn,
45 value,
46 })
47 }
48}
49
50struct FilesTestArgs {
60 root: String,
61 args: HashMap<Ident, TemplateArg>,
62}
63
64impl Parse for FilesTestArgs {
66 fn parse(input: ParseStream) -> ParseResult<Self> {
67 let root = input.parse::<syn::LitStr>()?;
68 let _comma = input.parse::<syn::token::Comma>()?;
69 let content;
70 let _brace_token = braced!(content in input);
71
72 let args: Punctuated<TemplateArg, Comma> =
73 content.parse_terminated(TemplateArg::parse, Token![,])?;
74 let args = args
75 .into_pairs()
76 .map(|p| {
77 let value = p.into_value();
78 (value.ident.clone(), value)
79 })
80 .collect();
81
82 Ok(Self {
83 root: root.value(),
84 args,
85 })
86 }
87}
88
89enum Registration {
90 Ctor,
92 Nightly,
95}
96
97#[proc_macro_attribute]
99pub fn files_ctor_registration(
100 args: proc_macro::TokenStream,
101 func: proc_macro::TokenStream,
102) -> proc_macro::TokenStream {
103 guarded_test_attribute(
104 args,
105 func,
106 Ident::new("files_ctor_internal", Span::call_site()),
107 )
108}
109
110#[proc_macro_attribute]
112pub fn files_test_case_registration(
113 args: proc_macro::TokenStream,
114 func: proc_macro::TokenStream,
115) -> proc_macro::TokenStream {
116 guarded_test_attribute(
117 args,
118 func,
119 Ident::new("files_test_case_internal", Span::call_site()),
120 )
121}
122
123#[proc_macro_attribute]
124pub fn files_ctor_internal(
125 args: proc_macro::TokenStream,
126 func: proc_macro::TokenStream,
127) -> proc_macro::TokenStream {
128 files_internal(args, func, Registration::Ctor)
129}
130
131#[proc_macro_attribute]
132pub fn files_test_case_internal(
133 args: proc_macro::TokenStream,
134 func: proc_macro::TokenStream,
135) -> proc_macro::TokenStream {
136 files_internal(args, func, Registration::Nightly)
137}
138
139fn files_internal(
181 args: proc_macro::TokenStream,
182 func: proc_macro::TokenStream,
183 channel: Registration,
184) -> proc_macro::TokenStream {
185 let mut func_item: ItemFn = parse_macro_input!(func as ItemFn);
186 let args: FilesTestArgs = parse_macro_input!(args as FilesTestArgs);
187 let info = handle_common_attrs(&mut func_item, false);
188 let func_ident = &func_item.sig.ident;
189 let func_name_str = func_ident.to_string();
190 let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
191 let trampoline_func_ident = Ident::new(
192 &format!("__TEST_TRAMPOLINE_{}", func_ident),
193 func_ident.span(),
194 );
195 let ignore = info.ignore;
196 let root = args.root;
197 let mut pattern_idx = None;
198 let mut params: Vec<String> = Vec::new();
199 let mut invoke_args: Vec<TokenStream> = Vec::new();
200 let mut ignore_fn = None;
201
202 for (mut idx, arg) in func_item.sig.inputs.iter().enumerate() {
209 match match_arg(arg) {
210 Some((pat_ident, ty)) => {
211 if info.bench {
212 if idx == 0 {
213 invoke_args.push(quote!(#pat_ident));
215 continue;
216 } else {
217 idx -= 1;
218 }
219 }
220
221 if let Some(arg) = args.args.get(&pat_ident.ident) {
222 if arg.is_pattern {
223 if pattern_idx.is_some() {
224 return Error::new(arg.ident.span(), "two patterns are not allowed!")
225 .to_compile_error()
226 .into();
227 }
228 pattern_idx = Some(idx);
229 ignore_fn = arg.ignore_fn.clone();
230 }
231
232 params.push(arg.value.value());
233 invoke_args.push(quote! {
234 ::datatest::__internal::TakeArg::take(&mut <#ty as ::datatest::__internal::DeriveArg>::derive(&paths_arg[#idx]))
235 })
236 } else {
237 return Error::new(pat_ident.span(), "mapping is not defined for the argument")
238 .to_compile_error()
239 .into();
240 }
241 }
242 None => {
243 return Error::new(
244 arg.span(),
245 "unexpected argument; only simple argument types are allowed (`&str`, `String`, `&[u8]`, `Vec<u8>`, `&Path`, etc)",
246 ).to_compile_error().into();
247 }
248 }
249 }
250
251 let ignore_func_ref = if let Some(ignore_fn) = ignore_fn {
252 quote!(Some(#ignore_fn))
253 } else {
254 quote!(None)
255 };
256
257 if pattern_idx.is_none() {
258 return Error::new(
259 Span::call_site(),
260 "must have exactly one pattern mapping defined via `pattern in r#\"<regular expression>\"`",
261 )
262 .to_compile_error()
263 .into();
264 }
265
266 let (kind, bencher_param) = if info.bench {
267 (
268 quote!(BenchFn),
269 quote!(bencher: &mut ::datatest::__internal::Bencher,),
270 )
271 } else {
272 (quote!(TestFn), quote!())
273 };
274
275 let registration = test_registration(channel, &desc_ident);
276 let output = quote! {
277 #registration
278 #[automatically_derived]
279 #[allow(non_upper_case_globals)]
280 static #desc_ident: ::datatest::__internal::FilesTestDesc = ::datatest::__internal::FilesTestDesc {
281 name: concat!(module_path!(), "::", #func_name_str),
282 ignore: #ignore,
283 root: #root,
284 params: &[#(#params),*],
285 pattern: #pattern_idx,
286 ignorefn: #ignore_func_ref,
287 testfn: ::datatest::__internal::FilesTestFn::#kind(#trampoline_func_ident),
288 source_file: file!(),
289 };
290
291 #[automatically_derived]
292 #[allow(non_snake_case)]
293 fn #trampoline_func_ident(#bencher_param paths_arg: &[::std::path::PathBuf]) {
294 let result = #func_ident(#(#invoke_args),*);
295 ::datatest::__internal::assert_test_result(result);
296 }
297
298 #func_item
299 };
300 output.into()
301}
302
303fn match_arg(arg: &FnArg) -> Option<(&PatIdent, &Type)> {
304 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
305 if let Pat::Ident(pat_ident) = pat.as_ref() {
306 return Some((pat_ident, ty));
307 }
308 }
309 None
310}
311
312enum ShouldPanic {
313 No,
314 Yes,
315 YesWithMessage(String),
316}
317
318struct FuncInfo {
319 ignore: bool,
320 bench: bool,
321 should_panic: ShouldPanic,
322}
323
324fn handle_common_attrs(func: &mut ItemFn, regular_test: bool) -> FuncInfo {
327 let test_pos = func
331 .attrs
332 .iter()
333 .position(|attr| attr.path().is_ident("test"));
334 if let Some(pos) = test_pos {
335 func.attrs.remove(pos);
336 }
337
338 let bench_pos = func
340 .attrs
341 .iter()
342 .position(|attr| attr.path().is_ident("bench"));
343 if let Some(pos) = bench_pos {
344 func.attrs.remove(pos);
345 }
346
347 let ignore_pos = func
349 .attrs
350 .iter()
351 .position(|attr| attr.path().is_ident("ignore"));
352 if let Some(pos) = ignore_pos {
353 func.attrs.remove(pos);
354 }
355
356 let mut should_panic = ShouldPanic::No;
357 if regular_test {
358 let should_panic_pos = func
360 .attrs
361 .iter()
362 .position(|attr| attr.path().is_ident("should_panic"));
363 if let Some(pos) = should_panic_pos {
364 let attr = &func.attrs[pos];
365 should_panic = parse_should_panic(attr);
366 func.attrs.remove(pos);
367 }
368 }
369
370 FuncInfo {
371 ignore: ignore_pos.is_some(),
372 bench: bench_pos.is_some(),
373 should_panic,
374 }
375}
376
377#[allow(clippy::collapsible_match)]
378fn parse_should_panic(attr: &syn::Attribute) -> ShouldPanic {
379 let mut message: Option<String> = None;
380 _ = attr.parse_nested_meta(|meta| {
381 if meta.path.is_ident("expected") {
382 if let Ok(v) = meta.value() {
383 let mut value = v.to_string();
384 if value.starts_with("\"") {
385 value.remove(0);
386 }
387 if value.ends_with("\"") {
388 value.pop();
389 }
390 message = Some(value);
391 }
392 }
393 Ok(())
394 });
395 match message {
396 Some(message) => ShouldPanic::YesWithMessage(message),
397 None => ShouldPanic::Yes,
398 }
399}
400
401#[allow(clippy::large_enum_variant)]
405enum DataTestArgs {
406 Literal(syn::LitStr),
407 Expression(syn::Expr),
408}
409
410impl Parse for DataTestArgs {
412 fn parse(input: ParseStream) -> ParseResult<Self> {
413 let lookahead = input.lookahead1();
414 if lookahead.peek(syn::LitStr) {
415 input.parse::<syn::LitStr>().map(DataTestArgs::Literal)
416 } else {
417 input.parse::<syn::Expr>().map(DataTestArgs::Expression)
418 }
419 }
420}
421
422#[proc_macro_attribute]
424pub fn data_ctor_registration(
425 args: proc_macro::TokenStream,
426 func: proc_macro::TokenStream,
427) -> proc_macro::TokenStream {
428 guarded_test_attribute(
429 args,
430 func,
431 Ident::new("data_ctor_internal", Span::call_site()),
432 )
433}
434
435#[proc_macro_attribute]
437pub fn data_test_case_registration(
438 args: proc_macro::TokenStream,
439 func: proc_macro::TokenStream,
440) -> proc_macro::TokenStream {
441 guarded_test_attribute(
442 args,
443 func,
444 Ident::new("data_test_case_internal", Span::call_site()),
445 )
446}
447
448#[proc_macro_attribute]
449pub fn data_ctor_internal(
450 args: proc_macro::TokenStream,
451 func: proc_macro::TokenStream,
452) -> proc_macro::TokenStream {
453 data_internal(args, func, Registration::Ctor)
454}
455
456#[proc_macro_attribute]
457pub fn data_test_case_internal(
458 args: proc_macro::TokenStream,
459 func: proc_macro::TokenStream,
460) -> proc_macro::TokenStream {
461 data_internal(args, func, Registration::Nightly)
462}
463
464fn data_internal(
465 args: proc_macro::TokenStream,
466 func: proc_macro::TokenStream,
467 channel: Registration,
468) -> proc_macro::TokenStream {
469 let mut func_item = parse_macro_input!(func as ItemFn);
470 let cases: DataTestArgs = parse_macro_input!(args as DataTestArgs);
471 let info = handle_common_attrs(&mut func_item, false);
472 let cases = match cases {
473 DataTestArgs::Literal(path) => quote!(datatest::yaml(#path)),
474 DataTestArgs::Expression(expr) => quote!(#expr),
475 };
476 let func_ident = &func_item.sig.ident;
477
478 let func_name_str = func_ident.to_string();
479 let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
480 let describe_func_ident = Ident::new(
481 &format!("__TEST_DESCRIBE_{}", func_ident),
482 func_ident.span(),
483 );
484 let trampoline_func_ident = Ident::new(
485 &format!("__TEST_TRAMPOLINE_{}", func_ident),
486 func_ident.span(),
487 );
488
489 let ignore = info.ignore;
490 let mut args = func_item.sig.inputs.iter();
492
493 if info.bench {
494 args.next();
497 }
498
499 let arg = args.next();
500 let ty = match arg {
501 Some(FnArg::Typed(PatType { ty, .. })) => Some(ty.as_ref()),
502 _ => None,
503 };
504 let (ref_token, ty) = match ty {
505 Some(syn::Type::Reference(type_ref)) => (quote!(&), Some(type_ref.elem.as_ref())),
506 _ => (TokenStream::new(), ty),
507 };
508
509 let (case_ctor, bencher_param, bencher_arg) = if info.bench {
510 (
511 quote!(::datatest::__internal::DataTestFn::BenchFn(Box::new(::datatest::__internal::DataBenchFn(#trampoline_func_ident, case)))),
512 quote!(bencher: &mut ::datatest::__internal::Bencher,),
513 quote!(bencher,),
514 )
515 } else {
516 (
517 quote!(::datatest::__internal::DataTestFn::TestFn(Box::new(move || #trampoline_func_ident(case)))),
518 quote!(),
519 quote!(),
520 )
521 };
522
523 let registration = test_registration(channel, &desc_ident);
524 let output = quote! {
525 #registration
526 #[automatically_derived]
527 #[allow(non_upper_case_globals)]
528 static #desc_ident: ::datatest::__internal::DataTestDesc = ::datatest::__internal::DataTestDesc {
529 name: concat!(module_path!(), "::", #func_name_str),
530 ignore: #ignore,
531 describefn: #describe_func_ident,
532 source_file: file!(),
533 };
534
535 #[automatically_derived]
536 #[allow(non_snake_case)]
537 fn #trampoline_func_ident(#bencher_param arg: #ty) {
538 let result = #func_ident(#bencher_arg #ref_token arg);
539 ::datatest::__internal::assert_test_result(result);
540 }
541
542 #[automatically_derived]
543 #[allow(non_snake_case)]
544 fn #describe_func_ident() -> Vec<::datatest::DataTestCaseDesc<::datatest::__internal::DataTestFn>> {
545 let result = #cases
546 .into_iter()
547 .map(|input| {
548 let case = input.case;
549 ::datatest::DataTestCaseDesc {
550 case: #case_ctor,
551 name: input.name,
552 location: input.location,
553 }
554 })
555 .collect::<Vec<_>>();
556 assert!(!result.is_empty(), "no test cases were found!");
557 result
558 }
559
560 #func_item
561 };
562 output.into()
563}
564
565fn test_registration(channel: Registration, desc_ident: &syn::Ident) -> TokenStream {
566 match channel {
567 Registration::Nightly => quote!(#[test_case]),
569 Registration::Ctor => {
571 let registration_fn =
572 syn::Ident::new(&format!("{}__REGISTRATION", desc_ident), desc_ident.span());
573 let check_fn = syn::Ident::new(&format!("{}__CHECK", desc_ident), desc_ident.span());
574 let tokens = quote! {
575 #[automatically_derived]
576 #[allow(non_snake_case)]
577 #[datatest::__internal::ctor]
578 fn #registration_fn() {
579 use ::datatest::__internal::RegistrationNode;
580 static mut REGISTRATION: RegistrationNode = RegistrationNode {
581 descriptor: &#desc_ident,
582 next: None,
583 };
584 ::datatest::__internal::register(unsafe { &mut REGISTRATION });
586 }
587
588 #[automatically_derived]
594 #[allow(non_snake_case)]
595 mod #check_fn {
596 #[datatest::__internal::dtor]
597 fn check_fn() {
598 ::datatest::__internal::check_test_runner();
599 }
600 }
601 };
602 tokens
603 }
604 }
605}
606
607#[proc_macro_attribute]
610pub fn test_ctor_registration(
611 _args: proc_macro::TokenStream,
612 func: proc_macro::TokenStream,
613) -> proc_macro::TokenStream {
614 let mut func_item = parse_macro_input!(func as ItemFn);
615 let info = handle_common_attrs(&mut func_item, true);
616 let func_ident = &func_item.sig.ident;
617 let func_name_str = func_ident.to_string();
618 let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
619
620 let ignore = info.ignore;
621 let should_panic = match info.should_panic {
622 ShouldPanic::No => quote!(::datatest::__internal::RegularShouldPanic::No),
623 ShouldPanic::Yes => quote!(::datatest::__internal::RegularShouldPanic::Yes),
624 ShouldPanic::YesWithMessage(v) => {
625 quote!(::datatest::__internal::RegularShouldPanic::YesWithMessage(#v))
626 }
627 };
628 let registration = test_registration(Registration::Ctor, &desc_ident);
629 let output = quote! {
630 #registration
631 #[automatically_derived]
632 #[allow(non_upper_case_globals)]
633 static #desc_ident: ::datatest::__internal::RegularTestDesc = ::datatest::__internal::RegularTestDesc {
634 name: concat!(module_path!(), "::", #func_name_str),
635 ignore: #ignore,
636 testfn: || {
637 let result = #func_ident();
638 ::datatest::__internal::assert_test_result(result);
639 },
640 should_panic: #should_panic,
641 source_file: file!(),
642 };
643
644 #func_item
645 };
646
647 output.into()
648}
649
650fn guarded_test_attribute(
651 args: proc_macro::TokenStream,
652 item: proc_macro::TokenStream,
653 implementation: Ident,
654) -> proc_macro::TokenStream {
655 let args: TokenStream = args.into();
656 let header = quote! {
657 #[cfg(test)]
658 #[::datatest::__internal::#implementation(#args)]
659 };
660 let mut out: proc_macro::TokenStream = header.into();
661 out.extend(item);
662 out
663}