derive_trait/
lib.rs

1//! Derive a trait and a delegating impl from an inherent impl block.
2//!
3//! # Why go the opposite way?
4//! This macro is designated for single generic types with many small impl blocks
5//! and complex type bounds in each impl block.
6//!
7//! - Without a trait, the function user needs to repeat all the type bounds in the impl block
8//!   in every function that requests a type supporting the associated functions.
9//! - Without a macro, the function author needs to write each function signature four times
10//!   (the trait, the inherent impl, the trait impl and delegation)
11//!   and the type bounds twice.
12//! - With the `#[inherent]` macro, the function author would still need to write twice
13//!  (the trait and the trait impl).
14//!
15//! Note that use of thsi crate is only advisable for impl blocks with complicated type bounds.
16//! It is not advisable to create single-implementor traits blindly.
17
18use std::collections::HashMap;
19
20use heck::ToPascalCase;
21use itertools::Itertools;
22use proc_macro2::{Span, TokenStream};
23use quote::quote;
24use syn::parse::{Parse, ParseStream};
25use syn::punctuated::Punctuated;
26use syn::spanned::Spanned;
27use syn::{parse_quote, parse_quote_spanned, Error, Result};
28
29#[proc_macro_attribute]
30pub fn derive_trait(
31    attr: proc_macro::TokenStream,
32    item: proc_macro::TokenStream,
33) -> proc_macro::TokenStream {
34    real_derive_trait(attr.into(), item.into()).unwrap_or_else(Error::into_compile_error).into()
35}
36
37fn real_derive_trait(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
38    let attr: Attr = syn::parse2(attr)?;
39    let Attr { debug_print, vis, trait_ident, supers, generics: trait_generics, fixed_assoc_types } =
40        &attr;
41    let debug_print = debug_print.is_some();
42    let (_, trait_generics_ref, _) = trait_generics.split_for_impl();
43
44    let inherent_impl: syn::ItemImpl = syn::parse2(item)?;
45
46    let mut item_trait = syn::ItemTrait {
47        attrs:       Vec::new(),
48        vis:         vis.clone(),
49        unsafety:    None,
50        auto_token:  None,
51        restriction: None,
52        trait_token: syn::Token![trait](Span::call_site()),
53        ident:       trait_ident.clone(),
54        generics:    trait_generics.clone(),
55        colon_token: supers.as_ref().map(|&(colon, _)| colon),
56        supertraits: supers.clone().map(|(_, supers)| supers).unwrap_or_default(),
57        brace_token: syn::token::Brace(Span::call_site()),
58        items:       Vec::new(),
59    };
60
61    let mut item_impl = syn::ItemImpl {
62        attrs:       Vec::new(),
63        unsafety:    None,
64        defaultness: None,
65        impl_token:  syn::Token![impl](Span::call_site()),
66        generics:    inherent_impl.generics.clone(),
67        trait_:      Some((
68            None,
69            syn::parse_quote!(#trait_ident #trait_generics_ref),
70            syn::Token![for](Span::call_site()),
71        )),
72        self_ty:     inherent_impl.self_ty.clone(),
73        brace_token: syn::token::Brace(Span::call_site()),
74        items:       Vec::new(),
75    };
76
77    let mut ident_assoc_map = HashMap::new();
78
79    for assoc in fixed_assoc_types {
80        if assoc.generics.params.len() > 0 {
81            return Err(syn::Error::new_spanned(&assoc.generics, "GATs here are not supported"));
82        }
83
84        let Some((
85            eq_token,
86            default @ syn::Type::Path(syn::TypePath { qself: None, path: default_ident }),
87        )) = &assoc.default
88        else {
89            return Err(syn::Error::new_spanned(assoc, "expected `type Type: Bounds = Ident;`"));
90        };
91        let default_ident = default_ident.require_ident()?;
92
93        ident_assoc_map.insert(default_ident.clone(), assoc.ident.clone());
94
95        item_trait.items.push(syn::TraitItem::Type(syn::TraitItemType {
96            attrs:       assoc.attrs.clone(),
97            type_token:  assoc.type_token,
98            ident:       assoc.ident.clone(),
99            generics:    assoc.generics.clone(),
100            colon_token: assoc.colon_token,
101            bounds:      assoc.bounds.clone(),
102            default:     None,
103            semi_token:  assoc.semi_token,
104        }));
105
106        item_impl.items.push(syn::ImplItem::Type(syn::ImplItemType {
107            attrs:       Vec::new(),
108            vis:         syn::Visibility::Inherited,
109            defaultness: None,
110            type_token:  assoc.type_token,
111            ident:       assoc.ident.clone(),
112            generics:    assoc.generics.clone(),
113            eq_token:    *eq_token,
114            ty:          default.clone(),
115            semi_token:  assoc.semi_token,
116        }));
117    }
118
119    struct ReplaceIdentVisitor<'t>(&'t HashMap<syn::Ident, syn::Ident>);
120    impl<'t> syn::visit_mut::VisitMut for ReplaceIdentVisitor<'t> {
121        fn visit_type_path_mut(&mut self, type_path: &mut syn::TypePath) {
122            if let Some(ident) = type_path.path.segments.first() {
123                if let Some(target) = self.0.get(&ident.ident) {
124                    let mut segments: Vec<_> =
125                        type_path.path.segments.clone().into_pairs().collect();
126                    segments.insert(
127                        0,
128                        syn::punctuated::Pair::Punctuated(
129                            syn::PathSegment {
130                                ident:     syn::Ident::new("Self", ident.span()),
131                                arguments: syn::PathArguments::None,
132                            },
133                            syn::Token![::](ident.span()),
134                        ),
135                    );
136                    *segments[1].value_mut() = syn::PathSegment {
137                        ident:     target.clone(),
138                        arguments: syn::PathArguments::None,
139                    };
140                    type_path.path.segments = segments.into_iter().collect();
141                }
142            }
143
144            if let Some(qself) = &mut type_path.qself {
145                self.visit_type_mut(&mut qself.ty);
146            }
147            self.visit_path_mut(&mut type_path.path);
148        }
149    }
150    let mut replace_ident_visitor = ReplaceIdentVisitor(&ident_assoc_map);
151
152    let self_ty = &*inherent_impl.self_ty;
153
154    for item in &inherent_impl.items {
155        match item {
156            syn::ImplItem::Fn(item) => {
157                let mut sig = item.sig.clone();
158                let sig_span = sig.span();
159
160                if let syn::ReturnType::Type(r_arrow, ret_ty) = &sig.output {
161                    let transformed = for_each_impl_trait(ret_ty, &mut |tit| {
162                        let span = tit.span();
163
164                        // convert return-position-impl-trait into associated-type-impl-trait
165                        let assoc_ident = item.sig.ident.to_string().to_pascal_case();
166                        let assoc_ident = syn::Ident::new(&assoc_ident, span);
167                        let ty_bounds = &tit.bounds;
168
169                        // for now, assume all generic parameters are required.
170                        // we cannot infer whether the signature involves an implicit lifetime,
171                        // so for simplicity we require all implicit lifetimes to be explicitly
172                        // documented for now.
173
174                        let (assoc_generics, assoc_generics_names, assoc_where) = if sig
175                            .generics
176                            .params
177                            .is_empty()
178                        {
179                            (None, None, None)
180                        } else {
181                            let (sig_impl_generics, sig_ty_generics, sig_where_generics) =
182                                sig.generics.split_for_impl();
183                            let mut sig_impl_generics: syn::Generics =
184                                syn::parse_quote!(#sig_impl_generics);
185                            let mut sig_ty_generics: syn::AngleBracketedGenericArguments =
186                                syn::parse_quote!(#sig_ty_generics);
187                            let mut sig_where_generics = sig_where_generics.cloned();
188
189                            if let Some(recv) = sig.receiver() {
190                                if let Some((and, lt)) = &recv.reference {
191                                    let lt = match lt {
192                                        Some(lt) => lt.clone(),
193                                        None => {
194                                            let lt: syn::Lifetime =
195                                                syn::parse_quote_spanned!(and.span() => '__self);
196                                            sig_impl_generics.params.push(
197                                                syn::GenericParam::Lifetime(parse_quote!(#lt)),
198                                            );
199                                            sig_ty_generics
200                                                .args
201                                                .push(syn::parse_quote_spanned!(and.span() => '_));
202                                            lt
203                                        }
204                                    };
205                                    let where_predicate: syn::WherePredicate =
206                                        syn::parse_quote_spanned!(and.span() => Self: #lt);
207                                    sig_where_generics.get_or_insert(syn::WhereClause {
208                                        where_token: syn::Token![where](sig_span),
209                                        predicates: Punctuated::new(),
210                                    }).predicates.push(syn::parse_quote_spanned!(and.span() => #where_predicate));
211                                }
212                            }
213
214                            (
215                                Some(quote!(#sig_impl_generics)),
216                                Some(quote!(#sig_ty_generics)),
217                                Some(quote!(#sig_where_generics)),
218                            )
219                        };
220
221                        let assoc_doc = format!(
222                            "Return value for [`{fn_ident}`](Self::{fn_ident})",
223                            fn_ident = &sig.ident
224                        );
225                        let mut trait_item_ty: syn::TraitItemType = parse_quote_spanned! { span =>
226                            #[doc = #assoc_doc]
227                            type #assoc_ident #assoc_generics: #ty_bounds #assoc_where;
228                        };
229                        syn::visit_mut::visit_trait_item_type_mut(
230                            &mut replace_ident_visitor,
231                            &mut trait_item_ty,
232                        );
233                        item_trait.items.push(syn::TraitItem::Type(trait_item_ty));
234                        item_impl.items.push(parse_quote_spanned! { span =>
235                            type #assoc_ident #assoc_generics = #tit #assoc_where;
236                        });
237
238                        parse_quote_spanned! { span =>
239                            Self::#assoc_ident #assoc_generics_names
240                        }
241                    });
242                    sig.output = syn::ReturnType::Type(*r_arrow, Box::new(transformed));
243                }
244
245                let sig_ident = &sig.ident;
246                let sig_args: Vec<syn::Pat> = sig
247                    .inputs
248                    .iter()
249                    .map(|input| match input {
250                        syn::FnArg::Receiver(syn::Receiver { self_token, .. }) => {
251                            parse_quote!(#self_token)
252                        }
253                        syn::FnArg::Typed(typed) => (*typed.pat).clone(),
254                    })
255                    .collect();
256
257                let fn_docs: Vec<_> =
258                    item.attrs.iter().filter(|attr| attr.path().is_ident("doc")).cloned().collect();
259
260                let mut trait_item_fn = syn::TraitItemFn {
261                    attrs:      fn_docs.clone(),
262                    sig:        sig.clone(),
263                    default:    None,
264                    semi_token: Some(syn::Token![;](item.span())),
265                };
266                syn::visit_mut::visit_trait_item_fn_mut(
267                    &mut replace_ident_visitor,
268                    &mut trait_item_fn,
269                );
270                item_trait.items.push(syn::TraitItem::Fn(trait_item_fn));
271
272                item_impl.items.push(syn::ImplItem::Fn(syn::ImplItemFn {
273                    attrs:       fn_docs.clone(),
274                    vis:         syn::Visibility::Inherited,
275                    defaultness: None,
276                    sig:         sig.clone(),
277                    block:       parse_quote_spanned! { item.span() => {
278                        <#self_ty>::#sig_ident(#(#sig_args),*)
279                    }},
280                }));
281            }
282            _ => return Err(Error::new_spanned(item, "only associated functions are supported")),
283        }
284    }
285
286    let trait_item_doc = format!(
287        "Derived trait for [`{}`].",
288        match self_ty {
289            syn::Type::Path(path) =>
290                path.path.segments.iter().map(|ident| ident.ident.to_string()).join("::"),
291            _ => quote!(#self_ty).to_string(),
292        }
293    );
294
295    let output = quote! {
296        #[allow(clippy::needless_lifetimes)]
297        #inherent_impl
298        #[allow(clippy::needless_lifetimes, non_camel_case_types)]
299        #[doc = #trait_item_doc]
300        #item_trait
301        #[automatically_derived]
302        #[allow(clippy::needless_lifetimes, non_camel_case_types)]
303        #item_impl
304    };
305    if debug_print {
306        println!("{}", output);
307    }
308    Ok(output)
309}
310
311fn for_each_impl_trait(
312    ty: &syn::Type,
313    f: &mut impl FnMut(&syn::TypeImplTrait) -> syn::Type,
314) -> syn::Type {
315    match ty {
316        syn::Type::Array(ty) => syn::Type::Array(syn::TypeArray {
317            elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
318            ..ty.clone()
319        }),
320        syn::Type::BareFn(ty) => syn::Type::BareFn(syn::TypeBareFn {
321            inputs: ty
322                .inputs
323                .clone()
324                .into_pairs()
325                .map(|mut pair| {
326                    let value = pair.value_mut();
327                    value.ty = for_each_impl_trait(&value.ty, f);
328                    pair
329                })
330                .collect(),
331            ..ty.clone()
332        }),
333        syn::Type::Group(_) => ty.clone(),
334        syn::Type::ImplTrait(ty) => f(ty),
335        syn::Type::Infer(_) => ty.clone(),
336        syn::Type::Macro(_) => ty.clone(),
337        syn::Type::Never(_) => ty.clone(),
338        syn::Type::Paren(ty) => syn::Type::Paren(syn::TypeParen {
339            elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
340            ..ty.clone()
341        }),
342        syn::Type::Path(ty) => syn::Type::Path(syn::TypePath {
343            qself: ty.qself.clone().map(|mut qself| {
344                qself.ty = Box::new(for_each_impl_trait(&*qself.ty, f));
345                qself
346            }),
347            path:  for_each_impl_trait_in_path(&ty.path, f),
348        }),
349        syn::Type::Ptr(ty) => syn::Type::Ptr(syn::TypePtr {
350            elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
351            ..ty.clone()
352        }),
353        syn::Type::Reference(ty) => syn::Type::Reference(syn::TypeReference {
354            elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
355            ..ty.clone()
356        }),
357        syn::Type::Slice(ty) => syn::Type::Slice(syn::TypeSlice {
358            elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
359            ..ty.clone()
360        }),
361        syn::Type::TraitObject(ty) => syn::Type::TraitObject(syn::TypeTraitObject {
362            bounds: ty
363                .bounds
364                .clone()
365                .into_pairs()
366                .map(|mut pair| {
367                    if let syn::TypeParamBound::Trait(bound) = pair.value_mut() {
368                        bound.path = for_each_impl_trait_in_path(&bound.path, f);
369                    }
370                    pair
371                })
372                .collect(),
373            ..ty.clone()
374        }),
375        syn::Type::Tuple(ty) => syn::Type::Tuple(syn::TypeTuple {
376            elems: ty
377                .elems
378                .clone()
379                .into_pairs()
380                .map(|mut pair| {
381                    let value = pair.value_mut();
382                    *value = for_each_impl_trait(&value, f);
383                    pair
384                })
385                .collect(),
386            ..ty.clone()
387        }),
388        syn::Type::Verbatim(_) => ty.clone(),
389        _ => ty.clone(),
390    }
391}
392
393fn for_each_impl_trait_in_path(
394    path: &syn::Path,
395    f: &mut impl FnMut(&syn::TypeImplTrait) -> syn::Type,
396) -> syn::Path {
397    syn::Path {
398        leading_colon: path.leading_colon,
399        segments:      path
400            .segments
401            .clone()
402            .into_pairs()
403            .map(|mut pair| {
404                let value = pair.value_mut();
405                match &mut value.arguments {
406                    syn::PathArguments::None => {}
407                    syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
408                        args,
409                        ..
410                    }) => {
411                        for pair in args.pairs_mut() {
412                            match pair.into_value() {
413                                syn::GenericArgument::Type(ty) => *ty = for_each_impl_trait(ty, f),
414                                syn::GenericArgument::AssocType(ty) => {
415                                    ty.ty = for_each_impl_trait(&ty.ty, f)
416                                }
417                                _ => {}
418                            }
419                        }
420                    }
421                    syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
422                        inputs,
423                        output,
424                        ..
425                    }) => {
426                        for input in inputs {
427                            *input = for_each_impl_trait(input, f);
428                        }
429                        if let syn::ReturnType::Type(_, ty) = output {
430                            *ty = Box::new(for_each_impl_trait(ty, f));
431                        }
432                    }
433                }
434                pair
435            })
436            .collect(),
437    }
438}
439
440struct Attr {
441    debug_print:       Option<kw::__debug_print>,
442    vis:               syn::Visibility,
443    trait_ident:       syn::Ident,
444    generics:          syn::Generics,
445    supers:            Option<(syn::Token![:], Punctuated<syn::TypeParamBound, syn::Token![+]>)>,
446    fixed_assoc_types: Vec<syn::TraitItemType>,
447}
448
449impl Parse for Attr {
450    fn parse(input: ParseStream) -> Result<Self> {
451        let debug_print = input.parse::<kw::__debug_print>().ok();
452        let vis = input.parse()?;
453        let trait_ident = input.parse()?;
454        let mut generics = syn::Generics::default();
455        let mut supers = None;
456        let mut fixed_assoc_types = Vec::new();
457
458        while !input.is_empty() {
459            let lh = input.lookahead1();
460            if generics.lt_token.is_none() && lh.peek(syn::Token![<]) {
461                generics = input.parse()?;
462            } else if lh.peek(syn::Token![:]) {
463                supers = Some((input.parse()?, Punctuated::parse_separated_nonempty(input)?));
464            } else if !generics.params.is_empty() && lh.peek(syn::Token![where]) {
465                generics.where_clause = Some(input.parse()?);
466            } else if lh.peek(syn::token::Brace) {
467                let inner;
468                _ = syn::braced!(inner in input);
469                while !inner.is_empty() {
470                    fixed_assoc_types.push(inner.parse()?);
471                }
472            } else {
473                return Err(lh.error());
474            }
475        }
476
477        Ok(Self { debug_print, vis, trait_ident, supers, generics, fixed_assoc_types })
478    }
479}
480
481mod kw {
482    use syn::custom_keyword;
483
484    custom_keyword!(Sized);
485    custom_keyword!(__debug_print);
486}