double_derive/
lib.rs

1use quote::quote;
2use syn::{Error, Ident, ItemTrait, parse_macro_input};
3
4mod double_trait;
5mod trait_impl;
6
7use self::{double_trait::double_trait, trait_impl::trait_impl};
8
9/// Generates a trait which replicates the original trait method for method. It does implement the
10/// original trait for each of its implementations, by means of forwarding the method calls. The
11/// utility comes from the fact that the generated trait has default implementations for each method
12/// using `unimplemented!()`, which makes it useful for testing purposes.
13///
14/// If a test requires an implementation of an original trait `Org` yet would only invoke one of its
15/// methods, implementing the mirrored method on an implementation of the generated trait `OrgDummy`
16/// is sufficient. The other methods would not be inovked in the test, so their default
17/// implementation using `unimplemented!()` would not be reached.
18///
19/// The argument passed to the attribute is used as the name of the generated trait.
20///
21/// * Existing default implementations are respected and not overridden.
22/// * Visibility of the generated trait is the same as the original trait.
23/// * `async` methods are supported
24/// * Methods returning `impl` Traits are not supported, with the exception of `impl Future`.
25/// * Generated double trait is implemented for `Dummy`.
26#[proc_macro_attribute]
27pub fn double(
28    attr: proc_macro::TokenStream,
29    item: proc_macro::TokenStream,
30) -> proc_macro::TokenStream {
31    let double_name = parse_macro_input!(attr as Ident);
32    let item = parse_macro_input!(item as ItemTrait);
33
34    let output = expand(double_name, item).unwrap_or_else(Error::into_compile_error);
35
36    proc_macro::TokenStream::from(output)
37}
38
39/// The main implementation of [`crate::double`]. This function is not annotated with
40/// `#[proc_macro_attribute]` so it can exist in unit tests. It uses only APIs build on top of
41/// [`proc_macro2`] in order to be unit testable.
42fn expand(double_trait_name: Ident, org_trait: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
43    let double_trait = double_trait(double_trait_name.clone(), org_trait.clone())?;
44    let trait_impl = trait_impl(double_trait_name.clone(), org_trait.clone());
45
46    // We generate three items as part of our output.
47    // 1. The orginal trait, which we put in the output unaltered.
48    // 2. The double trait, we genarate, which mirrors the original traits methods and provides
49    //    default implementations using `unimplemented!()`.
50    // 3. An implementation of the original trait for all types which implement the double trait.
51    //    This is done by forwarding the method calls to the double trait.
52    let token_stream = quote! {
53        #org_trait
54
55        #double_trait
56
57        #trait_impl
58
59        impl #double_trait_name for double_trait::Dummy{}
60    };
61    Ok(token_stream)
62}
63
64#[cfg(test)]
65mod tests {
66
67    use super::{Ident, expand};
68    use quote::quote;
69    use syn::{ItemTrait, parse2};
70
71    #[test]
72    fn generate_double_trait() {
73        let (attr, item) = given(quote! { MyTraitDummy }, quote! { trait MyTrait {} });
74
75        let output = expand(attr, item).unwrap();
76
77        let expected = quote! {
78            trait MyTrait {}
79
80            trait MyTraitDummy {}
81
82            impl<T> MyTrait for T where T: MyTraitDummy {}
83
84            impl MyTraitDummy for double_trait::Dummy {}
85        };
86        assert_eq!(expected.to_string(), output.to_string());
87    }
88
89    #[test]
90    fn forward_visibility() {
91        // Given a public trait
92        let (attr, item) = given(quote! { MyTraitDummy }, quote! { pub trait MyTrait {} });
93
94        // When generating the dummy
95        let output = expand(attr, item).unwrap();
96
97        // Then the generated trait should be public, too
98        let expected = quote! {
99            pub trait MyTrait {}
100
101            pub trait MyTraitDummy {}
102
103            impl<T> MyTrait for T where T: MyTraitDummy {}
104
105            impl MyTraitDummy for double_trait::Dummy {}
106        };
107        assert_eq!(expected.to_string(), output.to_string());
108    }
109
110    #[test]
111    fn forward_method() {
112        // Given a trait with a method
113        let (attr, item) = given(
114            quote! { MyTraitDummy },
115            quote! {
116                trait MyTrait {
117                    fn foobar(&self);
118                }
119            },
120        );
121
122        // When generating the dummy
123        let output = expand(attr, item).unwrap();
124
125        // Then the generated trait should contain that method, too
126        let expected = quote! {
127            trait MyTrait {
128                fn foobar(&self);
129            }
130
131            trait MyTraitDummy {
132                fn foobar (&self) { unimplemented!() }
133            }
134
135            impl<T> MyTrait for T where T: MyTraitDummy {
136                fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,) }
137            }
138
139            impl MyTraitDummy for double_trait::Dummy {}
140        };
141        assert_eq!(expected.to_string(), output.to_string());
142    }
143
144    #[test]
145    fn respect_existing_default_impl() {
146        // Given a method with a default implementation in the original trait
147        let (attr, item) = given(
148            quote! { MyTraitDummy },
149            quote! {
150                pub trait MyTrait {
151                    fn foobar() { println!("Hello Default!") }
152                }
153            },
154        );
155
156        // When generating the dummy
157        let output = expand(attr, item).unwrap();
158
159        // Then the generated trait should not overide the existing default
160        let expected = quote! {
161            pub trait MyTrait {
162                fn foobar() { println!("Hello Default!") }
163            }
164
165            pub trait MyTraitDummy {}
166
167            impl<T> MyTrait for T where T: MyTraitDummy {
168                fn foobar() { <Self as MyTraitDummy>::foobar() }
169            }
170
171            impl MyTraitDummy for double_trait::Dummy {}
172        };
173        assert_eq!(expected.to_string(), output.to_string());
174    }
175
176    #[test]
177    fn forward_async_method() {
178        // Given a trait with a method
179        let (attr, item) = given(
180            quote! { MyTraitDummy },
181            quote! {
182                trait MyTrait {
183                    async fn foobar(&self);
184                }
185            },
186        );
187
188        // When generating the dummy
189        let output = expand(attr, item).unwrap();
190
191        // Then the generated trait should contain that method, too
192        let expected = quote! {
193            trait MyTrait {
194                async fn foobar(&self);
195            }
196
197            trait MyTraitDummy {
198                async fn foobar (&self) { unimplemented!() }
199            }
200
201            impl<T> MyTrait for T where T: MyTraitDummy {
202                async fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,).await }
203            }
204
205            impl MyTraitDummy for double_trait::Dummy {}
206        };
207        assert_eq!(expected.to_string(), output.to_string());
208    }
209
210    fn given(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStream) -> (Ident, ItemTrait) {
211        let attr: Ident = parse2(attr).unwrap();
212        let item: ItemTrait = parse2(item).unwrap();
213        (attr, item)
214    }
215}