1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#[macro_use]
extern crate syn;
extern crate proc_macro;

use linked_hash_map::LinkedHashMap;
use proc_macro2::TokenStream;
use quote::quote;
use syn::export::fmt::Display;
use syn::parse_macro_input;
use syn::spanned::Spanned;

mod attribute;

#[proc_macro_attribute]
pub fn parameterized(
    args: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    // v := { ...exprs },
    // exprs := { e1, e2, e3, ... }
    // e := EXPR
    let argument_lists = parse_macro_input!(args as attribute::AttributeArgList);
    let func = parse_macro_input!(input as syn::ItemFn);

    let name = &func.sig.ident;
    let vis = &func.vis;
    let func_args = &func.sig.inputs;
    let body_block = &func.block;
    let attributes = &func.attrs;

    let mod_name = format!("{}", name);
    let mod_ident = syn::Ident::new(mod_name.as_str(), name.span());

    // For each provided argument (per parameter), we create a let bind at the start of the fn:
    // * `let #ident: #ty = #expr;`
    // After that we append the body of the test function
    let identifiers_len = argument_lists.args.len();

    let values = argument_lists
        .args
        .iter()
        .map(|v| {
            (
                v.id.clone(),
                v.param_args.iter().cloned().collect::<Vec<syn::Expr>>(),
            )
        })
        .collect::<LinkedHashMap<syn::Ident, Vec<syn::Expr>>>();

    // interlude: ensure that the parameterized test definition contain unique identifiers.
    if values.len() != identifiers_len {
        panic!("[parameterized-macro] error: Duplicate identifier(s) found. Please use unique parameter names.")
    }

    let amount_of_test_cases = check_all_input_lengths(&values);

    let test_case_fns = (0..amount_of_test_cases).map(|i| {
        let binds: Vec<TokenStream> = func_args
            .iter()
            .map(|fn_arg| {
                if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = fn_arg {
                    if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = pat.as_ref() {
                        // Now we use to identifier from the function signature to get the
                        // current (i) test case we are creating.
                        //
                        // If we have `#[parameterized(chars = { 'a', 'b' }, ints = { 1, 2 }]
                        // and the function signature is `fn my_test(chars: char, ints: i8) -> ()`
                        //
                        // then we will two test cases.
                        //
                        // The first test case will substitute (for your mental image,
                        // because in reality it will create let bindings at the start of the
                        // generated test function) the first expressions from the identified
                        // argument lists, in this case from `chars`, `a` and from `ints`, `1`.
                        // The second test case does the same
                        if let Some(exprs) = values.get(ident) {
                            let expr = &exprs[i];

                            // A let binding is constructed so we can type check the given expression.
                            return quote! {
                                let #ident: #ty = #expr;
                            };
                        } else {
                            panic!("[parameterized-macro] error: No matching values found for '{}'", ident);
                        }
                    } else {
                        panic!("[parameterized-macro] error: Function parameter identifier was not found");
                    }
                } else {
                    panic!("[parameterized-macro] error: Given function argument should be typed");
                }
            })
            .collect(); // end of construction of let bindings

        let ident = format!("case_{}", i);
        let ident = syn::Ident::new(ident.as_str(), func.span()); // fixme: span

        quote! {
            #[test]
            #(#attributes)*
            #vis fn #ident() {
                #(#binds)*

                #body_block
            }
        }
    });

    // we need to include `use super::*` since we put the test cases in a new module
    let token_stream = quote! {
        #[cfg(test)]
        #vis mod #mod_ident {
            use super::*;

            #(#test_case_fns)*
        }
    };

    token_stream.into()
}

/// Checks whether all inputs have equal length.
///
/// All inputs should have equal lengths. Take for example the following example parameterized definition:
/// `#[parameterized(v = { "a", "b", "c" }, w = { 1, 2 })]`
/// Here the length of `v` is 3, while the length of `w` is 2.
/// Since within individual constructed test cases, for all identifiers, values are matched one-by-one
/// the first test shall define `"a"` and `1`, the second `"b"` and 2, but for the third case,
/// a value for `v` exists (namely `"c"`), however no value to substitute for `w` exists.
/// Therefore, no fully valid set of tests can be constructed from the parameterized definition.
fn check_all_input_lengths(map: &LinkedHashMap<syn::Ident, Vec<syn::Expr>>) -> usize {
    let mut arguments: Option<usize> = None;
    for (ident, values) in map.iter() {
        match arguments {
            Some(len) if len == values.len() => continue,
            None => arguments = Some(values.len()),
            _ => panic_on_inequal_length(map.iter(), ident),
        }
    }

    arguments.unwrap_or_default()
}

/// When this function gets invoked, it will construct an error message and then panic! with that message.
fn panic_on_inequal_length<K: Ord + Display, V, D: Display>(
    map: impl Iterator<Item = (K, V)>,
    ident: D,
) {
    let ids: String = map
        .map(|(id, _)| format!("{}", id))
        .collect::<Vec<String>>()
        .join(", ");

    panic!(
        "[parameterized-macro] error: Inconsistent argument list length for '{}'; all inputs ({}) should have equal length",
        ident,
        ids
    )
}