1use std::collections::HashSet;
30
31use proc_macro::TokenStream;
32use proc_macro2::TokenStream as TokenStream2;
33use quote::{format_ident, quote};
34use syn::parse::{Parse, ParseStream};
35use syn::punctuated::Punctuated;
36use syn::spanned::Spanned;
37use syn::{Expr, FnArg, Ident, Index, ItemFn, LitStr, Pat, Path, Token, braced, parenthesized};
38
39struct StructPattern {
41 path: Path,
42 fields: Vec<(Ident, Expr)>,
43 rest: bool,
44}
45
46impl Parse for StructPattern {
47 fn parse(input: ParseStream) -> syn::Result<Self> {
48 let path: Path = input.parse()?;
49 let content;
50 braced!(content in input);
51 let (fields, rest) = parse_named_fields(&content)?;
52 Ok(Self { path, fields, rest })
53 }
54}
55
56struct TuplePattern {
58 path: Path,
59 elems: Vec<Expr>,
60 rest: bool,
61}
62
63impl Parse for TuplePattern {
64 fn parse(input: ParseStream) -> syn::Result<Self> {
65 let path: Path = input.parse()?;
66 let content;
67 parenthesized!(content in input);
68 let (elems, rest) = parse_positional_fields(&content)?;
69 Ok(Self { path, elems, rest })
70 }
71}
72
73enum VariantBody {
75 Struct {
76 fields: Vec<(Ident, Expr)>,
77 rest: bool,
78 },
79 Tuple {
80 elems: Vec<Expr>,
81 rest: bool,
82 },
83 Unit,
84}
85
86struct VariantPattern {
89 path: Path,
90 body: VariantBody,
91}
92
93impl Parse for VariantPattern {
94 fn parse(input: ParseStream) -> syn::Result<Self> {
95 let path: Path = input.parse()?;
96 let body = if input.peek(syn::token::Brace) {
97 let content;
98 braced!(content in input);
99 let (fields, rest) = parse_named_fields(&content)?;
100 VariantBody::Struct { fields, rest }
101 } else if input.peek(syn::token::Paren) {
102 let content;
103 parenthesized!(content in input);
104 let (elems, rest) = parse_positional_fields(&content)?;
105 VariantBody::Tuple { elems, rest }
106 } else {
107 VariantBody::Unit
108 };
109 Ok(Self { path, body })
110 }
111}
112
113fn parse_named_fields(content: ParseStream) -> syn::Result<(Vec<(Ident, Expr)>, bool)> {
116 let mut fields = Vec::new();
117 let mut rest = false;
118 while !content.is_empty() {
119 if content.peek(Token![..]) {
120 content.parse::<Token![..]>()?;
121 rest = true;
122 break;
123 }
124 let name: Ident = content.parse()?;
125 content.parse::<Token![:]>()?;
126 let expr: Expr = content.parse()?;
127 fields.push((name, expr));
128 if content.is_empty() {
129 break;
130 }
131 content.parse::<Token![,]>()?;
132 }
133 if !content.is_empty() {
134 return Err(content.error("`..` must be the final element of the pattern"));
135 }
136 Ok((fields, rest))
137}
138
139fn parse_positional_fields(content: ParseStream) -> syn::Result<(Vec<Expr>, bool)> {
142 let mut elems = Vec::new();
143 let mut rest = false;
144 while !content.is_empty() {
145 if content.peek(Token![..]) {
146 content.parse::<Token![..]>()?;
147 rest = true;
148 break;
149 }
150 elems.push(content.parse()?);
151 if content.is_empty() {
152 break;
153 }
154 content.parse::<Token![,]>()?;
155 }
156 if !content.is_empty() {
157 return Err(content.error("`..` must be the final element of the pattern"));
158 }
159 Ok((elems, rest))
160}
161
162fn split_variant_path(path: &Path) -> syn::Result<(Path, Ident)> {
165 if path.segments.len() < 2 {
166 return Err(syn::Error::new_spanned(
167 path,
168 "expected an enum variant path like `MyEnum::Variant`",
169 ));
170 }
171 let kept = path.segments.len() - 1;
172 let segments: Punctuated<syn::PathSegment, Token![::]> =
173 path.segments.iter().take(kept).cloned().collect();
174 let enum_path = Path {
175 leading_colon: path.leading_colon,
176 segments,
177 };
178 let variant_ident = match path.segments.last() {
179 Some(seg) => seg.ident.clone(),
180 None => return Err(syn::Error::new_spanned(path, "missing variant name")),
181 };
182 Ok((enum_path, variant_ident))
183}
184
185struct FieldIdents {
189 matcher_ty: Vec<Ident>,
190 field_ty: Vec<Ident>,
191 matcher_field: Vec<Ident>,
192 binding: Vec<Ident>,
193}
194
195fn field_idents(n: usize) -> FieldIdents {
196 FieldIdents {
197 matcher_ty: (0..n).map(|i| format_ident!("__TbM{}", i)).collect(),
198 field_ty: (0..n).map(|i| format_ident!("__TbF{}", i)).collect(),
199 matcher_field: (0..n).map(|i| format_ident!("__tb_m{}", i)).collect(),
200 binding: (0..n).map(|i| format_ident!("__tb_f{}", i)).collect(),
201 }
202}
203
204fn field_check_blocks(
208 matcher_field: &[Ident],
209 binding: &[Ident],
210 labels: &[String],
211) -> Vec<TokenStream2> {
212 matcher_field
213 .iter()
214 .zip(binding)
215 .zip(labels)
216 .map(|((field, bind), label)| {
217 let label = label.as_str();
218 quote! {
219 {
220 let __tb_result = ::test_better::Matcher::check(&self.#field, #bind);
221 if !__tb_result.matched {
222 let __tb_inner = match __tb_result.failure {
223 ::core::option::Option::Some(__tb_mismatch) => __tb_mismatch,
224 ::core::option::Option::None => ::test_better::Mismatch::new(
225 ::test_better::Matcher::description(&self.#field),
226 "the field matcher reported failure without detail",
227 ),
228 };
229 return ::test_better::MatchResult::fail(::test_better::Mismatch {
230 expected: ::test_better::Description::labeled(
231 #label,
232 __tb_inner.expected,
233 ),
234 actual: __tb_inner.actual,
235 diff: __tb_inner.diff,
236 });
237 }
238 }
239 }
240 })
241 .collect()
242}
243
244fn description_fold(matcher_field: &[Ident], labels: &[String]) -> TokenStream2 {
246 let mut parts = matcher_field.iter().zip(labels).map(|(field, label)| {
247 let label = label.as_str();
248 quote! {
249 ::test_better::Description::labeled(
250 #label,
251 ::test_better::Matcher::description(&self.#field),
252 )
253 }
254 });
255 match parts.next() {
256 Some(first) => {
257 let mut acc = first;
258 for part in parts {
259 acc = quote! { #acc.and(#part) };
260 }
261 acc
262 }
263 None => quote! { ::test_better::Description::text("a matching value") },
264 }
265}
266
267fn exhaustiveness_fn(target: &TokenStream2, stmt: Option<TokenStream2>) -> TokenStream2 {
270 match stmt {
271 Some(stmt) => quote! {
272 #[allow(dead_code, unused_variables, irrefutable_let_patterns, clippy::all)]
273 fn __tb_assert_exhaustive(__tb_value: &#target) {
274 #stmt
275 }
276 },
277 None => quote! {},
278 }
279}
280
281fn gen_plain(
288 target: &TokenStream2,
289 labels: &[String],
290 field_exprs: &[&Expr],
291 projection: TokenStream2,
292 exhaustiveness: Option<TokenStream2>,
293) -> TokenStream2 {
294 let idents = field_idents(labels.len());
295 let FieldIdents {
296 matcher_ty,
297 field_ty,
298 matcher_field,
299 binding,
300 } = &idents;
301 let n = labels.len();
302 let assertion = exhaustiveness_fn(target, exhaustiveness);
303
304 if n == 0 {
305 return quote! {
306 {
307 #[allow(non_camel_case_types, dead_code, clippy::all)]
308 struct __TbStructuralMatcher<__TbP> {
309 __tb_project: __TbP,
310 }
311
312 #[allow(clippy::all)]
313 impl<__TbS, __TbP> ::test_better::Matcher<__TbS>
314 for __TbStructuralMatcher<__TbP>
315 where
316 __TbP: ::core::ops::Fn(&__TbS) -> (),
317 {
318 fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
319 let () = (self.__tb_project)(__tb_actual);
320 ::test_better::MatchResult::pass()
321 }
322
323 fn description(&self) -> ::test_better::Description {
324 ::test_better::Description::text("a matching value")
325 }
326 }
327
328 #[allow(clippy::all)]
329 fn __tb_make<__TbS, __TbP>(
330 __tb_project: __TbP,
331 ) -> impl ::test_better::Matcher<__TbS>
332 where
333 __TbP: ::core::ops::Fn(&__TbS) -> (),
334 {
335 __TbStructuralMatcher { __tb_project }
336 }
337
338 #assertion
339
340 __tb_make(#projection)
341 }
342 };
343 }
344
345 let checks = field_check_blocks(matcher_field, binding, labels);
346 let desc = description_fold(matcher_field, labels);
347
348 quote! {
349 {
350 #[allow(non_camel_case_types, dead_code, clippy::all)]
351 struct __TbStructuralMatcher<__TbP, #( #matcher_ty, )*> {
352 __tb_project: __TbP,
353 #( #matcher_field: #matcher_ty, )*
354 }
355
356 #[allow(clippy::all)]
357 impl<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>
358 ::test_better::Matcher<__TbS>
359 for __TbStructuralMatcher<__TbP, #( #matcher_ty, )*>
360 where
361 __TbP: ::core::ops::Fn(&__TbS) -> ( #( &#field_ty, )* ),
362 #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
363 {
364 fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
365 let ( #( #binding, )* ) = (self.__tb_project)(__tb_actual);
366 #( #checks )*
367 ::test_better::MatchResult::pass()
368 }
369
370 fn description(&self) -> ::test_better::Description {
371 #desc
372 }
373 }
374
375 #[allow(clippy::all)]
376 fn __tb_make<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>(
377 __tb_project: __TbP,
378 #( #matcher_field: #matcher_ty, )*
379 ) -> impl ::test_better::Matcher<__TbS>
380 where
381 __TbP: ::core::ops::Fn(&__TbS) -> ( #( &#field_ty, )* ),
382 #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
383 {
384 __TbStructuralMatcher {
385 __tb_project,
386 #( #matcher_field, )*
387 }
388 }
389
390 #assertion
391
392 __tb_make(#projection, #( #field_exprs, )*)
393 }
394 }
395}
396
397fn gen_struct(path: &Path, fields: &[(Ident, Expr)], rest: bool) -> TokenStream2 {
398 let target = quote! { #path };
399 let labels: Vec<String> = fields.iter().map(|(name, _)| name.to_string()).collect();
400 let field_exprs: Vec<&Expr> = fields.iter().map(|(_, expr)| expr).collect();
401 let field_names: Vec<&Ident> = fields.iter().map(|(name, _)| name).collect();
402
403 let projection = if fields.is_empty() {
404 quote! { |_: &#path| () }
405 } else {
406 quote! { |__tb_subject: &#path| ( #( &__tb_subject.#field_names, )* ) }
407 };
408
409 let exhaustiveness = if rest {
410 None
411 } else {
412 Some(quote! { let #path { #( #field_names: _, )* } = __tb_value; })
413 };
414
415 gen_plain(&target, &labels, &field_exprs, projection, exhaustiveness)
416}
417
418fn gen_tuple(path: &Path, elems: &[Expr], rest: bool) -> TokenStream2 {
419 let target = quote! { #path };
420 let labels: Vec<String> = (0..elems.len()).map(|i| i.to_string()).collect();
421 let field_exprs: Vec<&Expr> = elems.iter().collect();
422 let indices: Vec<Index> = (0..elems.len()).map(Index::from).collect();
423
424 let projection = if elems.is_empty() {
425 quote! { |_: &#path| () }
426 } else {
427 quote! { |__tb_subject: &#path| ( #( &__tb_subject.#indices, )* ) }
428 };
429
430 let exhaustiveness = if rest {
431 None
432 } else {
433 let holes = elems.iter().map(|_| quote!(_));
434 Some(quote! { let #path( #( #holes, )* ) = __tb_value; })
435 };
436
437 gen_plain(&target, &labels, &field_exprs, projection, exhaustiveness)
438}
439
440fn gen_variant(pattern: &VariantPattern) -> syn::Result<TokenStream2> {
441 let (enum_path, variant_ident) = split_variant_path(&pattern.path)?;
442 let path = &pattern.path;
443 let target = quote! { #enum_path };
444 let variant_name = variant_ident.to_string();
445 let variant_label = format!("the {variant_name} variant");
446
447 let (labels, field_exprs, projection, exhaustiveness): (
450 Vec<String>,
451 Vec<&Expr>,
452 TokenStream2,
453 Option<TokenStream2>,
454 ) = match &pattern.body {
455 VariantBody::Struct { fields, rest } => {
456 let labels: Vec<String> = fields.iter().map(|(name, _)| name.to_string()).collect();
457 let field_exprs: Vec<&Expr> = fields.iter().map(|(_, expr)| expr).collect();
458 let field_names: Vec<&Ident> = fields.iter().map(|(name, _)| name).collect();
459 let bindings: Vec<Ident> = (0..fields.len())
460 .map(|i| format_ident!("__tb_p{}", i))
461 .collect();
462 let projection = quote! {
463 |__tb_subject: &#enum_path| match __tb_subject {
464 #path { #( #field_names: #bindings, )* .. } =>
465 ::core::option::Option::Some(( #( #bindings, )* )),
466 _ => ::core::option::Option::None,
467 }
468 };
469 let exhaustiveness = if *rest {
470 None
471 } else {
472 Some(quote! { if let #path { #( #field_names: _, )* } = __tb_value {} })
473 };
474 (labels, field_exprs, projection, exhaustiveness)
475 }
476 VariantBody::Tuple { elems, rest } => {
477 let labels: Vec<String> = (0..elems.len()).map(|i| i.to_string()).collect();
478 let field_exprs: Vec<&Expr> = elems.iter().collect();
479 let bindings: Vec<Ident> = (0..elems.len())
480 .map(|i| format_ident!("__tb_p{}", i))
481 .collect();
482 let projection = quote! {
483 |__tb_subject: &#enum_path| match __tb_subject {
484 #path( #( #bindings, )* .. ) =>
485 ::core::option::Option::Some(( #( #bindings, )* )),
486 _ => ::core::option::Option::None,
487 }
488 };
489 let exhaustiveness = if *rest {
490 None
491 } else {
492 let holes = elems.iter().map(|_| quote!(_));
493 Some(quote! { if let #path( #( #holes, )* ) = __tb_value {} })
494 };
495 (labels, field_exprs, projection, exhaustiveness)
496 }
497 VariantBody::Unit => {
498 let projection = quote! {
499 |__tb_subject: &#enum_path| match __tb_subject {
500 #path => ::core::option::Option::Some(()),
501 _ => ::core::option::Option::None,
502 }
503 };
504 (Vec::new(), Vec::new(), projection, None)
505 }
506 };
507
508 let idents = field_idents(labels.len());
509 let FieldIdents {
510 matcher_ty,
511 field_ty,
512 matcher_field,
513 binding,
514 } = &idents;
515 let n = labels.len();
516 let assertion = exhaustiveness_fn(&target, exhaustiveness);
517
518 let wrong_variant = quote! {
519 ::test_better::MatchResult::fail(::test_better::Mismatch::new(
520 ::test_better::Description::text(#variant_label),
521 ::std::format!("{:?}", __tb_actual),
522 ))
523 };
524
525 if n == 0 {
526 return Ok(quote! {
527 {
528 #[allow(non_camel_case_types, dead_code, clippy::all)]
529 struct __TbVariantMatcher<__TbP> {
530 __tb_project: __TbP,
531 }
532
533 #[allow(clippy::all)]
534 impl<__TbS, __TbP> ::test_better::Matcher<__TbS>
535 for __TbVariantMatcher<__TbP>
536 where
537 __TbP: ::core::ops::Fn(&__TbS) -> ::core::option::Option<()>,
538 __TbS: ::core::fmt::Debug,
539 {
540 fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
541 match (self.__tb_project)(__tb_actual) {
542 ::core::option::Option::Some(()) => {
543 ::test_better::MatchResult::pass()
544 }
545 ::core::option::Option::None => #wrong_variant,
546 }
547 }
548
549 fn description(&self) -> ::test_better::Description {
550 ::test_better::Description::text(#variant_label)
551 }
552 }
553
554 #[allow(clippy::all)]
555 fn __tb_make<__TbS, __TbP>(
556 __tb_project: __TbP,
557 ) -> impl ::test_better::Matcher<__TbS>
558 where
559 __TbP: ::core::ops::Fn(&__TbS) -> ::core::option::Option<()>,
560 __TbS: ::core::fmt::Debug,
561 {
562 __TbVariantMatcher { __tb_project }
563 }
564
565 #assertion
566
567 __tb_make(#projection)
568 }
569 });
570 }
571
572 let checks = field_check_blocks(matcher_field, binding, &labels);
573 let desc_inner = description_fold(matcher_field, &labels);
574 let desc = quote! { ::test_better::Description::labeled(#variant_name, #desc_inner) };
575
576 Ok(quote! {
577 {
578 #[allow(non_camel_case_types, dead_code, clippy::all)]
579 struct __TbVariantMatcher<__TbP, #( #matcher_ty, )*> {
580 __tb_project: __TbP,
581 #( #matcher_field: #matcher_ty, )*
582 }
583
584 #[allow(clippy::all)]
585 impl<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>
586 ::test_better::Matcher<__TbS>
587 for __TbVariantMatcher<__TbP, #( #matcher_ty, )*>
588 where
589 __TbP: ::core::ops::Fn(&__TbS)
590 -> ::core::option::Option<( #( &#field_ty, )* )>,
591 #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
592 __TbS: ::core::fmt::Debug,
593 {
594 fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
595 match (self.__tb_project)(__tb_actual) {
596 ::core::option::Option::Some(( #( #binding, )* )) => {
597 #( #checks )*
598 ::test_better::MatchResult::pass()
599 }
600 ::core::option::Option::None => #wrong_variant,
601 }
602 }
603
604 fn description(&self) -> ::test_better::Description {
605 #desc
606 }
607 }
608
609 #[allow(clippy::all)]
610 fn __tb_make<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>(
611 __tb_project: __TbP,
612 #( #matcher_field: #matcher_ty, )*
613 ) -> impl ::test_better::Matcher<__TbS>
614 where
615 __TbP: ::core::ops::Fn(&__TbS)
616 -> ::core::option::Option<( #( &#field_ty, )* )>,
617 #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
618 __TbS: ::core::fmt::Debug,
619 {
620 __TbVariantMatcher {
621 __tb_project,
622 #( #matcher_field, )*
623 }
624 }
625
626 #assertion
627
628 __tb_make(#projection, #( #field_exprs, )*)
629 }
630 })
631}
632
633#[proc_macro]
660pub fn matches_struct(input: TokenStream) -> TokenStream {
661 match syn::parse::<StructPattern>(input) {
662 Ok(pattern) => gen_struct(&pattern.path, &pattern.fields, pattern.rest).into(),
663 Err(error) => error.to_compile_error().into(),
664 }
665}
666
667#[proc_macro]
686pub fn matches_tuple(input: TokenStream) -> TokenStream {
687 match syn::parse::<TuplePattern>(input) {
688 Ok(pattern) => gen_tuple(&pattern.path, &pattern.elems, pattern.rest).into(),
689 Err(error) => error.to_compile_error().into(),
690 }
691}
692
693#[proc_macro]
717pub fn matches_variant(input: TokenStream) -> TokenStream {
718 let result = syn::parse::<VariantPattern>(input).and_then(|pattern| gen_variant(&pattern));
719 match result {
720 Ok(tokens) => tokens.into(),
721 Err(error) => error.to_compile_error().into(),
722 }
723}
724
725struct TestCase {
728 span: proc_macro2::Span,
731 args: Vec<Expr>,
733 label: Option<LitStr>,
735}
736
737impl Parse for TestCase {
738 fn parse(input: ParseStream) -> syn::Result<Self> {
739 let span = input.span();
740 let mut args = Vec::new();
741 let mut label = None;
742 while !input.is_empty() {
743 if input.peek(Token![;]) {
744 input.parse::<Token![;]>()?;
745 label = Some(input.parse::<LitStr>()?);
746 if !input.is_empty() {
747 return Err(input.error("unexpected tokens after the test-case label"));
748 }
749 break;
750 }
751 args.push(input.parse()?);
752 if input.is_empty() || input.peek(Token![;]) {
753 continue;
754 }
755 input.parse::<Token![,]>()?;
756 }
757 Ok(Self { span, args, label })
758 }
759}
760
761fn parse_test_case_attr(attribute: &syn::Attribute) -> syn::Result<TestCase> {
764 match &attribute.meta {
765 syn::Meta::Path(_) => Ok(TestCase {
766 span: attribute.span(),
767 args: Vec::new(),
768 label: None,
769 }),
770 _ => attribute.parse_args::<TestCase>(),
771 }
772}
773
774fn is_test_case_attr(attribute: &syn::Attribute) -> bool {
777 attribute
778 .path()
779 .segments
780 .last()
781 .is_some_and(|segment| segment.ident == "test_case")
782}
783
784fn sanitize_ident(label: &str) -> String {
788 let mut out: String = label
789 .chars()
790 .map(|ch| {
791 if ch.is_ascii_alphanumeric() {
792 ch.to_ascii_lowercase()
793 } else {
794 '_'
795 }
796 })
797 .collect();
798 if out.is_empty() {
799 out.push_str("case");
800 }
801 if out.starts_with(|ch: char| ch.is_ascii_digit()) {
802 out.insert(0, '_');
803 }
804 out
805}
806
807fn test_case_impl(first: TestCase, mut func: ItemFn) -> syn::Result<TokenStream2> {
810 let mut cases = vec![first];
815 let mut forwarded = Vec::new();
816 for attribute in std::mem::take(&mut func.attrs) {
817 if is_test_case_attr(&attribute) {
818 cases.push(parse_test_case_attr(&attribute)?);
819 } else {
820 forwarded.push(attribute);
821 }
822 }
823
824 let fn_name = func.sig.ident.clone();
825 let fn_name_str = fn_name.to_string();
826 let ret = func.sig.output.clone();
827 let expected_arity = func.sig.inputs.len();
828 let returns_value = match &func.sig.output {
831 syn::ReturnType::Default => false,
832 syn::ReturnType::Type(_, ty) => {
833 !matches!(&**ty, syn::Type::Tuple(tuple) if tuple.elems.is_empty())
834 }
835 };
836
837 let mut used_names: HashSet<String> = HashSet::new();
838 let mut tests = Vec::with_capacity(cases.len());
839 for (index, case) in cases.iter().enumerate() {
840 if case.args.len() != expected_arity {
841 return Err(syn::Error::new(
842 case.span,
843 format!(
844 "this `#[test_case]` passes {} argument(s) but `{fn_name_str}` takes {}",
845 case.args.len(),
846 expected_arity,
847 ),
848 ));
849 }
850
851 let base = match &case.label {
852 Some(label) => sanitize_ident(&label.value()),
853 None => format!("case_{index}"),
854 };
855 let name = if used_names.contains(&base) {
858 format!("{base}_{index}")
859 } else {
860 base
861 };
862 used_names.insert(name.clone());
863 let test_ident = format_ident!("{name}");
864
865 let args = &case.args;
866 let args_rendered = args
867 .iter()
868 .map(|arg| quote!(#arg).to_string())
869 .collect::<Vec<_>>()
870 .join(", ");
871 let label_part = match &case.label {
872 Some(label) => format!("{:?}", label.value()),
873 None => format!("#{index}"),
874 };
875 let context_msg = format!("test case {label_part}: {fn_name_str}({args_rendered})");
876
877 let body = if returns_value {
878 quote! {
879 ::test_better::ContextExt::context(#fn_name(#(#args),*), #context_msg)
880 }
881 } else {
882 quote! { #fn_name(#(#args),*); }
883 };
884
885 tests.push(quote! {
889 #(#forwarded)*
890 #[test]
891 pub(super) fn #test_ident() #ret {
892 #body
893 }
894 });
895 }
896
897 Ok(quote! {
898 mod #fn_name {
899 #[allow(unused_imports)]
900 use super::*;
901
902 #func
903
904 #(#tests)*
905 }
906 })
907}
908
909#[proc_macro_attribute]
936pub fn test_case(attr: TokenStream, item: TokenStream) -> TokenStream {
937 let first = match syn::parse::<TestCase>(attr) {
938 Ok(case) => case,
939 Err(error) => return error.to_compile_error().into(),
940 };
941 let func = match syn::parse::<ItemFn>(item) {
942 Ok(func) => func,
943 Err(error) => return error.to_compile_error().into(),
944 };
945 match test_case_impl(first, func) {
946 Ok(tokens) => tokens.into(),
947 Err(error) => error.to_compile_error().into(),
948 }
949}
950
951enum FixtureScope {
954 Test,
956 Module,
958}
959
960struct FixtureArgs {
962 scope: FixtureScope,
963}
964
965impl Parse for FixtureArgs {
966 fn parse(input: ParseStream) -> syn::Result<Self> {
967 if input.is_empty() {
969 return Ok(Self {
970 scope: FixtureScope::Test,
971 });
972 }
973 let key: Ident = input.parse()?;
974 if key != "scope" {
975 return Err(syn::Error::new_spanned(
976 key,
977 "the only `#[fixture]` argument is `scope`",
978 ));
979 }
980 input.parse::<Token![=]>()?;
981 let value: LitStr = input.parse()?;
982 let scope = match value.value().as_str() {
983 "test" => FixtureScope::Test,
984 "module" => FixtureScope::Module,
985 other => {
986 return Err(syn::Error::new_spanned(
987 value,
988 format!("unknown fixture scope {other:?}, expected \"test\" or \"module\""),
989 ));
990 }
991 };
992 if !input.is_empty() {
993 return Err(input.error("unexpected tokens after the fixture scope"));
994 }
995 Ok(Self { scope })
996 }
997}
998
999fn fixture_impl(args: FixtureArgs, mut func: ItemFn) -> syn::Result<TokenStream2> {
1006 if let Some(param) = func.sig.inputs.first() {
1007 return Err(syn::Error::new_spanned(
1008 param,
1009 "a `#[fixture]` function takes no parameters",
1010 ));
1011 }
1012 let return_ty = match &func.sig.output {
1013 syn::ReturnType::Type(_, ty) => (**ty).clone(),
1014 syn::ReturnType::Default => {
1015 return Err(syn::Error::new_spanned(
1016 &func.sig,
1017 "a `#[fixture]` function must return a `TestResult<T>`",
1018 ));
1019 }
1020 };
1021
1022 let forwarded: Vec<syn::Attribute> = std::mem::take(&mut func.attrs);
1026
1027 let fn_name = func.sig.ident.clone();
1028 let fn_name_str = fn_name.to_string();
1029 let vis = func.vis.clone();
1030 let ret = func.sig.output.clone();
1031 let body = &func.block;
1032 let context_msg = format!("setting up fixture `{fn_name_str}`");
1033
1034 let impl_fn = quote! {
1035 fn __tb_fixture_impl() #ret #body
1036 };
1037
1038 let outer_body = match args.scope {
1039 FixtureScope::Test => quote! {
1043 #impl_fn
1044 ::core::result::Result::map_err(__tb_fixture_impl(), |__tb_error| {
1045 ::test_better::TestError::with_context_frame(
1046 ::test_better::TestError::with_kind(
1047 __tb_error,
1048 ::test_better::ErrorKind::Setup,
1049 ),
1050 ::test_better::ContextFrame::new(#context_msg),
1051 )
1052 })
1053 },
1054 FixtureScope::Module => quote! {
1058 #impl_fn
1059 static __TB_FIXTURE_CELL: ::std::sync::LazyLock<#return_ty> =
1060 ::std::sync::LazyLock::new(__tb_fixture_impl);
1061 match &*__TB_FIXTURE_CELL {
1062 ::core::result::Result::Ok(__tb_value) => {
1063 ::core::result::Result::Ok(::core::clone::Clone::clone(__tb_value))
1064 }
1065 ::core::result::Result::Err(__tb_error) => {
1066 ::core::result::Result::Err(
1067 ::test_better::TestError::with_message(
1068 ::test_better::TestError::new(
1069 ::test_better::ErrorKind::Setup,
1070 ),
1071 ::std::format!(
1072 "module-scoped fixture `{}` failed during setup: {}",
1073 #fn_name_str,
1074 __tb_error,
1075 ),
1076 ),
1077 )
1078 }
1079 }
1080 },
1081 };
1082
1083 Ok(quote! {
1084 #(#forwarded)*
1085 #vis fn #fn_name() #ret {
1086 #outer_body
1087 }
1088 })
1089}
1090
1091#[proc_macro_attribute]
1117pub fn fixture(attr: TokenStream, item: TokenStream) -> TokenStream {
1118 let args = match syn::parse::<FixtureArgs>(attr) {
1119 Ok(args) => args,
1120 Err(error) => return error.to_compile_error().into(),
1121 };
1122 let func = match syn::parse::<ItemFn>(item) {
1123 Ok(func) => func,
1124 Err(error) => return error.to_compile_error().into(),
1125 };
1126 match fixture_impl(args, func) {
1127 Ok(tokens) => tokens.into(),
1128 Err(error) => error.to_compile_error().into(),
1129 }
1130}
1131
1132fn test_with_fixtures_impl(mut func: ItemFn) -> syn::Result<TokenStream2> {
1135 let mut params = Vec::with_capacity(func.sig.inputs.len());
1136 for input in &func.sig.inputs {
1137 match input {
1138 FnArg::Receiver(receiver) => {
1139 return Err(syn::Error::new_spanned(
1140 receiver,
1141 "a `#[test_with_fixtures]` function cannot take `self`",
1142 ));
1143 }
1144 FnArg::Typed(pat_type) => match &*pat_type.pat {
1145 Pat::Ident(pat_ident) => params.push(pat_ident.ident.clone()),
1146 other => {
1147 return Err(syn::Error::new_spanned(
1148 other,
1149 "each `#[test_with_fixtures]` parameter must be a plain \
1150 `name: Type`, where `name` is the fixture function",
1151 ));
1152 }
1153 },
1154 }
1155 }
1156
1157 let forwarded: Vec<syn::Attribute> = std::mem::take(&mut func.attrs);
1160 let fn_name = func.sig.ident.clone();
1161 let vis = func.vis.clone();
1162 let ret = func.sig.output.clone();
1163
1164 let mut inner = func;
1167 inner.sig.ident = format_ident!("__tb_inner");
1168 inner.vis = syn::Visibility::Inherited;
1169
1170 Ok(quote! {
1171 #(#forwarded)*
1172 #[test]
1173 #vis fn #fn_name() #ret {
1174 #inner
1175 #( let #params = #params()?; )*
1176 __tb_inner(#(#params),*)
1177 }
1178 })
1179}
1180
1181#[proc_macro_attribute]
1205pub fn test_with_fixtures(_attr: TokenStream, item: TokenStream) -> TokenStream {
1206 let func = match syn::parse::<ItemFn>(item) {
1207 Ok(func) => func,
1208 Err(error) => return error.to_compile_error().into(),
1209 };
1210 match test_with_fixtures_impl(func) {
1211 Ok(tokens) => tokens.into(),
1212 Err(error) => error.to_compile_error().into(),
1213 }
1214}