double_derive/
lib.rs

1use quote::quote;
2use syn::{Error, Ident, ItemTrait, parse_macro_input};
3
4mod double_trait;
5mod dummy_impl;
6mod trait_impl;
7
8use self::{double_trait::double_trait, trait_impl::trait_impl};
9
10/// Generates a trait which replicates the original trait method for method. It does implement the
11/// original trait for each of its implementations, by means of forwarding the method calls. The
12/// utility comes from the fact that the generated trait has default implementations for each method
13/// using `unimplemented!()`, which makes it useful for testing purposes.
14///
15/// If a test requires an implementation of an original trait `Org` yet would only invoke one of its
16/// methods, implementing the mirrored method on an implementation of the generated trait `OrgDummy`
17/// is sufficient. The other methods would not be inovked in the test, so their default
18/// implementation using `unimplemented!()` would not be reached.
19///
20/// The argument passed to the attribute is used as the name of the generated trait.
21///
22/// * Existing default implementations are respected and not overridden.
23/// * Visibility of the generated trait is the same as the original trait.
24/// * `async` methods are supported
25/// * Methods returning `impl` Traits are not supported, with the exception of `impl Future`.
26/// * Generated double trait is implemented for `Dummy`.
27///
28/// # Example
29///
30/// Basic usage allows creating test stubs for traits, without worrying about implementing methods
31/// not called in test code
32///
33/// ```no_run
34/// use double_trait::double;
35///
36/// #[double(MyTraitDouble)]
37/// trait MyTrait {
38///    fn answer(&self) -> i32;
39///
40///    fn some_other_method(&self);
41/// }
42///  
43/// struct MyStub;
44///
45/// impl MyTraitDouble for MyStub {
46///     fn answer(&self) -> i32 {
47///         42
48///     }
49/// }
50///
51/// assert_eq!(42, MyTrait::answer(&MyStub));
52/// ```
53///
54/// Then interacting with the `async_trait` crate, make sure to put the `#[async_trait]` attribute
55/// on top.
56///
57/// ```no_run
58/// use double_trait::double;
59/// use async_trait::async_trait;
60///
61/// #[async_trait]
62/// #[double(MyTraitDouble)]
63/// trait MyTrait {
64///     async fn answer(&self) -> i32;
65/// }
66/// ```
67#[proc_macro_attribute]
68pub fn double(
69    attr: proc_macro::TokenStream,
70    item: proc_macro::TokenStream,
71) -> proc_macro::TokenStream {
72    let double_name = parse_macro_input!(attr as Ident);
73    let item = parse_macro_input!(item as ItemTrait);
74
75    let output = expand(double_name, item).unwrap_or_else(Error::into_compile_error);
76
77    proc_macro::TokenStream::from(output)
78}
79
80/// The main implementation of [`crate::double`]. This function is not annotated with
81/// `#[proc_macro_attribute]` so it can exist in unit tests. It uses only APIs build on top of
82/// [`proc_macro2`] in order to be unit testable.
83fn expand(double_trait_name: Ident, org_trait: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
84    let double_trait = double_trait(double_trait_name.clone(), org_trait.clone())?;
85    let trait_impl = trait_impl(double_trait_name.clone(), org_trait.clone());
86    let dummy_impl = dummy_impl::dummy_impl(double_trait_name, org_trait.clone());
87
88    // We generate three items as part of our output.
89    // 1. The orginal trait, which we put in the output unaltered.
90    // 2. The double trait, we genarate, which mirrors the original traits methods and provides
91    //    default implementations using `unimplemented!()`.
92    // 3. An implementation of the original trait for all types which implement the double trait.
93    //    This is done by forwarding the method calls to the double trait.
94    let token_stream = quote! {
95        #org_trait
96
97        #double_trait
98
99        #trait_impl
100
101        #dummy_impl
102    };
103    Ok(token_stream)
104}
105
106#[cfg(test)]
107mod tests {
108
109    use super::{Ident, expand};
110    use quote::quote;
111    use syn::{ItemTrait, parse2};
112
113    #[test]
114    fn generate_double_trait() {
115        let (attr, item) = given(quote! { MyTraitDummy }, quote! { trait MyTrait {} });
116
117        let output = expand(attr, item).unwrap();
118
119        let expected = quote! {
120            trait MyTrait {}
121
122            trait MyTraitDummy {}
123
124            impl<T> MyTrait for T where T: MyTraitDummy {}
125
126            impl MyTraitDummy for double_trait::Dummy {}
127        };
128        assert_eq!(expected.to_string(), output.to_string());
129    }
130
131    #[test]
132    fn forward_visibility() {
133        // Given a public trait
134        let (attr, item) = given(quote! { MyTraitDummy }, quote! { pub trait MyTrait {} });
135
136        // When generating the dummy
137        let output = expand(attr, item).unwrap();
138
139        // Then the generated trait should be public, too
140        let expected = quote! {
141            pub trait MyTrait {}
142
143            pub trait MyTraitDummy {}
144
145            impl<T> MyTrait for T where T: MyTraitDummy {}
146
147            impl MyTraitDummy for double_trait::Dummy {}
148        };
149        assert_eq!(expected.to_string(), output.to_string());
150    }
151
152    #[test]
153    fn forward_method() {
154        // Given a trait with a method
155        let (attr, item) = given(
156            quote! { MyTraitDummy },
157            quote! {
158                trait MyTrait {
159                    fn foobar(&self);
160                }
161            },
162        );
163
164        // When generating the dummy
165        let output = expand(attr, item).unwrap();
166
167        // Then the generated trait should contain that method, too
168        let expected = quote! {
169            trait MyTrait {
170                fn foobar(&self);
171            }
172
173            trait MyTraitDummy {
174                fn foobar (&self) {
175                    let double_trait_name = stringify!(MyTraitDummy);
176                    let fn_name = stringify!(foobar);
177                    unimplemented!("{double_trait_name}::{fn_name}")
178                }
179            }
180
181            impl<T> MyTrait for T where T: MyTraitDummy {
182                fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,) }
183            }
184
185            impl MyTraitDummy for double_trait::Dummy {}
186        };
187        assert_eq!(expected.to_string(), output.to_string());
188    }
189
190    #[test]
191    fn respect_existing_default_impl() {
192        // Given a method with a default implementation in the original trait
193        let (attr, item) = given(
194            quote! { MyTraitDummy },
195            quote! {
196                pub trait MyTrait {
197                    fn foobar() { println!("Hello Default!") }
198                }
199            },
200        );
201
202        // When generating the dummy
203        let output = expand(attr, item).unwrap();
204
205        // Then the generated trait should not overide the existing default
206        let expected = quote! {
207            pub trait MyTrait {
208                fn foobar() { println!("Hello Default!") }
209            }
210
211            pub trait MyTraitDummy {
212                fn foobar() { println!("Hello Default!") }
213            }
214
215            impl<T> MyTrait for T where T: MyTraitDummy {
216                fn foobar() { <Self as MyTraitDummy>::foobar() }
217            }
218
219            impl MyTraitDummy for double_trait::Dummy {}
220        };
221        assert_eq!(expected.to_string(), output.to_string());
222    }
223
224    #[test]
225    fn forward_async_method() {
226        // Given a trait with a method
227        let (attr, item) = given(
228            quote! { MyTraitDummy },
229            quote! {
230                trait MyTrait {
231                    async fn foobar(&self);
232                }
233            },
234        );
235
236        // When generating the dummy
237        let output = expand(attr, item).unwrap();
238
239        // Then the generated trait should contain that method, too
240        let expected = quote! {
241            trait MyTrait {
242                async fn foobar(&self);
243            }
244
245            trait MyTraitDummy {
246                async fn foobar (&self) {
247                    let double_trait_name = stringify!(MyTraitDummy);
248                    let fn_name = stringify!(foobar);
249                    unimplemented!("{double_trait_name}::{fn_name}")
250                }
251            }
252
253            impl<T> MyTrait for T where T: MyTraitDummy {
254                async fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,).await }
255            }
256
257            impl MyTraitDummy for double_trait::Dummy {}
258        };
259        assert_eq!(expected.to_string(), output.to_string());
260    }
261
262    fn given(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStream) -> (Ident, ItemTrait) {
263        let attr: Ident = parse2(attr).unwrap();
264        let item: ItemTrait = parse2(item).unwrap();
265        (attr, item)
266    }
267}