syn_helpers/
derive.rs

1use std::{collections::HashMap, iter::FromIterator};
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{
6    punctuated::Punctuated, visit_mut::VisitMut, Data, GenericParam, Ident, Path, PathSegment,
7};
8
9use crate::{
10    dyn_error_to_compile_error_tokens, generic_helpers,
11    generic_param_to_generic_argument_token_stream, generic_parameters_have_same_name, model::Item,
12    syn_fields_to_fields, EnumStructure, EnumVariant, Field, StructStructure, Structure, Trait,
13    TraitItem,
14};
15
16/// Same as [derive_trait] function, but handles parsing the [TokenStream]
17pub fn derive_trait_from_token_stream(tokens: TokenStream, r#trait: Trait) -> TokenStream {
18    let result = syn::parse2::<syn::DeriveInput>(tokens);
19    match result {
20        Ok(input) => derive_trait(input, r#trait),
21        Err(_) => quote! { compile_error!("Invalid input") },
22    }
23}
24
25/// Creates an impl block for the respective trait over the item
26///
27/// Handles a few complex things
28/// - Collisions with the names of generics on the trait and the structure
29/// - Creating where clauses for referenced items
30pub fn derive_trait(item: syn::DeriveInput, r#trait: Trait) -> TokenStream {
31    let syn::DeriveInput {
32        generics: mut structure_generics,
33        attrs,
34        data,
35        ident: structure_name,
36        ..
37    } = item;
38
39    let Trait {
40        name: trait_name,
41        generic_parameters: trait_generic_parameters,
42        items,
43    } = r#trait;
44
45    let mut generic_conflicts_map = HashMap::new();
46
47    let trait_with_arguments: TokenStream = if let Some(trait_generic_parameters) =
48        trait_generic_parameters.as_ref()
49    {
50        // Rename clashing trait names
51        if structure_generics.lt_token.is_some() {
52            for structure_generic_parameter in structure_generics.params.iter_mut() {
53                let collision = trait_generic_parameters
54                    .iter()
55                    .any(|trait_generic_parameter| {
56                        generic_parameters_have_same_name(
57                            trait_generic_parameter,
58                            structure_generic_parameter,
59                        )
60                    });
61
62                if collision {
63                    // Just hope nothing called `_gp...`...
64                    let new_ident = Ident::new(
65                        &format!("_gp{}", generic_conflicts_map.len()),
66                        Span::call_site(),
67                    );
68
69                    let ident: &mut Ident = match structure_generic_parameter {
70                        GenericParam::Type(gtp) => &mut gtp.ident,
71                        GenericParam::Lifetime(glp) => &mut glp.lifetime.ident,
72                        GenericParam::Const(gcp) => &mut gcp.ident,
73                    };
74
75                    generic_conflicts_map.insert(ident.clone(), new_ident.clone());
76                    *ident = new_ident;
77                }
78            }
79        }
80
81        // Removes bounds off parameters thus becoming arguments
82        let trait_generic_arguments =
83            trait_generic_parameters
84                .iter()
85                .map(|trait_generic_parameter| {
86                    generic_param_to_generic_argument_token_stream(trait_generic_parameter)
87                });
88
89        quote!(#trait_name<#(#trait_generic_arguments),*>)
90    } else {
91        quote::ToTokens::to_token_stream(&trait_name)
92    };
93
94    // Combination of structure and trait generics, retains bounds
95    let generics_for_impl =
96        if trait_generic_parameters.is_some() || !structure_generics.params.is_empty() {
97            let mut references = trait_generic_parameters
98                .iter()
99                .flatten()
100                .chain(structure_generics.params.iter())
101                .collect::<Vec<_>>();
102
103            // This is the order in which the AST is defined
104            references.sort_unstable_by_key(|generic| match generic {
105                GenericParam::Lifetime(_) => 0,
106                GenericParam::Type(_) => 1,
107                GenericParam::Const(_) => 2,
108            });
109
110            let token_stream = quote!(<#(#references),*>);
111
112            // Clear bounds for when reprinted
113            structure_generics
114                .type_params_mut()
115                .for_each(|tp| tp.bounds = Default::default());
116
117            Some(token_stream)
118        } else {
119            None
120        };
121
122    // Could be `HashSet` but Rust tolerates duplicate where clause (when that rarely occurs)
123    // Vec ensures consistent order which makes tests easy
124    let mut where_clauses = Vec::new();
125
126    let mut structure = match data {
127        Data::Struct(r#struct) => {
128            let fields = syn_fields_to_fields(r#struct.fields, attrs);
129            Structure::Struct(StructStructure {
130                fields,
131                name: structure_name.clone(),
132            })
133        }
134        Data::Enum(r#enum) => Structure::Enum(EnumStructure {
135            name: structure_name.clone(),
136            attrs,
137            variants: r#enum
138                .variants
139                .into_iter()
140                .enumerate()
141                .map(|(idx, variant)| EnumVariant {
142                    full_path: Path {
143                        leading_colon: None,
144                        segments: Punctuated::from_iter([
145                            PathSegment::from(structure_name.clone()),
146                            PathSegment::from(variant.ident),
147                        ]),
148                    },
149                    idx,
150                    fields: syn_fields_to_fields(variant.fields, variant.attrs),
151                })
152                .collect(),
153        }),
154        Data::Union(_) => {
155            return quote!( compile_error!("syn-helpers does not support derives on unions"); );
156        }
157    };
158
159    let impl_items = items
160        .into_iter()
161        .flat_map(|item| {
162            match item {
163                TraitItem::Method {
164                    name,
165                    generic_parameters,
166                    self_type,
167                    other_parameters,
168                    return_type,
169                    handler,
170                } => {
171                    let result = handler(Item { structure: &mut structure, self_type });
172                    let stmts = match result {
173                        Ok(stmts) => stmts,
174                        Err(err) => {
175                            return dyn_error_to_compile_error_tokens(err);
176                        },
177                    };
178                    let chevroned_generic_params = if let Some(generic_parameters) = generic_parameters {
179                        quote! { <#(#generic_parameters),*> }
180                    } else {
181                        TokenStream::default()
182                    };
183                    let return_type = return_type.iter();
184                    let self_parameter = self_type.as_parameter_tokens();
185                    quote! {
186                        fn #name #chevroned_generic_params(#self_parameter #(,#other_parameters)*) #(-> #return_type)* {
187                            #(#stmts)*
188                        }
189                    }
190                }
191                TraitItem::AssociatedFunction {
192                    name,
193                    generic_parameters,
194                    parameters,
195                    return_type,
196                    handler,
197                } => {
198                    let result = handler(&mut structure);
199                    let stmts = match result {
200                        Ok(stmts) => stmts,
201                        Err(err) => {
202                            return dyn_error_to_compile_error_tokens(err);
203                        },
204                    };
205                    let chevroned_generic_params = if let Some(generic_parameters) = generic_parameters {
206                        quote! { <#(#generic_parameters),*> }
207                    } else {
208                        TokenStream::default()
209                    };
210                    let return_type = return_type.iter();
211                    quote! {
212                        fn #name #chevroned_generic_params(#(#parameters),*) #(-> #return_type)* {
213                            #(#stmts)*
214                        }
215                    }
216                }
217            }
218        }).collect::<TokenStream>();
219
220    {
221        let iter = structure
222            .all_fields()
223            .flat_map(|field| field.get_type_that_needs_constraint())
224            .map(|mut ty| {
225                generic_helpers::RenameGenerics(&generic_conflicts_map).visit_type_mut(&mut ty);
226                ty
227            })
228            .filter(|ty| generic_helpers::ReferencesAGeneric::has_generic(ty, &structure_generics));
229
230        where_clauses.extend(iter);
231    }
232
233    let where_clause: Option<_> = if !where_clauses.is_empty() {
234        Some(quote!(where #( #where_clauses: #trait_with_arguments ),* ))
235    } else {
236        None
237    };
238
239    quote! {
240        #[automatically_derived]
241        impl #generics_for_impl #trait_with_arguments for #structure_name #structure_generics #where_clause {
242            #impl_items
243        }
244    }
245}