Skip to main content

desugar_impl/
lib.rs

1// Copyright (c) 2021 René Kijewski <rene.[SURNAME]@fu-berlin.de>
2// All rights reserved.
3//
4// This software and the accompanying materials are made available under
5// the terms of the ISC License which is available in the project root as LICENSE-ISC, AND/OR
6// the terms of the MIT License which is available in the project root as LICENSE-MIT, AND/OR
7// the terms of the Apache License, Version 2.0 which is available in the project root as LICENSE-APACHE.
8//
9// You have to accept AT LEAST one of the aforementioned licenses to use, copy, modify, and/or distribute this software.
10// At your will you may redistribute the software under the terms of only one, two, or all three of the aforementioned licenses.
11
12#![forbid(unsafe_code)]
13#![deny(missing_docs)]
14
15//! ## `impl Trait` not allowed outside of function and method return types
16//!
17//! **… but it is now!**
18//!
19//! This library gives you one macro, and one macro only: [`#[desugar_impl]`][macro@desugar_impl].
20//!
21//! Annotate any struct, enum, or union with [`#[desugar_impl]`][macro@desugar_impl]
22//! to allow the use of `field_name: impl SomeTrait` in their declaration. E.g.
23//!
24//! ```
25//! #[desugar_impl::desugar_impl]
26//! struct Test {
27//!     a: impl Clone + PartialOrd,
28//!     b: impl Clone + PartialOrd,
29//!     c: impl Copy,
30//! }
31//! ```
32//!
33//! desugars to
34//!
35//! ```
36//! struct Test<Ty1, Ty2, Ty3>
37//! where
38//!     Ty1: Clone + PartialOrd,
39//!     Ty2: Clone + PartialOrd,
40//!     Ty3: Copy,
41//! {
42//!     a: Ty1,
43//!     b: Ty2,
44//!     c: Ty3,
45//! }
46//! ```
47//!
48//! You can still place any `#[derive(…)]` macros just below `#[desugar_impl]` any they see work
49//! with the desugared code.
50
51use std::iter::FromIterator;
52
53use proc_macro::{Span, TokenStream};
54use quote::quote;
55use syn::punctuated::Punctuated;
56use syn::token::{Colon, Where};
57use syn::{
58    parse_macro_input, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields,
59    FieldsNamed, FieldsUnnamed, GenericParam, Ident, Path, PathArguments, PathSegment,
60    PredicateType, Type, TypeImplTrait, TypeParam, TypePath, WhereClause, WherePredicate,
61};
62
63/// Desugar `impl Trait` fields in a struct, enum, or union declaration.
64///
65/// Please see the library documentation for an explanation: [desugar_impl](index.html).
66#[proc_macro_attribute]
67pub fn desugar_impl(_: TokenStream, item: TokenStream) -> TokenStream {
68    let mut ast = parse_macro_input!(item as DeriveInput);
69    let mut ty_index = 1;
70
71    let ast_generics = &mut ast.generics;
72    let ast_data = &mut ast.data;
73
74    let mut convert_fields = |fields: &mut Punctuated<_, _>| {
75        for Field { ty, .. } in fields {
76            if let Type::ImplTrait(TypeImplTrait { bounds, .. }) = ty {
77                let type_ident = format!("Ty{}", ty_index);
78                ty_index += 1;
79                let type_ident = Ident::new(&type_ident, Span::call_site().into());
80                let type_path = Type::Path(TypePath {
81                    qself: None,
82                    path: Path {
83                        leading_colon: None,
84                        segments: Punctuated::from_iter([PathSegment {
85                            ident: type_ident.clone(),
86                            arguments: PathArguments::None,
87                        }]),
88                    },
89                });
90
91                let predicate = WherePredicate::Type(PredicateType {
92                    lifetimes: None,
93                    bounded_ty: type_path.clone(),
94                    colon_token: Colon::default(),
95                    bounds: bounds.clone(),
96                });
97                match &mut ast_generics.where_clause {
98                    Some(where_clause) => {
99                        where_clause.predicates.push(predicate);
100                    }
101                    where_clause @ None => {
102                        *where_clause = Some(WhereClause {
103                            where_token: Where::default(),
104                            predicates: Punctuated::from_iter([predicate]),
105                        });
106                    }
107                }
108
109                ast_generics.params.push(GenericParam::Type(TypeParam {
110                    attrs: Vec::new(),
111                    ident: type_ident,
112                    colon_token: None,
113                    bounds: Default::default(),
114                    eq_token: None,
115                    default: None,
116                }));
117
118                *ty = type_path;
119            }
120        }
121    };
122
123    let mut convert_some_fields = |fields: &mut Fields| match fields {
124        Fields::Named(FieldsNamed { named: fields, .. })
125        | Fields::Unnamed(FieldsUnnamed {
126            unnamed: fields, ..
127        }) => {
128            convert_fields(fields);
129        }
130        Fields::Unit => {}
131    };
132
133    match ast_data {
134        Data::Struct(DataStruct { fields, .. }) => {
135            convert_some_fields(fields);
136        }
137        Data::Union(DataUnion {
138            fields: FieldsNamed { named: fields, .. },
139            ..
140        }) => {
141            convert_fields(fields);
142        }
143        Data::Enum(DataEnum { variants, .. }) => {
144            for variant in variants {
145                convert_some_fields(&mut variant.fields);
146            }
147        }
148    }
149
150    TokenStream::from(quote! { #ast })
151}