mock_it_codegen/
lib.rs

1extern crate proc_macro;
2
3mod generics;
4mod mock_fn;
5mod trait_method;
6
7use generics::MockItTraitGenerics;
8use mock_fn::{mock_fns, MockFn};
9use proc_macro2::TokenStream;
10use quote::quote;
11use syn::{parse_macro_input, Generics, Ident, Item};
12use trait_method::get_trait_method_types;
13
14/// Generate a mock struct from a trait. The mock struct will be named after the
15/// trait, with "Mock" appended.
16#[proc_macro_attribute]
17pub fn mock_it(
18    _attr: proc_macro::TokenStream,
19    item: proc_macro::TokenStream,
20) -> proc_macro::TokenStream {
21    // Parse the tokens
22    let input: Item = parse_macro_input!(item as Item);
23
24    // Make sure it's a trait
25    let item_trait = match input {
26        Item::Trait(item_trait) => item_trait,
27        _ => panic!("Only traits can be mocked with the mock_it macro"),
28    };
29
30    let trait_method_types = get_trait_method_types(&item_trait);
31    let mock_fns = mock_fns(trait_method_types.clone());
32    let helper_functions: Vec<TokenStream> = mock_fns
33        .iter()
34        .map(|mock_fn| mock_fn.helper_functions())
35        .collect();
36
37    // Create the mock identifier
38    let trait_ident = &item_trait.ident;
39    let mock_ident = Ident::new(&format!("{}Mock", trait_ident), trait_ident.span());
40
41    // Generate the mock
42    let fields = create_fields(&mock_fns);
43    let field_init = create_field_init(&mock_ident, &mock_fns);
44    let trait_impls = create_trait_impls(&mock_fns);
45    let clone_impl = create_clone_impl(&mock_fns);
46    let async_attribute = async_attribute(&mock_fns);
47
48    // Configure trait generics
49    let generics = configure_trait_generics(&mock_fns, &item_trait.generics);
50    let (generics_impl, generics_ty, generics_where) = generics.split_for_impl();
51
52    let output = quote! {
53        #item_trait
54
55        #[derive(Debug)]
56        pub struct #mock_ident #generics_ty #generics_where {
57            #(#fields),*
58        }
59
60        impl #generics_impl #mock_ident #generics_ty #generics_where {
61            pub fn new() -> Self {
62                #mock_ident {
63                    #(#field_init),*
64                }
65            }
66
67            #(#helper_functions)*
68        }
69
70        impl #generics_impl std::clone::Clone for #mock_ident #generics_ty #generics_where {
71            fn clone(&self) -> Self {
72                #mock_ident {
73                    #(#clone_impl),*
74                }
75            }
76        }
77
78        #async_attribute
79        impl #generics_impl #trait_ident #generics_ty for #mock_ident #generics_ty #generics_where {
80            #(#trait_impls)*
81        }
82    };
83
84    output.into()
85}
86
87fn configure_trait_generics(mock_fns: &Vec<MockFn>, generics: &Generics) -> Generics {
88    let mut trait_generics = MockItTraitGenerics::new(generics);
89    trait_generics.configure_predicates(&mock_fns);
90    trait_generics.into()
91}
92
93fn async_attribute(mock_fns: &Vec<MockFn>) -> TokenStream {
94    for mock_fn in mock_fns.iter() {
95        if mock_fn.is_async() {
96            return quote! { #[async_trait::async_trait] }.into();
97        }
98    }
99
100    quote! {}.into()
101}
102
103/// Create the struct fields
104fn create_fields(mock_fns: &Vec<MockFn>) -> Vec<TokenStream> {
105    mock_fns
106        .iter()
107        .map(|mock_fn| {
108            let name = mock_fn.name();
109            let return_input_types = mock_fn.return_input_types();
110            let return_output_type = mock_fn.return_output_type();
111
112            quote! {
113                #name: mock_it::Mock<#return_input_types, #return_output_type>
114            }
115        })
116        .collect()
117}
118
119/// Create the field initializers for the `new` method
120fn create_field_init(mock_ident: &Ident, mock_fns: &Vec<MockFn>) -> Vec<TokenStream> {
121    mock_fns
122        .iter()
123        .map(|mock_fn| {
124            let name = mock_fn.name();
125
126            quote! {
127                #name: mock_it::Mock::new(format!("{}.{}", stringify!(#mock_ident), stringify!(#name)))
128            }
129        })
130        .collect()
131}
132
133/// Create the clone implementation
134fn create_clone_impl(mock_fns: &Vec<MockFn>) -> impl Iterator<Item = TokenStream> + '_ {
135    mock_fns.iter().map(|mock_fn| {
136        let ident = &mock_fn.signature().ident;
137        quote! {
138            #ident: self.#ident.clone()
139        }
140    })
141}
142
143/// Create the trait method implementations
144fn create_trait_impls(mock_fns: &Vec<MockFn>) -> impl Iterator<Item = TokenStream> + '_ {
145    mock_fns.iter().map(|mock_fn| {
146        let called_fn_name = mock_fn.called_fn_name();
147        let arg_names = mock_fn.args().into_iter().map(|arg| {
148            let name = &arg.name;
149            quote! {
150                mock_it::Matcher::Val(#name)
151            }
152        });
153        let signature = mock_fn.signature();
154
155        quote! {
156            #signature {
157                self.#called_fn_name(#(#arg_names),*)
158            }
159        }
160    })
161}