1use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed};
6
7#[proc_macro_derive(Mock, attributes(mock, mock_default))]
10pub fn derive_mock(token_stream: TokenStream) -> TokenStream {
11 derive_mock_impl(token_stream)
12}
13
14fn derive_mock_impl(token_stream: TokenStream) -> TokenStream {
15 let type_definition = parse_macro_input!(token_stream as DeriveInput);
16
17 let identifier = type_definition.ident;
18
19 let cfg_scope = match cfg_scope(type_definition.attrs) {
20 Ok(scope) => scope,
21 Err(err) => return err.into_compile_error().into(),
22 };
23
24 let self_definition_result = match type_definition.data {
25 Data::Struct(data_struct) => derive_struct(data_struct),
26 Data::Enum(data_enum) => derive_enum(data_enum),
27 Data::Union(data_union) => Err(syn::Error::new(data_union.union_token.span, "union types not supported")),
29 };
30
31 match self_definition_result {
32 Ok(self_definition) => {
33 quote! {
34 #cfg_scope
35 impl ::damock::Mock for #identifier {
36 fn mock() -> Self {
37 #self_definition
38 }
39 }
40 }
41 }
42 Err(err) => err.to_compile_error(),
43 }
44 .into()
45}
46
47fn cfg_scope(container_attributes: Vec<syn::Attribute>) -> syn::Result<proc_macro2::TokenStream> {
48 let mut cfg_override: Option<syn::MetaNameValue> = None;
49
50 let mock_attributes = container_attributes.into_iter().filter(|attribute| {
51 matches!(&attribute.meta,
52 syn::Meta::List(meta_list) if meta_list.path.is_ident("mock"))
53 });
54
55 for mock_attribute in mock_attributes {
56 let cfg_args: syn::MetaNameValue = mock_attribute.parse_args()?;
57
58 match &cfg_override {
59 Some(_pre_existing) => {
60 Err(syn::Error::new(cfg_args.span(), "multiple #[cfg], values provided"))?;
61 }
62 None => cfg_override = Some(cfg_args),
63 }
64 }
65
66 Ok(match cfg_override {
67 Some(overrides) => quote! { #[cfg(#overrides)] },
68 None => quote! { #[cfg(test)] },
69 })
70}
71
72fn derive_struct(data_struct: syn::DataStruct) -> syn::Result<proc_macro2::TokenStream> {
73 Ok(match data_struct.fields {
74 Fields::Named(named_fields) => {
75 let fields = fields::named(named_fields);
76
77 quote! {
78 Self {
79 #(#fields),*
80 }
81 }
82 }
83 Fields::Unnamed(tuple_fields) => {
84 let fields = fields::tuple(tuple_fields);
85
86 quote! { Self(#(#fields),*) }
87 }
88 Fields::Unit => quote! { Self },
89 })
90}
91
92fn derive_enum(data_enum: syn::DataEnum) -> syn::Result<proc_macro2::TokenStream> {
93 let mut variant_to_mock_iter = data_enum.variants.into_iter().filter_map(|variant| {
94 variant
95 .attrs
96 .clone()
97 .iter()
98 .find(|attribute| match &attribute.meta {
99 syn::Meta::Path(path) => path.is_ident("mock"),
100 _ => false,
101 })
102 .map(|_| variant)
103 });
104
105 let Some(variant_to_mock) = variant_to_mock_iter.next() else {
106 return Err(syn::Error::new(
107 data_enum.enum_token.span,
108 "no #[mock] attribute found in any of the listed variants",
109 ));
110 };
111
112 if let Some(_another_variant_to_mock) = variant_to_mock_iter.next() {
113 return Err(syn::Error::new(
114 data_enum.enum_token.span,
115 "expected only one #[mock] enum variant attribute, unable to infer which one to use.",
116 ));
117 }
118
119 let variant_name = variant_to_mock.ident;
120
121 Ok(match variant_to_mock.fields {
122 Fields::Named(named_fields) => {
123 let fields = fields::named(named_fields);
124
125 quote! {
126 Self::#variant_name {
127 #(#fields),*
128 }
129 }
130 }
131 Fields::Unnamed(tuple_fields) => {
132 let fields = fields::tuple(tuple_fields);
133 quote! {
134 Self::#variant_name(#(#fields),*)
135 }
136 }
137 Fields::Unit => {
138 quote! {
139 Self::#variant_name
140 }
141 }
142 })
143}
144
145mod fields {
146 use super::*;
147
148 pub fn named(named_fields: FieldsNamed) -> impl Iterator<Item = proc_macro2::TokenStream> {
149 named_fields
150 .named
151 .into_iter()
152 .map(|field| {
153 (
154 field.ident.expect("encountered named field without an identifier"),
155 mock_or_default(field.attrs),
156 )
157 })
158 .map(|(field_name, mock_or_default)| quote! { #field_name: #mock_or_default })
159 }
160
161 pub fn tuple(tuple_fields: syn::FieldsUnnamed) -> impl Iterator<Item = proc_macro2::TokenStream> {
162 tuple_fields.unnamed.into_iter().map(|field| mock_or_default(field.attrs))
163 }
164
165 fn mock_or_default(field_attributes: Vec<syn::Attribute>) -> proc_macro2::TokenStream {
166 match field_attributes
167 .into_iter()
168 .any(|attribute| matches!(&attribute.meta, syn::Meta::Path(path) if path.is_ident("mock_default")))
169 {
170 true => quote! { Default::default() },
171 false => quote! { ::damock::Mock::mock() },
172 }
173 }
174}