impl_tools_lib/
generics.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License in the LICENSE-APACHE file or at:
4//     https://www.apache.org/licenses/LICENSE-2.0
5
6//! Custom version of [`syn`] generics supporting 'X: trait' bound
7
8use proc_macro2::TokenStream;
9use quote::{quote, ToTokens, TokenStreamExt};
10use syn::parse::{Parse, ParseStream, Result};
11use syn::punctuated::{Pair, Punctuated};
12use syn::token;
13use syn::{Attribute, ConstParam, LifetimeParam, PredicateLifetime};
14use syn::{BoundLifetimes, Ident, Lifetime, Token, Type};
15
16/// Lifetimes and type parameters attached an item
17///
18/// This is a custom variant of [`syn::Generics`]
19/// which supports `trait` as a parameter bound.
20#[derive(Debug)]
21pub struct Generics {
22    /// `<`
23    pub lt_token: Option<Token![<]>,
24    /// Parameters
25    pub params: Punctuated<GenericParam, Token![,]>,
26    /// `>`
27    pub gt_token: Option<Token![>]>,
28    /// `where` bounds
29    pub where_clause: Option<WhereClause>,
30}
31
32impl Default for Generics {
33    fn default() -> Self {
34        Generics {
35            lt_token: None,
36            params: Punctuated::new(),
37            gt_token: None,
38            where_clause: None,
39        }
40    }
41}
42
43/// A generic type parameter, lifetime, or const generic
44///
45/// This is a custom variant of [`syn::GenericParam`]
46/// which supports `trait` as a parameter bound.
47#[derive(Debug)]
48#[allow(clippy::large_enum_variant)]
49pub enum GenericParam {
50    /// Type parameter
51    Type(TypeParam),
52    /// Lifetime parameter
53    Lifetime(LifetimeParam),
54    /// `const` parameter
55    Const(ConstParam),
56}
57
58/// A generic type parameter: `T: Into<String>`.
59///
60/// This is a custom variant of [`syn::TypeParam`]
61/// which supports `trait` as a parameter bound.
62#[derive(Debug)]
63pub struct TypeParam {
64    /// Type parameter attributes
65    pub attrs: Vec<Attribute>,
66    /// Type parameter identifier
67    pub ident: Ident,
68    /// `:`
69    pub colon_token: Option<Token![:]>,
70    /// List of type bounds
71    pub bounds: Punctuated<TypeParamBound, Token![+]>,
72    /// `=`
73    pub eq_token: Option<Token![=]>,
74    /// Optional default value
75    pub default: Option<Type>,
76}
77
78/// A trait or lifetime used as a bound on a type parameter.
79///
80/// This is a superset of [`syn::TypeParamBound`].
81#[derive(Debug)]
82pub enum TypeParamBound {
83    /// `trait` used as a bound (substituted for the trait name by [`ToTokensSubst`])
84    TraitSubst(Token![trait]),
85    /// Everything else
86    Other(syn::TypeParamBound),
87}
88
89/// A `where` clause in a definition: `where T: Deserialize<'de>, D: 'static`.
90///
91/// This is a custom variant of [`syn::WhereClause`]
92/// which supports `trait` as a parameter bound.
93#[derive(Debug)]
94pub struct WhereClause {
95    /// `where`
96    pub where_token: Token![where],
97    /// Parameter bounds
98    pub predicates: Punctuated<WherePredicate, Token![,]>,
99}
100
101/// A single predicate in a `where` clause: `T: Deserialize<'de>`.
102///
103/// This is a custom variant of [`syn::WherePredicate`]
104/// which supports `trait` as a parameter bound.
105#[allow(clippy::large_enum_variant)]
106#[derive(Debug)]
107pub enum WherePredicate {
108    /// A type predicate in a `where` clause: `for<'c> Foo<'c>: Trait<'c>`.
109    Type(PredicateType),
110
111    /// A lifetime predicate in a `where` clause: `'a: 'b + 'c`.
112    Lifetime(PredicateLifetime),
113}
114
115/// A type predicate in a `where` clause: `for<'c> Foo<'c>: Trait<'c>`.
116///
117/// This is a custom variant of [`syn::PredicateType`]
118/// which supports `trait` as a parameter bound.
119#[derive(Debug)]
120pub struct PredicateType {
121    /// Any lifetimes from a `for` binding
122    pub lifetimes: Option<BoundLifetimes>,
123    /// The type being bounded
124    pub bounded_ty: Type,
125    /// `:` before bounds
126    pub colon_token: Token![:],
127    /// Trait and lifetime bounds (`Clone+Send+'static`)
128    pub bounds: Punctuated<TypeParamBound, Token![+]>,
129}
130
131mod parsing {
132    use super::*;
133    use syn::ext::IdentExt;
134
135    impl Parse for Generics {
136        fn parse(input: ParseStream) -> Result<Self> {
137            if !input.peek(Token![<]) {
138                return Ok(Generics::default());
139            }
140
141            let lt_token: Token![<] = input.parse()?;
142
143            let mut params = Punctuated::new();
144            loop {
145                if input.peek(Token![>]) {
146                    break;
147                }
148
149                let attrs = input.call(Attribute::parse_outer)?;
150                let lookahead = input.lookahead1();
151                if lookahead.peek(Lifetime) {
152                    params.push_value(GenericParam::Lifetime(LifetimeParam {
153                        attrs,
154                        ..input.parse()?
155                    }));
156                } else if lookahead.peek(Ident) {
157                    params.push_value(GenericParam::Type(TypeParam {
158                        attrs,
159                        ..input.parse()?
160                    }));
161                } else if lookahead.peek(Token![const]) {
162                    params.push_value(GenericParam::Const(ConstParam {
163                        attrs,
164                        ..input.parse()?
165                    }));
166                } else if input.peek(Token![_]) {
167                    params.push_value(GenericParam::Type(TypeParam {
168                        attrs,
169                        ident: input.call(Ident::parse_any)?,
170                        colon_token: None,
171                        bounds: Punctuated::new(),
172                        eq_token: None,
173                        default: None,
174                    }));
175                } else {
176                    return Err(lookahead.error());
177                }
178
179                if input.peek(Token![>]) {
180                    break;
181                }
182                let punct = input.parse()?;
183                params.push_punct(punct);
184            }
185
186            let gt_token: Token![>] = input.parse()?;
187
188            Ok(Generics {
189                lt_token: Some(lt_token),
190                params,
191                gt_token: Some(gt_token),
192                where_clause: None,
193            })
194        }
195    }
196
197    impl Parse for TypeParam {
198        fn parse(input: ParseStream) -> Result<Self> {
199            let attrs = input.call(Attribute::parse_outer)?;
200            let ident: Ident = input.parse()?;
201            let colon_token: Option<Token![:]> = input.parse()?;
202
203            let mut bounds = Punctuated::new();
204            if colon_token.is_some() {
205                loop {
206                    if input.peek(Token![,]) || input.peek(Token![>]) || input.peek(Token![=]) {
207                        break;
208                    }
209                    let value: TypeParamBound = input.parse()?;
210                    bounds.push_value(value);
211                    if !input.peek(Token![+]) {
212                        break;
213                    }
214                    let punct: Token![+] = input.parse()?;
215                    bounds.push_punct(punct);
216                }
217            }
218
219            let eq_token: Option<Token![=]> = input.parse()?;
220            let default = if eq_token.is_some() {
221                Some(input.parse::<Type>()?)
222            } else {
223                None
224            };
225
226            Ok(TypeParam {
227                attrs,
228                ident,
229                colon_token,
230                bounds,
231                eq_token,
232                default,
233            })
234        }
235    }
236
237    impl Parse for TypeParamBound {
238        fn parse(input: ParseStream) -> Result<Self> {
239            if input.peek(Token![trait]) {
240                input.parse().map(TypeParamBound::TraitSubst)
241            } else {
242                syn::TypeParamBound::parse(input).map(TypeParamBound::Other)
243            }
244        }
245    }
246
247    impl Parse for WhereClause {
248        fn parse(input: ParseStream) -> Result<Self> {
249            Ok(WhereClause {
250                where_token: input.parse()?,
251                predicates: {
252                    let mut predicates = Punctuated::new();
253                    loop {
254                        if input.is_empty()
255                            || input.peek(token::Brace)
256                            || input.peek(Token![,])
257                            || input.peek(Token![;])
258                            || input.peek(Token![:]) && !input.peek(Token![::])
259                            || input.peek(Token![=])
260                        {
261                            break;
262                        }
263                        let value = input.parse()?;
264                        predicates.push_value(value);
265                        if !input.peek(Token![,]) {
266                            break;
267                        }
268                        let punct = input.parse()?;
269                        predicates.push_punct(punct);
270                    }
271                    predicates
272                },
273            })
274        }
275    }
276
277    impl Parse for WherePredicate {
278        fn parse(input: ParseStream) -> Result<Self> {
279            if input.peek(Lifetime) && input.peek2(Token![:]) {
280                Ok(WherePredicate::Lifetime(PredicateLifetime {
281                    lifetime: input.parse()?,
282                    colon_token: input.parse()?,
283                    bounds: {
284                        let mut bounds = Punctuated::new();
285                        loop {
286                            if input.is_empty()
287                                || input.peek(token::Brace)
288                                || input.peek(Token![,])
289                                || input.peek(Token![;])
290                                || input.peek(Token![:])
291                                || input.peek(Token![=])
292                            {
293                                break;
294                            }
295                            let value = input.parse()?;
296                            bounds.push_value(value);
297                            if !input.peek(Token![+]) {
298                                break;
299                            }
300                            let punct = input.parse()?;
301                            bounds.push_punct(punct);
302                        }
303                        bounds
304                    },
305                }))
306            } else {
307                Ok(WherePredicate::Type(PredicateType {
308                    lifetimes: input.parse()?,
309                    bounded_ty: input.parse()?,
310                    colon_token: input.parse()?,
311                    bounds: {
312                        let mut bounds = Punctuated::new();
313                        loop {
314                            if input.is_empty()
315                                || input.peek(token::Brace)
316                                || input.peek(Token![,])
317                                || input.peek(Token![;])
318                                || input.peek(Token![:]) && !input.peek(Token![::])
319                                || input.peek(Token![=])
320                            {
321                                break;
322                            }
323                            let value = input.parse()?;
324                            bounds.push_value(value);
325                            if !input.peek(Token![+]) {
326                                break;
327                            }
328                            let punct = input.parse()?;
329                            bounds.push_punct(punct);
330                        }
331                        bounds
332                    },
333                }))
334            }
335        }
336    }
337}
338
339/// Tokenization trait with substitution
340///
341/// This is similar to [`quote::ToTokens`], but replaces instances of `trait`
342/// as a parameter bound with `subst`.
343pub trait ToTokensSubst {
344    /// Write `self` to the given [`TokenStream`]
345    fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream);
346}
347
348mod printing_subst {
349    use super::*;
350    use syn::AttrStyle;
351
352    impl ToTokensSubst for Generics {
353        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
354            if self.params.is_empty() {
355                return;
356            }
357
358            self.lt_token.unwrap_or_default().to_tokens(tokens);
359
360            let mut trailing_or_empty = true;
361            for pair in self.params.pairs() {
362                if let GenericParam::Lifetime(param) = *pair.value() {
363                    param.to_tokens(tokens);
364                    pair.punct().to_tokens(tokens);
365                    trailing_or_empty = pair.punct().is_some();
366                }
367            }
368            for pair in self.params.pairs() {
369                match *pair.value() {
370                    GenericParam::Type(param) => {
371                        if !trailing_or_empty {
372                            <Token![,]>::default().to_tokens(tokens);
373                            trailing_or_empty = true;
374                        }
375                        param.to_tokens_subst(tokens, subst);
376                        pair.punct().to_tokens(tokens);
377                    }
378                    GenericParam::Const(param) => {
379                        if !trailing_or_empty {
380                            <Token![,]>::default().to_tokens(tokens);
381                            trailing_or_empty = true;
382                        }
383                        param.to_tokens(tokens);
384                        pair.punct().to_tokens(tokens);
385                    }
386                    GenericParam::Lifetime(_) => {}
387                }
388            }
389
390            self.gt_token.unwrap_or_default().to_tokens(tokens);
391        }
392    }
393
394    impl ToTokensSubst for TypeParam {
395        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
396            tokens.append_all(
397                self.attrs
398                    .iter()
399                    .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
400            );
401            self.ident.to_tokens(tokens);
402            if !self.bounds.is_empty() {
403                self.colon_token.unwrap_or_default().to_tokens(tokens);
404                self.bounds.to_tokens_subst(tokens, subst);
405            }
406            if let Some(default) = &self.default {
407                self.eq_token.unwrap_or_default().to_tokens(tokens);
408                default.to_tokens(tokens);
409            }
410        }
411    }
412
413    impl ToTokensSubst for WhereClause {
414        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
415            if !self.predicates.is_empty() {
416                self.where_token.to_tokens(tokens);
417                self.predicates.to_tokens_subst(tokens, subst);
418            }
419        }
420    }
421
422    impl<T, P> ToTokensSubst for Punctuated<T, P>
423    where
424        T: ToTokensSubst,
425        P: ToTokens,
426    {
427        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
428            for pair in self.pairs() {
429                pair.value().to_tokens_subst(tokens, subst);
430                if let Some(punct) = pair.punct() {
431                    punct.to_tokens(tokens);
432                }
433            }
434        }
435    }
436
437    impl ToTokensSubst for WherePredicate {
438        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
439            match self {
440                WherePredicate::Type(ty) => ty.to_tokens_subst(tokens, subst),
441                WherePredicate::Lifetime(lt) => lt.to_tokens(tokens),
442            }
443        }
444    }
445
446    impl ToTokensSubst for PredicateType {
447        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
448            self.lifetimes.to_tokens(tokens);
449            self.bounded_ty.to_tokens(tokens);
450            self.colon_token.to_tokens(tokens);
451            self.bounds.to_tokens_subst(tokens, subst);
452        }
453    }
454
455    impl ToTokensSubst for TypeParamBound {
456        fn to_tokens_subst(&self, tokens: &mut TokenStream, subst: &TokenStream) {
457            match self {
458                TypeParamBound::TraitSubst(_) => tokens.append_all(quote! { #subst }),
459                TypeParamBound::Other(bound) => bound.to_tokens(tokens),
460            }
461        }
462    }
463}
464
465fn map_type_param_bound(bound: &syn::TypeParamBound) -> TypeParamBound {
466    TypeParamBound::Other(bound.clone())
467}
468
469fn map_generic_param(param: &syn::GenericParam) -> GenericParam {
470    match param {
471        syn::GenericParam::Type(ty) => GenericParam::Type(TypeParam {
472            attrs: ty.attrs.clone(),
473            ident: ty.ident.clone(),
474            colon_token: ty.colon_token,
475            bounds: Punctuated::from_iter(
476                ty.bounds
477                    .pairs()
478                    .map(|pair| map_pair(pair, map_type_param_bound)),
479            ),
480            eq_token: ty.eq_token,
481            default: ty.default.clone(),
482        }),
483        syn::GenericParam::Lifetime(lt) => GenericParam::Lifetime(lt.clone()),
484        syn::GenericParam::Const(c) => GenericParam::Const(c.clone()),
485    }
486}
487
488fn map_pair<U, V, P: Clone, F: Fn(U) -> V>(pair: Pair<U, &P>, f: F) -> Pair<V, P> {
489    match pair {
490        Pair::Punctuated(value, punct) => Pair::Punctuated(f(value), punct.clone()),
491        Pair::End(value) => Pair::End(f(value)),
492    }
493}
494
495impl Generics {
496    /// Generate (`impl_generics`, `where_clause`) tokens
497    ///
498    /// Combines generics from `self` and `item_generics`.
499    ///
500    /// This is the equivalent of the first and third items output by
501    /// [`syn::Generics::split_for_impl`]. Any instance of `trait` as a parameter
502    /// bound is replaced by `subst`.
503    ///
504    /// Note: use `ty_generics` from [`syn::Generics::split_for_impl`] or
505    /// [`Self::ty_generics`] as appropriate.
506    pub fn impl_generics(
507        mut self,
508        item_generics: &syn::Generics,
509        subst: &TokenStream,
510    ) -> (TokenStream, TokenStream) {
511        let mut impl_generics = quote! {};
512        if self.params.is_empty() {
513            item_generics.to_tokens(&mut impl_generics);
514        } else {
515            if !self.params.empty_or_trailing() {
516                self.params.push_punct(Default::default());
517            }
518            self.params.extend(
519                item_generics
520                    .params
521                    .pairs()
522                    .map(|pair| map_pair(pair, map_generic_param)),
523            );
524            self.to_tokens_subst(&mut impl_generics, subst);
525        }
526
527        let where_clause = clause_to_toks(
528            &self.where_clause,
529            item_generics.where_clause.as_ref(),
530            subst,
531        );
532        (impl_generics, where_clause)
533    }
534    /// Generate `ty_generics` tokens
535    ///
536    /// Combines generics from `self` and `item_generics`.
537    ///
538    /// This is the equivalent to the second item output by
539    /// [`syn::Generics::split_for_impl`].
540    pub fn ty_generics(&self, item_generics: &syn::Generics) -> TokenStream {
541        let mut toks = TokenStream::new();
542        let tokens = &mut toks;
543
544        if self.params.is_empty() {
545            let (_, ty_generics, _) = item_generics.split_for_impl();
546            ty_generics.to_tokens(tokens);
547            return toks;
548        }
549
550        self.lt_token.unwrap_or_default().to_tokens(tokens);
551
552        // Print lifetimes before types and consts (see syn impl)
553        for (def, punct) in self
554            .params
555            .pairs()
556            .filter_map(|param| {
557                if let GenericParam::Lifetime(def) = *param.value() {
558                    Some((def, param.punct().map(|p| **p).unwrap_or_default()))
559                } else {
560                    None
561                }
562            })
563            .chain(item_generics.params.pairs().filter_map(|param| {
564                if let syn::GenericParam::Lifetime(def) = *param.value() {
565                    Some((def, param.punct().map(|p| **p).unwrap_or_default()))
566                } else {
567                    None
568                }
569            }))
570        {
571            // Leave off the lifetime bounds and attributes
572            def.lifetime.to_tokens(tokens);
573            punct.to_tokens(tokens);
574        }
575
576        for param in self.params.pairs() {
577            match *param.value() {
578                GenericParam::Lifetime(_) => continue,
579                GenericParam::Type(param) => {
580                    // Leave off the type parameter defaults
581                    param.ident.to_tokens(tokens);
582                }
583                GenericParam::Const(param) => {
584                    // Leave off the const parameter defaults
585                    param.ident.to_tokens(tokens);
586                }
587            }
588            param
589                .punct()
590                .map(|p| **p)
591                .unwrap_or_default()
592                .to_tokens(tokens);
593        }
594        for param in item_generics.params.pairs() {
595            match *param.value() {
596                syn::GenericParam::Lifetime(_) => continue,
597                syn::GenericParam::Type(param) => {
598                    // Leave off the type parameter defaults
599                    param.ident.to_tokens(tokens);
600                }
601                syn::GenericParam::Const(param) => {
602                    // Leave off the const parameter defaults
603                    param.ident.to_tokens(tokens);
604                }
605            }
606            param
607                .punct()
608                .map(|p| **p)
609                .unwrap_or_default()
610                .to_tokens(tokens);
611        }
612
613        self.gt_token.unwrap_or_default().to_tokens(tokens);
614
615        toks
616    }
617}
618
619/// Generate a `where_clause`
620///
621/// This merges a [`WhereClause`] with a [`syn::WhereClause`], replacing any
622/// instance of `trait` as a parameter bound in `wc` with `subst`.
623pub fn clause_to_toks(
624    wc: &Option<WhereClause>,
625    item_wc: Option<&syn::WhereClause>,
626    subst: &TokenStream,
627) -> TokenStream {
628    match (wc, item_wc) {
629        (None, None) => quote! {},
630        (Some(wc), None) => {
631            let mut toks = quote! {};
632            wc.to_tokens_subst(&mut toks, subst);
633            toks
634        }
635        (None, Some(wc)) => quote! { #wc },
636        (Some(wc), Some(item_wc)) => {
637            let mut toks = quote! { #item_wc };
638            if !item_wc.predicates.empty_or_trailing() {
639                toks.append_all(quote! { , });
640            }
641            wc.predicates.to_tokens_subst(&mut toks, subst);
642            toks
643        }
644    }
645}