double_derive/
lib.rs

1use quote::quote;
2use syn::{Error, Ident, ItemTrait, parse_macro_input};
3
4mod double_trait;
5mod trait_impl;
6
7use self::{double_trait::double_trait, trait_impl::trait_impl};
8
9/// Generates a trait which replicates the original trait method for method. It does implement the
10/// original trait for each of its implementations, by means of forwarding the method calls. The
11/// utility comes from the fact that the generated trait has default implementations for each method
12/// using `unimplemented!()`, which makes it useful for testing purposes.
13///
14/// If a test requires an implementation of an original trait `Org` yet would only invoke one of its
15/// methods, implementing the mirrored method on an implementation of the generated trait `OrgDummy`
16/// is sufficient. The other methods would not be inovked in the test, so their default
17/// implementation using `unimplemented!()` would not be reached.
18///
19/// The argument passed to the attribute is used as the name of the generated trait.
20///
21/// * Existing default implementations are respected and not overridden.
22/// * Visibility of the generated trait is the same as the original trait.
23/// * `async` methods are supported
24/// * Methods returning `impl` Traits are not supported, with the exception of `impl Future`.
25/// * Generated double trait is implemented for `Dummy`.
26/// 
27/// # Example
28/// 
29/// Basic usage allows creating test stubs for traits, without worrying about implementing methods
30/// not called in test code
31/// 
32/// ```no_run
33/// use double_trait::double;
34/// 
35/// #[double(MyTraitDouble)]
36/// trait MyTrait {
37///    fn answer(&self) -> i32;
38/// 
39///    fn some_other_method(&self);
40/// }
41///  
42/// struct MyStub;
43/// 
44/// impl MyTraitDouble for MyStub {
45///     fn answer(&self) -> i32 {
46///         42
47///     }
48/// }
49/// 
50/// assert_eq!(42, MyTrait::answer(&MyStub));
51/// ```
52/// 
53/// Then interacting with the `async_trait` crate, make sure to put the `#[async_trait]` attribute
54/// on top.
55/// 
56/// ```no_run
57/// use double_trait::double;
58/// use async_trait::async_trait;
59///
60/// #[async_trait]
61/// #[double(MyTraitDouble)]
62/// trait MyTrait {
63///     async fn answer(&self) -> i32;
64/// }
65/// ```
66#[proc_macro_attribute]
67pub fn double(
68    attr: proc_macro::TokenStream,
69    item: proc_macro::TokenStream,
70) -> proc_macro::TokenStream {
71    let double_name = parse_macro_input!(attr as Ident);
72    let item = parse_macro_input!(item as ItemTrait);
73
74    let output = expand(double_name, item).unwrap_or_else(Error::into_compile_error);
75
76    proc_macro::TokenStream::from(output)
77}
78
79/// The main implementation of [`crate::double`]. This function is not annotated with
80/// `#[proc_macro_attribute]` so it can exist in unit tests. It uses only APIs build on top of
81/// [`proc_macro2`] in order to be unit testable.
82fn expand(double_trait_name: Ident, org_trait: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
83    let double_trait = double_trait(double_trait_name.clone(), org_trait.clone())?;
84    let trait_impl = trait_impl(double_trait_name.clone(), org_trait.clone());
85
86    // We generate three items as part of our output.
87    // 1. The orginal trait, which we put in the output unaltered.
88    // 2. The double trait, we genarate, which mirrors the original traits methods and provides
89    //    default implementations using `unimplemented!()`.
90    // 3. An implementation of the original trait for all types which implement the double trait.
91    //    This is done by forwarding the method calls to the double trait.
92    let token_stream = quote! {
93        #org_trait
94
95        #double_trait
96
97        #trait_impl
98
99        impl #double_trait_name for double_trait::Dummy{}
100    };
101    Ok(token_stream)
102}
103
104#[cfg(test)]
105mod tests {
106
107    use super::{Ident, expand};
108    use quote::quote;
109    use syn::{ItemTrait, parse2};
110
111    #[test]
112    fn generate_double_trait() {
113        let (attr, item) = given(quote! { MyTraitDummy }, quote! { trait MyTrait {} });
114
115        let output = expand(attr, item).unwrap();
116
117        let expected = quote! {
118            trait MyTrait {}
119
120            trait MyTraitDummy {}
121
122            impl<T> MyTrait for T where T: MyTraitDummy {}
123
124            impl MyTraitDummy for double_trait::Dummy {}
125        };
126        assert_eq!(expected.to_string(), output.to_string());
127    }
128
129    #[test]
130    fn forward_visibility() {
131        // Given a public trait
132        let (attr, item) = given(quote! { MyTraitDummy }, quote! { pub trait MyTrait {} });
133
134        // When generating the dummy
135        let output = expand(attr, item).unwrap();
136
137        // Then the generated trait should be public, too
138        let expected = quote! {
139            pub trait MyTrait {}
140
141            pub trait MyTraitDummy {}
142
143            impl<T> MyTrait for T where T: MyTraitDummy {}
144
145            impl MyTraitDummy for double_trait::Dummy {}
146        };
147        assert_eq!(expected.to_string(), output.to_string());
148    }
149
150    #[test]
151    fn forward_method() {
152        // Given a trait with a method
153        let (attr, item) = given(
154            quote! { MyTraitDummy },
155            quote! {
156                trait MyTrait {
157                    fn foobar(&self);
158                }
159            },
160        );
161
162        // When generating the dummy
163        let output = expand(attr, item).unwrap();
164
165        // Then the generated trait should contain that method, too
166        let expected = quote! {
167            trait MyTrait {
168                fn foobar(&self);
169            }
170
171            trait MyTraitDummy {
172                fn foobar (&self) { unimplemented!() }
173            }
174
175            impl<T> MyTrait for T where T: MyTraitDummy {
176                fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,) }
177            }
178
179            impl MyTraitDummy for double_trait::Dummy {}
180        };
181        assert_eq!(expected.to_string(), output.to_string());
182    }
183
184    #[test]
185    fn respect_existing_default_impl() {
186        // Given a method with a default implementation in the original trait
187        let (attr, item) = given(
188            quote! { MyTraitDummy },
189            quote! {
190                pub trait MyTrait {
191                    fn foobar() { println!("Hello Default!") }
192                }
193            },
194        );
195
196        // When generating the dummy
197        let output = expand(attr, item).unwrap();
198
199        // Then the generated trait should not overide the existing default
200        let expected = quote! {
201            pub trait MyTrait {
202                fn foobar() { println!("Hello Default!") }
203            }
204
205            pub trait MyTraitDummy {}
206
207            impl<T> MyTrait for T where T: MyTraitDummy {
208                fn foobar() { <Self as MyTraitDummy>::foobar() }
209            }
210
211            impl MyTraitDummy for double_trait::Dummy {}
212        };
213        assert_eq!(expected.to_string(), output.to_string());
214    }
215
216    #[test]
217    fn forward_async_method() {
218        // Given a trait with a method
219        let (attr, item) = given(
220            quote! { MyTraitDummy },
221            quote! {
222                trait MyTrait {
223                    async fn foobar(&self);
224                }
225            },
226        );
227
228        // When generating the dummy
229        let output = expand(attr, item).unwrap();
230
231        // Then the generated trait should contain that method, too
232        let expected = quote! {
233            trait MyTrait {
234                async fn foobar(&self);
235            }
236
237            trait MyTraitDummy {
238                async fn foobar (&self) { unimplemented!() }
239            }
240
241            impl<T> MyTrait for T where T: MyTraitDummy {
242                async fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,).await }
243            }
244
245            impl MyTraitDummy for double_trait::Dummy {}
246        };
247        assert_eq!(expected.to_string(), output.to_string());
248    }
249
250    fn given(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStream) -> (Ident, ItemTrait) {
251        let attr: Ident = parse2(attr).unwrap();
252        let item: ItemTrait = parse2(item).unwrap();
253        (attr, item)
254    }
255}