datafile_test/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use std::fs;
6use syn::{parse_macro_input, ItemFn, LitStr};
7
8/// Define data-file-driven tests using JSON/YAML files.
9///
10/// This attribute macro reads a JSON/YAML file at compile time and generates a test function for each
11/// test case in the file. The test function must take a single argument, which is a structured type
12/// that implements `serde::Deserialize`.
13/// The file is read from the file system relative to the current working directory of the
14/// compiler.
15///
16/// Note that `serde` and `serde_json` crate is required in caller's `Cargo.toml`.
17///
18/// # Example
19/// ```rust
20/// use datafile_test::datafile_test;
21///
22/// #[derive(Debug, serde::Deserialize)]
23/// struct TestCaseInput {
24///     a: i32,
25///     b: i32,
26/// }
27///
28/// #[derive(Debug, serde::Deserialize)]
29/// struct TestCase {
30///     input: TestCaseInput,
31///     output: String,
32/// }
33///
34/// #[datafile_test("tests/testcase.yml")]
35/// fn test(testcase: TestCase) {
36///     assert_eq!(testcase.input.a + testcase.input.b, testcase.output);
37/// }
38/// ```
39///
40/// The yaml file should look like this:
41///
42/// ```yaml
43/// - input:
44///     a: 1
45///     b: 2
46///   expect: 3
47/// - input:
48///     a: 2
49///     b: 3
50///   expect: 5
51/// ```
52///
53#[proc_macro_attribute]
54pub fn datafile_test(attr: TokenStream, item: TokenStream) -> TokenStream {
55    // Parse attribute
56    let attr = parse_macro_input!(attr as LitStr);
57    let file_path = attr.value();
58
59    // Parse the function item
60    let input_fn = parse_macro_input!(item as ItemFn);
61    let fn_name = &input_fn.sig.ident;
62    let fn_body = &input_fn.block;
63    let fn_args = &input_fn.sig.inputs;
64
65    // Ensure the function has exactly one argument
66    if fn_args.len() != 1 {
67        return syn::Error::new_spanned(
68            &input_fn.sig,
69            "datafile_test function must have exactly one argument",
70        )
71        .to_compile_error()
72        .into();
73    }
74
75    let test_case_type = match fn_args.first().unwrap() {
76        syn::FnArg::Typed(pat_type) => &pat_type.ty,
77        _ => {
78            return syn::Error::new_spanned(
79                &input_fn.sig,
80                "datafile_test function must take a structured argument",
81            )
82            .to_compile_error()
83            .into();
84        }
85    };
86
87    // Load JSON/YAML file at compile time
88    let data_text = match fs::read_to_string(&file_path) {
89        Ok(content) => content,
90        Err(e) => {
91            return syn::Error::new_spanned(
92                &attr,
93                format!("Failed to read data file '{:?}': {}", &file_path, e),
94            )
95            .to_compile_error()
96            .into();
97        }
98    };
99
100    let ext = std::path::Path::new(&file_path)
101        .extension()
102        .and_then(std::ffi::OsStr::to_str)
103        .unwrap_or_default()
104        .to_lowercase();
105    // Parse JSON/YAML into Vec<serde_json::Value>
106    let test_cases: Vec<serde_json::Value> = match ext.as_str() {
107        "json" => match serde_json::from_str(&data_text) {
108            Ok(cases) => cases,
109            Err(e) => {
110                return syn::Error::new_spanned(
111                    &attr,
112                    format!("Failed to parse JSON file '{:?}': {}", &file_path, e),
113                )
114                .to_compile_error()
115                .into();
116            }
117        },
118        "yaml" | "yml" => match serde_yaml::from_str(&data_text) {
119            Ok(cases) => cases,
120            Err(e) => {
121                return syn::Error::new_spanned(
122                    &attr,
123                    format!("Failed to parse YAML file '{:?}': {}", &file_path, e),
124                )
125                .to_compile_error()
126                .into();
127            }
128        },
129        _ => {
130            return syn::Error::new_spanned(
131                &attr,
132                format!("Unsupported file extension: {:?}", ext),
133            )
134            .to_compile_error()
135            .into();
136        }
137    };
138
139    // Generate test functions for each case
140    let test_fns: Vec<_> = test_cases
141        .iter()
142        .enumerate()
143        .map(|(i, test_case)| {
144            let test_fn_name = format_ident!("{}_case_{}", fn_name, i);
145
146            // Convert serde_yaml::Value to JSON string
147            let json_str = match serde_json::to_string(test_case) {
148                Ok(s) => s,
149                Err(e) => {
150                    return syn::Error::new_spanned(
151                        &attr,
152                        format!("Failed to convert test case to JSON: {}", e),
153                    )
154                    .to_compile_error();
155                }
156            };
157
158            // Convert JSON string to Rust expression
159            let test_case_expr: syn::Expr = match syn::parse_str(&format!(
160                "serde_json::from_str::<{}>({:?}).unwrap()",
161                quote!(#test_case_type),
162                json_str
163            )) {
164                Ok(expr) => expr,
165                Err(e) => {
166                    return syn::Error::new_spanned(
167                        &attr,
168                        format!("Failed to parse test case JSON as Rust expression: {}", e),
169                    )
170                    .to_compile_error();
171                }
172            };
173
174            quote! {
175                #[test]
176                fn #test_fn_name() {
177                    let testcase: #test_case_type = #test_case_expr;
178                    #fn_body
179                }
180            }
181        })
182        .collect();
183
184    let output = quote! {
185        #(#test_fns)*
186    };
187
188    output.into()
189}