double_derive/
lib.rs

1use quote::quote;
2use syn::{Error, Ident, ItemTrait, parse_macro_input};
3
4mod double_trait;
5mod trait_impl;
6
7use self::{double_trait::double_trait, trait_impl::trait_impl};
8
9/// Generates a trait which replicates the original trait method for method. It does implement the
10/// original trait for each of its implementations, by means of forwarding the method calls. The
11/// utility comes from the fact that the generated trait has default implementations for each method
12/// using `unimplemented!()`, which makes it useful for testing purposes.
13///
14/// If a test requires an implementation of an original trait `Org` yet would only invoke one of its
15/// methods, implementing the mirrored method on an implementation of the generated trait `OrgDummy`
16/// is sufficient. The other methods would not be inovked in the test, so their default
17/// implementation using `unimplemented!()` would not be reached.
18///
19/// The argument passed to the attribute is used as the name of the generated trait.
20#[proc_macro_attribute]
21pub fn double(
22    attr: proc_macro::TokenStream,
23    item: proc_macro::TokenStream,
24) -> proc_macro::TokenStream {
25    let double_name = parse_macro_input!(attr as Ident);
26    let item = parse_macro_input!(item as ItemTrait);
27
28    let output = double_impl(double_name, item).unwrap_or_else(Error::into_compile_error);
29
30    proc_macro::TokenStream::from(output)
31}
32
33/// The main implementation of [`crate::double`]. This function is not annotated with
34/// `#[proc_macro_attribute]` so it can exist in unit tests. It uses only APIs build on top of
35/// [`proc_macro2`] in order to be unit testable.
36fn double_impl(
37    double_trait_name: Ident,
38    org_trait: ItemTrait,
39) -> syn::Result<proc_macro2::TokenStream> {
40    let double_trait = double_trait(double_trait_name.clone(), org_trait.clone())?;
41    let trait_impl = trait_impl(double_trait_name, org_trait.clone());
42
43    // We generate three items as part of our output.
44    // 1. The orginal trait, which we put in the output unaltered.
45    // 2. The double trait, we genarate, which mirrors the original traits methods and provides
46    //    default implementations using `unimplemented!()`.
47    // 3. An implementation of the original trait for all types which implement the double trait.
48    //    This is done by forwarding the method calls to the double trait.
49    let token_stream = quote! {
50        #org_trait
51
52        #double_trait
53
54        #trait_impl
55    };
56    Ok(token_stream)
57}
58
59#[cfg(test)]
60mod tests {
61
62    use super::{Ident, double_impl};
63    use quote::quote;
64    use syn::{ItemTrait, parse2};
65
66    #[test]
67    fn generate_double_trait() {
68        let (attr, item) = given(quote! { MyTraitDummy }, quote! { trait MyTrait {} });
69
70        let output = double_impl(attr, item).unwrap();
71
72        let expected = quote! {
73            trait MyTrait {}
74
75            trait MyTraitDummy {}
76
77            impl<T> MyTrait for T where T: MyTraitDummy {}
78        };
79        assert_eq!(expected.to_string(), output.to_string());
80    }
81
82    #[test]
83    fn forward_visibility() {
84        // Given a public trait
85        let (attr, item) = given(quote! { MyTraitDummy }, quote! { pub trait MyTrait {} });
86
87        // When generating the dummy
88        let output = double_impl(attr, item).unwrap();
89
90        // Then the generated trait should be public, too
91        let expected = quote! {
92            pub trait MyTrait {}
93
94            pub trait MyTraitDummy {}
95
96            impl<T> MyTrait for T where T: MyTraitDummy {}
97        };
98        assert_eq!(expected.to_string(), output.to_string());
99    }
100
101    #[test]
102    fn forward_method() {
103        // Given a trait with a method
104        let (attr, item) = given(
105            quote! { MyTraitDummy },
106            quote! {
107                trait MyTrait {
108                    fn foobar(&self);
109                }
110            },
111        );
112
113        // When generating the dummy
114        let output = double_impl(attr, item).unwrap();
115
116        // Then the generated trait should contain that method, too
117        let expected = quote! {
118            trait MyTrait {
119                fn foobar(&self);
120            }
121
122            trait MyTraitDummy {
123                fn foobar (&self) { unimplemented!() }
124            }
125
126            impl<T> MyTrait for T where T: MyTraitDummy {
127                fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,) }
128            }
129        };
130        assert_eq!(expected.to_string(), output.to_string());
131    }
132
133    #[test]
134    fn respect_existing_default_impl() {
135        // Given a method with a default implementation in the original trait
136        let (attr, item) = given(
137            quote! { MyTraitDummy },
138            quote! {
139                pub trait MyTrait {
140                    fn foobar() { println!("Hello Default!") }
141                }
142            },
143        );
144
145        // When generating the dummy
146        let output = double_impl(attr, item).unwrap();
147
148        // Then the generated trait should not overide the existing default
149        let expected = quote! {
150            pub trait MyTrait {
151                fn foobar() { println!("Hello Default!") }
152            }
153
154            pub trait MyTraitDummy {}
155
156            impl<T> MyTrait for T where T: MyTraitDummy {
157                fn foobar() { <Self as MyTraitDummy>::foobar() }
158            }
159        };
160        assert_eq!(expected.to_string(), output.to_string());
161    }
162
163    #[test]
164    fn forward_async_method() {
165        // Given a trait with a method
166        let (attr, item) = given(
167            quote! { MyTraitDummy },
168            quote! {
169                trait MyTrait {
170                    async fn foobar(&self);
171                }
172            },
173        );
174
175        // When generating the dummy
176        let output = double_impl(attr, item).unwrap();
177
178        // Then the generated trait should contain that method, too
179        let expected = quote! {
180            trait MyTrait {
181                async fn foobar(&self);
182            }
183
184            trait MyTraitDummy {
185                async fn foobar (&self) { unimplemented!() }
186            }
187
188            impl<T> MyTrait for T where T: MyTraitDummy {
189                async fn foobar(&self) { <Self as MyTraitDummy>::foobar(self,).await }
190            }
191        };
192        assert_eq!(expected.to_string(), output.to_string());
193    }
194
195    fn given(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStream) -> (Ident, ItemTrait) {
196        let attr: Ident = parse2(attr).unwrap();
197        let item: ItemTrait = parse2(item).unwrap();
198        (attr, item)
199    }
200}