Skip to main content

test_that_macro/
lib.rs

1// Copyright 2022 Google LLC
2// Copyright 2026 Bradford Hovinen <bradford@hovinen.me>
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use quote::quote;
17use syn::{Attribute, ItemFn, ReturnType, parse_macro_input};
18
19/// Marks a test which may have non fatal assertions.
20///
21/// Annotate tests the same way ordinary Rust tests are annotated:
22///
23/// ```ignore
24/// #[test_that::test]
25/// fn should_work() {
26///    ...
27/// }
28/// ```
29///
30/// The test function is not required to have a return type. If it does have a
31/// return type, that type must be [`test_that::Result`]. One may do this if
32/// one wishes to use both fatal and non-fatal assertions in the same test. For
33/// example:
34///
35/// ```
36/// # use test_that::prelude::*;
37/// #[test_that::test]
38/// fn should_work() -> TestResult<()> {
39///     let value = 2;
40///     expect_that!(value, gt(0));
41///     verify_that!(value, eq(2))
42/// }
43/// ```
44///
45/// This macro can be used with `#[should_panic]` to indicate that the test is
46/// expected to panic. For example:
47///
48/// ```
49/// # use test_that::prelude::*;
50/// #[test_that::test]
51/// #[should_panic]
52/// fn passes_due_to_should_panic() {
53///     let value = 2;
54///     expect_that!(value, gt(0));
55///     panic!("This panics");
56/// }
57/// ```
58///
59/// Using `#[should_panic]` modifies the behaviour of `#[test_that::test]` so
60/// that the test panics (and passes) if any non-fatal assertion occurs.
61/// For example, the following test passes:
62///
63/// ```
64/// # use test_that::prelude::*;
65/// #[test_that::test]
66/// #[should_panic]
67/// fn passes_due_to_should_panic_and_failing_assertion() {
68///     let value = 2;
69///     expect_that!(value, eq(0));
70/// }
71/// ```
72///
73/// This integrates with other common test attribute macros such as [`tokio::test`]
74/// and [`rstest`]. Just apply both attribute macros to your test.
75///
76/// ```ignore
77/// #[test_that::test]
78/// #[rstest]
79/// #[case(1)]
80/// #[case(2)]
81/// #[case(3)]
82/// fn rstest_works_with_test_that(#[case] value: u32) -> Result<()> {
83///     verify_that!(value, gt(0))
84/// }
85///
86/// #[test_that::test]
87/// #[tokio::test]
88/// async fn tokio_works_with_test_that() -> Result<()> {
89///     verify_that!(get_some_value_async().await, gt(0))
90/// }
91/// ```
92///
93/// > **Note:**
94/// > In the case of rstest, make sure to put `#[test_that::test]` *before*
95/// > `#[rstest]`. Otherwise the annotated test will run twice, since both macros will
96/// > attempt to register a test with the Rust test harness.
97///
98/// [`test_that::Result`]: type.Result.html
99/// [`tokio::test`]: https://docs.rs/tokio/latest/tokio/attr.test.html
100/// [`rstest`]: https://docs.rs/rstest/latest/rstest/attr.rstest.html
101#[proc_macro_attribute]
102pub fn test(
103    _args: proc_macro::TokenStream,
104    input: proc_macro::TokenStream,
105) -> proc_macro::TokenStream {
106    let mut parsed_fn = parse_macro_input!(input as ItemFn);
107    let attrs = parsed_fn.attrs.drain(..).collect::<Vec<_>>();
108    let (mut sig, block) = (parsed_fn.sig, parsed_fn.block);
109    let (outer_return_type, trailer) =
110        if attrs.iter().any(|attr| attr.path().is_ident("should_panic")) {
111            (quote! { () }, quote! { .unwrap(); })
112        } else {
113            (
114                quote! { std::result::Result<(), test_that::internal::test_outcome::TestFailure> },
115                quote! {},
116            )
117        };
118    let output_type = match sig.output.clone() {
119        ReturnType::Type(_, output_type) => Some(output_type),
120        ReturnType::Default => None,
121    };
122    sig.output = ReturnType::Default;
123    let (maybe_closure, invocation) = if sig.asyncness.is_some() {
124        (
125            // In the async case, the ? operator returns from the *block* rather than the
126            // surrounding function. So we just put the test content in an async block. Async
127            // closures are still unstable (see https://github.com/rust-lang/rust/issues/62290),
128            // so we can't use the same solution as the sync case below.
129            quote! {},
130            quote! {
131                async { #block }.await
132            },
133        )
134    } else {
135        (
136            // In the sync case, the ? operator returns from the surrounding function. So we must
137            // create a separate closure from which the ? operator can return in order to capture
138            // the output.
139            quote! {
140                let test = move || #block;
141            },
142            quote! {
143                test()
144            },
145        )
146    };
147    let function = if let Some(output_type) = output_type {
148        quote! {
149            #(#attrs)*
150            #sig -> #outer_return_type {
151                #maybe_closure
152                test_that::internal::test_outcome::TestOutcome::init_current_test_outcome();
153                let result: #output_type = #invocation;
154                test_that::internal::test_outcome::TestOutcome::close_current_test_outcome(result)
155                #trailer
156            }
157        }
158    } else {
159        quote! {
160            #(#attrs)*
161            #sig -> #outer_return_type {
162                #maybe_closure
163                test_that::internal::test_outcome::TestOutcome::init_current_test_outcome();
164                #invocation;
165                test_that::internal::test_outcome::TestOutcome::close_current_test_outcome(test_that::TestResult::Ok(()))
166                #trailer
167            }
168        }
169    };
170    let output = if attrs.iter().any(is_test_attribute) {
171        function
172    } else {
173        quote! {
174            #[::core::prelude::v1::test]
175            #function
176        }
177    };
178    output.into()
179}
180
181fn is_test_attribute(attr: &Attribute) -> bool {
182    let first_segment = match attr.path().segments.first() {
183        Some(first_segment) => first_segment,
184        None => return false,
185    };
186    let last_segment = match attr.path().segments.last() {
187        Some(last_segment) => last_segment,
188        None => return false,
189    };
190    last_segment.ident == "test"
191        || (first_segment.ident == "rstest"
192            && last_segment.ident == "rstest"
193            && attr.path().segments.len() <= 2)
194}