enum_assoc/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{
6    parenthesized, parse::Parser, punctuated::Punctuated, spanned::Spanned, Error, FnArg, Result,
7    Token, Variant,
8};
9
10const FUNC_ATTR: &'static str = "func";
11const ASSOC_ATTR: &'static str = "assoc";
12
13#[proc_macro_derive(Assoc, attributes(func, assoc))]
14pub fn derive_assoc(input: TokenStream) -> TokenStream {
15    impl_macro(&syn::parse(input).expect("Failed to parse macro input"))
16        //.map(|t| {println!("{}", quote!(#t)); t})
17        .unwrap_or_else(syn::Error::into_compile_error)
18        .into()
19}
20
21fn impl_macro(ast: &syn::DeriveInput) -> Result<proc_macro2::TokenStream> {
22    let name = &ast.ident;
23    let generics = &ast.generics;
24    let generic_params = &generics.params;
25    let fns = ast
26        .attrs
27        .iter()
28        .filter(|attr| attr.path().is_ident(FUNC_ATTR))
29        .map(|attr| syn::parse2::<DeriveFuncs>(attr.meta.to_token_stream()))
30        .collect::<Result<Vec<DeriveFuncs>>>()?;
31    let variants: Vec<&Variant> = if let syn::Data::Enum(data) = &ast.data {
32        data.variants.iter().collect()
33    } else {
34        panic!("#[derive(Assoc)] only applicable to enums")
35    };
36    let functions: Vec<proc_macro2::TokenStream> = fns
37        .into_iter()
38        .flat_map(|DeriveFuncs(funcs)| {
39            funcs
40                .iter()
41                .map(|func| build_function(&variants, func, funcs.clone()))
42                .collect::<Vec<_>>()
43        })
44        .collect::<Result<Vec<proc_macro2::TokenStream>>>()?;
45    Ok(quote! {
46        impl <#generic_params> #name #generics
47        {
48            #(#functions)*
49        }
50    }
51    .into())
52}
53
54fn build_function(
55    variants: &[&Variant],
56    func: &DeriveFunc,
57    associated_funcs: Vec<DeriveFunc>,
58) -> Result<proc_macro2::TokenStream> {
59    let vis = &func.vis;
60    let sig = &func.sig;
61    // has_self determines whether or not this a reverse assoc
62    let has_self = match func.sig.inputs.first() {
63        Some(FnArg::Receiver(_)) => true,
64        Some(FnArg::Typed(pat_type)) => {
65            let pat = &pat_type.pat;
66            quote!(#pat).to_string().trim() == "self"
67        }
68        None => false,
69    };
70    let is_option = if let syn::ReturnType::Type(_, ty) = &func.sig.output {
71        let s = quote!(#ty).to_string();
72        let trimmed = s.trim();
73        trimmed.starts_with("Option") && trimmed.len() > 6 && trimmed[6..].trim().starts_with("<")
74    } else {
75        false
76    };
77    let mut arms = variants
78        .iter()
79        .map(|variant| {
80            build_variant_arm(
81                variant,
82                &func.sig.ident,
83                associated_funcs.iter().map(|func| func.sig.ident.clone()),
84                is_option,
85                has_self,
86                &func.def,
87            )
88        })
89        .collect::<Result<Vec<(proc_macro2::TokenStream, Wildcard)>>>()?;
90    if is_option
91        && !arms
92            .iter()
93            .any(|(_, wildcard)| matches!(wildcard, Wildcard::True))
94    {
95        arms.push((quote!(_ => None,), Wildcard::True))
96    }
97    // make sure wildcards are last
98    if has_self == false {
99        arms.sort_by(|(_, wildcard1), (_, wildcard2)| wildcard1.cmp(wildcard2));
100    }
101    let arms = arms.into_iter().map(|(toks, _)| toks);
102    let match_on = if has_self {
103        quote!(self)
104    } else if func.sig.inputs.is_empty() {
105        return Err(syn::Error::new(func.span, "Missing parameter"));
106    } else {
107        let mut result = quote!();
108        for input in &func.sig.inputs {
109            match input {
110                FnArg::Receiver(_) => {
111                    result = quote!(self);
112                    break;
113                }
114                FnArg::Typed(pat_type) => {
115                    let pat = &pat_type.pat;
116                    result = if result.is_empty() {
117                        quote!(#pat)
118                    } else {
119                        quote!(#result, #pat)
120                    };
121                }
122            }
123        }
124        if func.sig.inputs.len() > 1 {
125            result = quote!((#result));
126        }
127        result
128    };
129    Ok(quote! {
130        #vis #sig
131        {
132            match #match_on
133            {
134                #(#arms)*
135            }
136        }
137    })
138}
139
140fn build_variant_arm(
141    variant: &Variant,
142    func: &syn::Ident,
143    mut assoc_funcs: impl Iterator<Item = syn::Ident>,
144    is_option: bool,
145    has_self: bool,
146    def: &Option<proc_macro2::TokenStream>,
147) -> Result<(proc_macro2::TokenStream, Wildcard)> {
148    // Partially parse associations
149    let assocs = Association::get_variant_assocs(variant, !has_self).filter(|assoc| {
150        assoc.func == *func || assoc_funcs.any(|assoc_func| assoc_func == assoc.func)
151    });
152    if has_self {
153        build_fwd_assoc(assocs, variant, is_option, func, def)
154    } else {
155        build_rev_assoc(assocs, variant, is_option)
156    }
157}
158
159fn build_fwd_assoc(
160    assocs: impl Iterator<Item = Association>,
161    variant: &Variant,
162    is_option: bool,
163    func_ident: &syn::Ident,
164    def: &Option<proc_macro2::TokenStream>,
165) -> Result<(proc_macro2::TokenStream, Wildcard)> {
166    let var_ident = &variant.ident;
167    let fields = match &variant.fields {
168        syn::Fields::Named(fields) => {
169            let named = fields
170                .named
171                .iter()
172                .map(|f| {
173                    let ident = &f.ident;
174                    let val: &Option<proc_macro2::Ident> = &f.ident.as_ref().map(|s| {
175                        proc_macro2::Ident::new(
176                            &("_".to_string() + &s.to_string()),
177                            f.span().clone(),
178                        )
179                    });
180                    quote!(#ident: #val)
181                })
182                .collect::<Vec<proc_macro2::TokenStream>>();
183            quote!({#(#named),*})
184        }
185        syn::Fields::Unnamed(fields) => {
186            let unnamed = fields
187                .unnamed
188                .iter()
189                .enumerate()
190                .map(|(i, f)| {
191                    let ident = proc_macro2::Ident::new(
192                        &("_".to_string() + &i.to_string()),
193                        f.span().clone(),
194                    );
195                    quote!(#ident)
196                })
197                .collect::<Vec<proc_macro2::TokenStream>>();
198            quote!((#(#unnamed),*))
199        }
200        _ => quote!(),
201    };
202    let assocs = assocs
203        .filter_map(|assoc| {
204            if let AssociationType::Forward(expr) = assoc.assoc {
205                Some(Ok(expr))
206            } else {
207                None
208            }
209        })
210        .collect::<Result<Vec<syn::Expr>>>()?;
211    match assocs.len() {
212        0 => {
213            if let Some(tokens) = def {
214                Ok(quote! { Self::#var_ident #fields => #tokens, })
215            } else if is_option {
216                Ok(quote! { Self::#var_ident #fields => None, })
217            } else {
218                Err(Error::new_spanned(
219                    variant,
220                    format!("Missing `assoc` attribute for {}", func_ident),
221                ))
222            }
223        }
224        1 => {
225            let val = &assocs[0];
226            if is_option {
227                if quote!(#val).to_string().trim() == "None" {
228                    Ok(quote! { Self::#var_ident #fields => #val, })
229                } else {
230                    Ok(quote! { Self::#var_ident #fields => Some(#val), })
231                }
232            } else {
233                Ok(quote! { Self::#var_ident #fields => #val, })
234            }
235        }
236        _ => Err(Error::new_spanned(
237            variant,
238            format!("Too many `assoc` attributes for {}", func_ident),
239        )),
240    }
241    .map(|toks| (toks, Wildcard::None))
242}
243
244fn build_rev_assoc(
245    assocs: impl Iterator<Item = Association>,
246    variant: &Variant,
247    is_option: bool,
248) -> Result<(proc_macro2::TokenStream, Wildcard)> {
249    let var_ident = &variant.ident;
250    let assocs = assocs
251        .filter_map(|assoc| {
252            if let AssociationType::Reverse(pat) = assoc.assoc {
253                Some(Ok(pat))
254            } else {
255                None
256            }
257        })
258        .collect::<Result<Vec<syn::Pat>>>()?;
259    let mut concrete_pats: Vec<proc_macro2::TokenStream> = Vec::new();
260    let mut wildcard_pat: Option<proc_macro2::TokenStream> = None;
261    let mut wildcard_status = Wildcard::False;
262    for pat in assocs.iter() {
263        if !matches!(variant.fields, syn::Fields::Unit) {
264            return Err(Error::new_spanned(
265                variant,
266                "Reverse associations not allowed for tuple or struct-like variants",
267            ));
268        }
269        let arm = if is_option {
270            quote!(#pat => Some(Self::#var_ident),)
271        } else {
272            quote!(#pat => Self::#var_ident,)
273        };
274        if matches!(pat, syn::Pat::Wild(_)) {
275            if wildcard_pat.is_some() {
276                return Err(syn::Error::new_spanned(
277                    pat,
278                    "Only 1 wildcard allowed per reverse association",
279                ));
280            }
281            wildcard_status = Wildcard::True;
282            wildcard_pat = Some(arm);
283        } else {
284            concrete_pats.push(arm);
285        }
286    }
287    if let Some(wildcard_pat) = wildcard_pat {
288        concrete_pats.push(wildcard_pat)
289    }
290    Ok((quote!(#(#concrete_pats) *), wildcard_status))
291}
292
293/// A container for a function parsed within a `func` attribute. Note that the
294/// span of the `func` atribute is included because the syn nodes were
295/// manipulated as a string and have lost therr own span information.
296#[derive(Clone)]
297struct DeriveFunc {
298    vis: syn::Visibility,
299    sig: syn::Signature,
300    span: proc_macro2::Span,
301    def: Option<proc_macro2::TokenStream>,
302}
303
304/// An association. Contains a function ident as well as the actual tokens of
305/// the VALUE (not the variant) of the association.
306struct Association {
307    func: syn::Ident,
308    assoc: AssociationType,
309}
310
311enum AssociationType {
312    Forward(syn::Expr),
313    Reverse(syn::Pat),
314}
315
316/// For reverse associations, this enum keeps track of wldcard patterns. For
317/// forward associations, the value is always set to "None". This is also used
318/// to sort reverse associations appropriately. If more complex sorting is to
319/// be implemented, updating this enum would be the best way to start.
320#[derive(PartialEq, Eq, PartialOrd, Ord)]
321enum Wildcard {
322    False = 0,
323    None = 1,
324    True = 2,
325}
326
327impl syn::parse::Parse for DeriveFunc {
328    /// Parse a function signature from an attribute
329    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
330        let vis = input.parse::<syn::Visibility>()?;
331        let sig = input.parse::<syn::Signature>()?;
332        let def = if let Ok(block) = input.parse::<syn::Block>() {
333            Some(proc_macro2::TokenStream::from(ToTokens::into_token_stream(
334                block,
335            )))
336        } else {
337            None
338        };
339        Ok(DeriveFunc {
340            vis,
341            sig,
342            span: input.span(),
343            def,
344        })
345    }
346}
347
348struct DeriveFuncs(Vec<DeriveFunc>);
349impl syn::parse::Parse for DeriveFuncs {
350    /// Parse a list of function signatures form an attribute
351    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
352        input.step(|cursor| {
353            if let Some((_, next)) = cursor.token_tree() {
354                Ok(((), next))
355            } else {
356                Err(cursor.error("Missing function signature"))
357            }
358        })?;
359        let content;
360        parenthesized!(content in input);
361        Ok(Self(
362            content
363                .parse_terminated(DeriveFunc::parse, Token!(,))
364                .map(|parsed| parsed.into_iter().collect())?,
365        ))
366    }
367}
368
369/// Used to parse forward associations, which are of form Ident = Expr
370struct ForwardAssocTokens(syn::Ident, syn::Expr);
371impl syn::parse::Parse for ForwardAssocTokens {
372    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
373        let ident = input.parse()?;
374        input.parse::<syn::Token!(=)>()?;
375        let expr = input.parse()?;
376        Ok(Self(ident, expr))
377    }
378}
379
380/// Used to parse reverse associations, which are of form Ident = Pat
381struct ReverseAssocTokens(syn::Ident, syn::Pat);
382impl syn::parse::Parse for ReverseAssocTokens {
383    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
384        let ident = input.parse()?;
385        input.parse::<syn::Token!(=)>()?;
386        let pat = syn::Pat::parse_multi_with_leading_vert(input)?;
387        Ok(Self(ident, pat))
388    }
389}
390
391impl Into<Association> for ForwardAssocTokens {
392    fn into(self) -> Association {
393        Association {
394            func: self.0,
395            assoc: AssociationType::Forward(self.1),
396        }
397    }
398}
399
400impl Into<Association> for ReverseAssocTokens {
401    fn into(self) -> Association {
402        Association {
403            func: self.0,
404            assoc: AssociationType::Reverse(self.1),
405        }
406    }
407}
408
409impl Association {
410    fn get_variant_assocs(variant: &Variant, is_reverse: bool) -> impl Iterator<Item = Self> + '_ {
411        variant
412            .attrs
413            .iter()
414            .filter(|assoc_attr| assoc_attr.path().is_ident(ASSOC_ATTR))
415            .filter_map(move |attr| {
416                if let syn::Meta::List(meta_list) = &attr.meta {
417                    if is_reverse {
418                        let parser = Punctuated::<ReverseAssocTokens, Token![,]>::parse_terminated;
419                        parser
420                            .parse2(meta_list.tokens.clone())
421                            .map(|tokens| {
422                                tokens
423                                    .into_iter()
424                                    .map(|tokens| tokens.into())
425                                    .collect::<Vec<Self>>()
426                            })
427                            .ok()
428                    } else {
429                        let parser = Punctuated::<ForwardAssocTokens, Token![,]>::parse_terminated;
430                        parser
431                            .parse2(meta_list.tokens.clone())
432                            .map(|tokens| {
433                                tokens
434                                    .into_iter()
435                                    .map(|tokens| tokens.into())
436                                    .collect::<Vec<Self>>()
437                            })
438                            .ok()
439                    }
440                } else {
441                    None
442                }
443            })
444            .flat_map(std::convert::identity)
445    }
446}