damock_macros/
lib.rs

1//! Derive proc-macro  definitions for the `Mock` trait.
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed};
6
7/// Derives `Mock` for both structs and enums if all their fields implement either `Mock` or
8/// `Default`.
9#[proc_macro_derive(Mock, attributes(mock, mock_default))]
10pub fn derive_mock(token_stream: TokenStream) -> TokenStream {
11    derive_mock_impl(token_stream)
12}
13
14fn derive_mock_impl(token_stream: TokenStream) -> TokenStream {
15    let type_definition = parse_macro_input!(token_stream as DeriveInput);
16
17    let identifier = type_definition.ident;
18
19    let cfg_scope = match cfg_scope(type_definition.attrs) {
20        Ok(scope) => scope,
21        Err(err) => return err.into_compile_error().into(),
22    };
23
24    let self_definition_result = match type_definition.data {
25        Data::Struct(data_struct) => derive_struct(data_struct),
26        Data::Enum(data_enum) => derive_enum(data_enum),
27        // TODO:
28        Data::Union(data_union) => Err(syn::Error::new(data_union.union_token.span, "union types not supported")),
29    };
30
31    match self_definition_result {
32        Ok(self_definition) => {
33            quote! {
34                #cfg_scope
35                impl ::damock::Mock for #identifier {
36                    fn mock() -> Self {
37                        #self_definition
38                    }
39                }
40            }
41        }
42        Err(err) => err.to_compile_error(),
43    }
44    .into()
45}
46
47fn cfg_scope(container_attributes: Vec<syn::Attribute>) -> syn::Result<proc_macro2::TokenStream> {
48    let mut cfg_override: Option<syn::MetaNameValue> = None;
49
50    let mock_attributes = container_attributes.into_iter().filter(|attribute| {
51        matches!(&attribute.meta,
52            syn::Meta::List(meta_list) if meta_list.path.is_ident("mock"))
53    });
54
55    for mock_attribute in mock_attributes {
56        let cfg_args: syn::MetaNameValue = mock_attribute.parse_args()?;
57
58        match &cfg_override {
59            Some(_pre_existing) => {
60                Err(syn::Error::new(cfg_args.span(), "multiple #[cfg], values provided"))?;
61            }
62            None => cfg_override = Some(cfg_args),
63        }
64    }
65
66    Ok(match cfg_override {
67        Some(overrides) => quote! { #[cfg(#overrides)] },
68        None => quote! { #[cfg(test)] },
69    })
70}
71
72fn derive_struct(data_struct: syn::DataStruct) -> syn::Result<proc_macro2::TokenStream> {
73    Ok(match data_struct.fields {
74        Fields::Named(named_fields) => {
75            let fields = fields::named(named_fields);
76
77            quote! {
78                Self {
79                    #(#fields),*
80                }
81            }
82        }
83        Fields::Unnamed(tuple_fields) => {
84            let fields = fields::tuple(tuple_fields);
85
86            quote! { Self(#(#fields),*) }
87        }
88        Fields::Unit => quote! { Self },
89    })
90}
91
92fn derive_enum(data_enum: syn::DataEnum) -> syn::Result<proc_macro2::TokenStream> {
93    let mut variant_to_mock_iter = data_enum.variants.into_iter().filter_map(|variant| {
94        variant
95            .attrs
96            .clone()
97            .iter()
98            .find(|attribute| match &attribute.meta {
99                syn::Meta::Path(path) => path.is_ident("mock"),
100                _ => false,
101            })
102            .map(|_| variant)
103    });
104
105    let Some(variant_to_mock) = variant_to_mock_iter.next() else {
106        return Err(syn::Error::new(
107            data_enum.enum_token.span,
108            "no #[mock] attribute found in any of the listed variants",
109        ));
110    };
111
112    if let Some(_another_variant_to_mock) = variant_to_mock_iter.next() {
113        return Err(syn::Error::new(
114            data_enum.enum_token.span,
115            "expected only one #[mock] enum variant attribute, unable to infer which one to use.",
116        ));
117    }
118
119    let variant_name = variant_to_mock.ident;
120
121    Ok(match variant_to_mock.fields {
122        Fields::Named(named_fields) => {
123            let fields = fields::named(named_fields);
124
125            quote! {
126                Self::#variant_name {
127                    #(#fields),*
128                }
129            }
130        }
131        Fields::Unnamed(tuple_fields) => {
132            let fields = fields::tuple(tuple_fields);
133            quote! {
134                Self::#variant_name(#(#fields),*)
135            }
136        }
137        Fields::Unit => {
138            quote! {
139                Self::#variant_name
140            }
141        }
142    })
143}
144
145mod fields {
146    use super::*;
147
148    pub fn named(named_fields: FieldsNamed) -> impl Iterator<Item = proc_macro2::TokenStream> {
149        named_fields
150            .named
151            .into_iter()
152            .map(|field| {
153                (
154                    field.ident.expect("encountered named field without an identifier"),
155                    mock_or_default(field.attrs),
156                )
157            })
158            .map(|(field_name, mock_or_default)| quote! { #field_name: #mock_or_default })
159    }
160
161    pub fn tuple(tuple_fields: syn::FieldsUnnamed) -> impl Iterator<Item = proc_macro2::TokenStream> {
162        tuple_fields.unnamed.into_iter().map(|field| mock_or_default(field.attrs))
163    }
164
165    fn mock_or_default(field_attributes: Vec<syn::Attribute>) -> proc_macro2::TokenStream {
166        match field_attributes
167            .into_iter()
168            .any(|attribute| matches!(&attribute.meta, syn::Meta::Path(path) if path.is_ident("mock_default")))
169        {
170            true => quote! { Default::default() },
171            false => quote! { ::damock::Mock::mock() },
172        }
173    }
174}