auto_trait/
lib.rs

1//!Automatic trait extension macro for wrapper types
2#![warn(missing_docs)]
3#![allow(clippy::style)]
4
5use proc_macro::TokenStream;
6
7use quote::quote;
8
9struct TypeInfo {
10    ident: syn::Ident,
11    generics: Option<syn::AngleBracketedGenericArguments>,
12    reference: Option<syn::Lifetime>,
13    mutability: Option<syn::token::Mut>
14}
15
16fn generate_self_trait_bound(generic_name: syn::Ident, trait_name: &syn::Ident) -> syn::GenericArgument {
17    let mut segments = syn::punctuated::Punctuated::new();
18    segments.push(syn::PathSegment {
19        ident: trait_name.clone(),
20        arguments: syn::PathArguments::None,
21    });
22
23    let mut bounds = syn::punctuated::Punctuated::new();
24    bounds.push(syn::TypeParamBound::Trait(syn::TraitBound {
25        paren_token: None,
26        modifier: syn::TraitBoundModifier::None,
27        lifetimes: None,
28        path: syn::Path {
29            leading_colon: None,
30            segments
31        }
32    }));
33    syn::GenericArgument::Constraint(syn::Constraint {
34        ident: generic_name,
35        generics: None,
36        colon_token: syn::Token![:](proc_macro2::Span::call_site()),
37        bounds
38    })
39}
40
41fn extract_type(typ: &mut syn::Type, trait_name: &syn::Ident, deref_type: &mut Option<syn::Ident>) -> Result<TypeInfo, TokenStream> {
42    match typ {
43        syn::Type::Path(ref mut typ) => {
44            let ident = match typ.path.segments.first() {
45                Some(path) => path.ident.clone(),
46                None => return Err(syn::Error::new_spanned(typ, "Type has no path segments").to_compile_error().into()),
47            };
48
49            match typ.path.segments.last_mut().expect("To have at least on type path segment").arguments {
50                syn::PathArguments::AngleBracketed(ref mut args) => {
51                    let result = args.clone();
52
53                    for arg in args.args.iter_mut() {
54                        if let syn::GenericArgument::Constraint(constraint) = arg {
55
56                            for param in constraint.bounds.iter() {
57                                if let syn::TypeParamBound::Trait(bound) = param {
58                                    if bound.path.is_ident(trait_name) {
59                                        if let Some(ident) = deref_type.replace(constraint.ident.clone()) {
60                                            return Err(syn::Error::new_spanned(ident, "Multiple bounds to trait, can be problematic so how about no?").to_compile_error().into());
61                                        }
62                                    }
63                                }
64                            }
65
66                            let mut segments = syn::punctuated::Punctuated::new();
67                            segments.push(syn::PathSegment {
68                                ident: constraint.ident.clone(),
69                                arguments: syn::PathArguments::None
70                            });
71
72                            *arg = syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
73                                qself: None,
74                                path: syn::Path {
75                                    leading_colon: None,
76                                    segments
77                                },
78                            }));
79                        }
80                    }
81
82                    //if deref_type.is_none() && result.args.len() == 1 {
83                    //    result.args.last_mut();
84                    //}
85
86                    Ok(TypeInfo {
87                        ident,
88                        generics: Some(result),
89                        reference: None,
90                        mutability: None,
91                    })
92                },
93                syn::PathArguments::None => Ok(TypeInfo {
94                    ident,
95                    generics: None,
96                    reference: None,
97                    mutability: None,
98                }),
99                syn::PathArguments::Parenthesized(ref args) => Err(syn::Error::new_spanned(args, "Unsupported type arguments").to_compile_error().into()),
100            }
101        },
102        syn::Type::Reference(reference) => match extract_type(&mut reference.elem, trait_name, deref_type) {
103            Ok(mut result) => {
104                result.mutability = reference.mutability;
105                result.reference = reference.lifetime.clone();
106                Ok(result)
107            },
108            Err(error) => Err(error),
109        }
110        other => Err(syn::Error::new_spanned(other, "Unsupported type").to_compile_error().into()),
111    }
112}
113
114///Generates trait implementation for specified type, relying on `Deref` or `Into` depending on
115///whether `self` is reference or owned
116///
117///Note that this crate is only needed due to lack of specialization that would allow to have
118///generic implementation over `T: Deref<Target=O>`
119///
120///## Example
121///
122///```rust
123///use auto_trait::auto_trait;
124///pub struct Wrapper(u32);
125///
126///impl Into<u32> for Wrapper {
127///    fn into(self) -> u32 {
128///        self.0
129///    }
130///}
131///
132///impl core::ops::Deref for Wrapper {
133///    type Target = u32;
134///    fn deref(&self) -> &Self::Target {
135///        &self.0
136///    }
137///}
138///
139///impl core::ops::DerefMut for Wrapper {
140///    fn deref_mut(&mut self) -> &mut Self::Target {
141///        &mut self.0
142///    }
143///}
144///
145///#[auto_trait(Wrapper)]
146///pub trait Lolka3 {
147///}
148///
149///impl Lolka3 for u32 {}
150///
151///#[auto_trait(Box<T: Lolka2>)]
152///#[auto_trait(Wrapper)]
153///#[auto_trait(&'a mut R)]
154///pub trait Lolka2 {
155///   fn lolka2_ref(&self) -> u32;
156///   fn lolka2_mut(&mut self) -> u32;
157///}
158///
159///impl Lolka2 for u32 {
160///   fn lolka2_ref(&self) -> u32 {
161///       10
162///   }
163///   fn lolka2_mut(&mut self) -> u32 {
164///       11
165///   }
166///}
167///
168///#[auto_trait(Box<T: Lolka + From<Box<T>>>)]
169///pub trait Lolka {
170///   fn lolka() -> u32;
171///
172///   fn lolka_ref(&self) -> u32;
173///
174///   fn lolka_mut(&mut self) -> u32;
175///
176///   fn lolka_self(self) -> u32;
177///}
178///
179///impl Lolka for u32 {
180///   fn lolka() -> u32 {
181///       1
182///   }
183///
184///   fn lolka_ref(&self) -> u32 {
185///       2
186///   }
187///
188///   fn lolka_mut(&mut self) -> u32 {
189///       3
190///   }
191///
192///   fn lolka_self(self) -> u32 {
193///       4
194///   }
195///
196///}
197///
198///let mut lolka = 0u32;
199///let mut wrapped = Box::new(lolka);
200///
201///assert_eq!(lolka.lolka_ref(), wrapped.lolka_ref());
202///assert_eq!(lolka.lolka_mut(), wrapped.lolka_mut());
203///assert_eq!(lolka.lolka_self(), wrapped.lolka_self());
204///
205///assert_eq!(lolka.lolka2_ref(), wrapped.lolka2_ref());
206///assert_eq!(lolka.lolka2_mut(), wrapped.lolka2_mut());
207///
208///assert_eq!(lolka.lolka2_ref(), (&mut lolka).lolka2_ref());
209///assert_eq!(lolka.lolka2_mut(), (&mut lolka).lolka2_mut());
210///```
211#[proc_macro_attribute]
212pub fn auto_trait(args: TokenStream, input: TokenStream) -> TokenStream {
213    let mut input = syn::parse_macro_input!(input as syn::ItemTrait);
214    let args: syn::Type = match syn::parse(args) {
215        Ok(args) => args,
216        Err(error) => {
217            return syn::Error::new(error.span(), "Argument is required and must be a type").to_compile_error().into()
218        }
219    };
220
221    let mut args = vec![args];
222
223    //We need to remove attributes that we're going to parse
224    let mut remaining_attrs = Vec::new();
225    for attr in input.attrs.drain(..) {
226        if attr.path().is_ident("auto_trait") {
227            match attr.parse_args() {
228                Ok(arg) => match arg {
229                    syn::Type::Paren(arg) => args.push(*arg.elem),
230                    arg => args.push(arg),
231                },
232                Err(error) => {
233                    return syn::Error::new(error.span(), "Argument is required and must be a type").to_compile_error().into()
234                }
235            }
236        } else {
237            remaining_attrs.push(attr)
238        }
239    }
240    input.attrs = remaining_attrs;
241
242    let mut impls = Vec::new();
243
244    for mut args in args.drain(..) {
245        let trait_name = input.ident.clone();
246        let mut deref_type = None;
247        let type_info = match extract_type(&mut args, &trait_name, &mut deref_type) {
248            Ok(type_info) => type_info,
249            Err(error) => return error,
250        };
251
252        let deref_name = deref_type.unwrap_or_else(|| trait_name.clone());
253
254        let mut methods = Vec::new();
255
256        for item in input.items.iter() {
257            match item {
258                syn::TraitItem::Fn(ref method) => {
259                    let method_name = method.sig.ident.clone();
260                    let mut method_args = Vec::new();
261                    for arg in method.sig.inputs.iter() {
262                        match arg {
263                            syn::FnArg::Receiver(arg) => {
264                                if arg.reference.is_some() {
265                                    if arg.mutability.is_some() {
266                                        if type_info.reference.is_some() {
267                                            method_args.push(quote! {
268                                                &mut **self
269                                            })
270                                        } else {
271                                            method_args.push(quote! {
272                                                core::ops::DerefMut::deref_mut(self)
273                                            })
274                                        }
275                                    } else {
276                                        if type_info.reference.is_some() {
277                                            method_args.push(quote! {
278                                                &**self
279                                            })
280                                        } else {
281                                            method_args.push(quote! {
282                                                core::ops::Deref::deref(self)
283                                            })
284                                        }
285                                    }
286                                } else {
287                                    method_args.push(quote! {
288                                        self.into()
289                                    })
290                                }
291                            },
292                            syn::FnArg::Typed(arg) => {
293                                let name = &arg.pat;
294                                method_args.push(quote! {
295                                    #name
296                                })
297                            },
298                        }
299                    }
300
301                    let deref_block: syn::Block = syn::parse2(quote! {
302                        {
303                            #deref_name::#method_name(#(#method_args,)*)
304                        }
305                    }).unwrap();
306
307                    let mut method = method.clone();
308                    method.default = Some(deref_block);
309                    method.semi_token = None;
310
311                    methods.push(method);
312                },
313                unsupported => return syn::Error::new_spanned(unsupported, "Trait contains non-method definitions which is unsupported").to_compile_error().into(),
314
315            }
316        }
317
318        let type_generics = if let Some(lifetime) = type_info.reference {
319            match type_info.generics {
320                Some(mut generics) => {
321                    let mut new_args = syn::punctuated::Punctuated::new();
322                    new_args.insert(0, generate_self_trait_bound(type_info.ident, &trait_name));
323                    new_args.insert(0, syn::GenericArgument::Lifetime(lifetime));
324                    while let Some(arg) = generics.args.pop() {
325                        new_args.push(arg.into_tuple().0);
326                    }
327                    generics.args = new_args;
328                    Some(generics)
329                },
330                None => {
331                    let mut args = syn::punctuated::Punctuated::new();
332                    args.push(syn::GenericArgument::Lifetime(lifetime));
333                    args.push(generate_self_trait_bound(type_info.ident, &trait_name));
334
335                    Some(syn::AngleBracketedGenericArguments {
336                        colon2_token: None,
337                        lt_token: syn::Token![<](proc_macro2::Span::call_site()),
338                        args,
339                        gt_token: syn::Token![>](proc_macro2::Span::call_site()),
340                    })
341                }
342            }
343        } else {
344            type_info.generics
345        };
346
347        impls.push(quote! {
348            impl#type_generics #trait_name for #args {
349                #(
350                    #methods
351                )*
352            }
353        });
354    }
355
356    let mut result = quote! {
357        #input
358    };
359    result.extend(impls.drain(..));
360
361    result.into()
362}