enum_assoc/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{parenthesized, parse::Parser, punctuated::Punctuated, spanned::Spanned, Error, FnArg, Result, Token, Variant};
6
7const FUNC_ATTR: &'static str = "func";
8const ASSOC_ATTR: &'static str = "assoc";
9
10#[proc_macro_derive(Assoc, attributes(func, assoc))]
11pub fn derive_assoc(input: TokenStream) -> TokenStream {
12    impl_macro(&syn::parse(input).expect("Failed to parse macro input"))
13        //.map(|t| {println!("{}", quote!(#t)); t})
14        .unwrap_or_else(syn::Error::into_compile_error)
15        .into()
16}
17
18fn impl_macro(ast: &syn::DeriveInput) -> Result<proc_macro2::TokenStream> {
19    let name = &ast.ident;
20    let generics = &ast.generics;
21    let generic_params = &generics.params;
22    let fns = ast
23        .attrs
24        .iter()
25        .filter(|attr| attr.path().is_ident(FUNC_ATTR))
26        .map(|attr| syn::parse2::<DeriveFunc>(attr.meta.to_token_stream()))
27        .collect::<Result<Vec<DeriveFunc>>>()?;
28    let variants: Vec<&Variant> = if let syn::Data::Enum(data) = &ast.data {
29        data.variants.iter().collect()
30    } else {
31        panic!("#[derive(Assoc)] only applicable to enums")
32    };
33    let functions: Vec<proc_macro2::TokenStream> = fns
34        .into_iter()
35        .map(|func| build_function(&variants, func))
36        .collect::<Result<Vec<proc_macro2::TokenStream>>>()?;
37    Ok(quote! {
38        impl <#generic_params> #name #generics
39        {
40            #(#functions)*
41        }
42    }
43    .into())
44}
45
46fn build_function(variants: &[&Variant], func: DeriveFunc) -> Result<proc_macro2::TokenStream> {
47    let vis = &func.vis;
48    let sig = &func.sig;
49    // has_self determines whether or not this a reverse assoc
50    let has_self = match func.sig.inputs.first() {
51        Some(FnArg::Receiver(_)) => true,
52        Some(FnArg::Typed(pat_type)) => {
53            let pat = &pat_type.pat;
54            quote!(#pat).to_string().trim() == "self"
55        }
56        None => false,
57    };
58    let is_option = if let syn::ReturnType::Type(_, ty) = &func.sig.output {
59        let s = quote!(#ty).to_string();
60        let trimmed = s.trim();
61        trimmed.starts_with("Option") && trimmed.len() > 6 && trimmed[6..].trim().starts_with("<")
62    } else {
63        false
64    };
65    let mut arms = variants
66        .iter()
67        .map(|variant| build_variant_arm(variant, &func.sig.ident, is_option, has_self, &func.def))
68        .collect::<Result<Vec<(proc_macro2::TokenStream, Wildcard)>>>()?;
69    if is_option
70        && !arms
71            .iter()
72            .any(|(_, wildcard)| matches!(wildcard, Wildcard::True))
73    {
74        arms.push((quote!(_ => None,), Wildcard::True))
75    }
76    // make sure wildcards are last
77    if has_self == false {
78        arms.sort_by(|(_, wildcard1), (_, wildcard2)| wildcard1.cmp(wildcard2));
79    }
80    let arms = arms.into_iter().map(|(toks, _)| toks);
81    let match_on = if has_self {
82        quote!(self)
83    } else if func.sig.inputs.is_empty() {
84        return Err(syn::Error::new(func.span, "Missing parameter"));
85    } else {
86        let mut result = quote!();
87        for input in &func.sig.inputs {
88            match input {
89                FnArg::Receiver(_) => {
90                    result = quote!(self);
91                    break;
92                }
93                FnArg::Typed(pat_type) => {
94                    let pat = &pat_type.pat;
95                    result = if result.is_empty() {
96                        quote!(#pat)
97                    } else {
98                        quote!(#result, #pat)
99                    };
100                }
101            }
102        }
103        if func.sig.inputs.len() > 1 {
104            result = quote!((#result));
105        }
106        result
107    };
108    Ok(quote! {
109        #vis #sig
110        {
111            match #match_on
112            {
113                #(#arms)*
114            }
115        }
116    })
117}
118
119fn build_variant_arm(
120    variant: &Variant,
121    func: &syn::Ident,
122    is_option: bool,
123    has_self: bool,
124    def: &Option<proc_macro2::TokenStream>,
125) -> Result<(proc_macro2::TokenStream, Wildcard)> {
126    // Partially parse associations
127    let assocs =
128        Association::get_variant_assocs(variant, !has_self).filter(|assoc| assoc.func == *func);
129    if has_self {
130        build_fwd_assoc(assocs, variant, is_option, func, def)
131    } else {
132        build_rev_assoc(assocs, variant, is_option)
133    }
134}
135
136fn build_fwd_assoc(
137    assocs: impl Iterator<Item = Association>,
138    variant: &Variant,
139    is_option: bool,
140    func_ident: &syn::Ident,
141    def: &Option<proc_macro2::TokenStream>,
142) -> Result<(proc_macro2::TokenStream, Wildcard)> {
143    let var_ident = &variant.ident;
144    let fields = match &variant.fields {
145        syn::Fields::Named(fields) => {
146            let named = fields
147                .named
148                .iter()
149                .map(|f| {
150                    let ident = &f.ident;
151                    let val: &Option<proc_macro2::Ident> = &f.ident.as_ref().map(|s| {
152                        proc_macro2::Ident::new(
153                            &("_".to_string() + &s.to_string()),
154                            f.span().clone(),
155                        )
156                    });
157                    quote!(#ident: #val)
158                })
159                .collect::<Vec<proc_macro2::TokenStream>>();
160            quote!({#(#named),*})
161        }
162        syn::Fields::Unnamed(fields) => {
163            let unnamed = fields
164                .unnamed
165                .iter()
166                .enumerate()
167                .map(|(i, f)| {
168                    let ident = proc_macro2::Ident::new(
169                        &("_".to_string() + &i.to_string()),
170                        f.span().clone(),
171                    );
172                    quote!(#ident)
173                })
174                .collect::<Vec<proc_macro2::TokenStream>>();
175            quote!((#(#unnamed),*))
176        }
177        _ => quote!(),
178    };
179    let assocs = assocs
180        .filter_map(|assoc| 
181        {
182            if let AssociationType::Forward(expr) = assoc.assoc {
183                Some(Ok(expr))
184            } else {
185                None
186            }
187        })
188        .collect::<Result<Vec<syn::Expr>>>()?;
189    match assocs.len() {
190        0 => {
191            if let Some(tokens) = def {
192                Ok(quote! { Self::#var_ident #fields => #tokens, })
193            } else if is_option {
194                Ok(quote! { Self::#var_ident #fields => None, })
195            } else {
196                Err(Error::new_spanned(
197                    variant,
198                    format!("Missing `assoc` attribute for {}", func_ident),
199                ))
200            }
201        }
202        1 => {
203            let val = &assocs[0];
204            if is_option {
205                if quote!(#val).to_string().trim() == "None" {
206                    Ok(quote! { Self::#var_ident #fields => #val, })
207                } else {
208                    Ok(quote! { Self::#var_ident #fields => Some(#val), })
209                }
210            } else {
211                Ok(quote! { Self::#var_ident #fields => #val, })
212            }
213        }
214        _ => Err(Error::new_spanned(
215            variant,
216            format!("Too many `assoc` attributes for {}", func_ident),
217        )),
218    }
219    .map(|toks| (toks, Wildcard::None))
220}
221
222fn build_rev_assoc(
223    assocs: impl Iterator<Item = Association>,
224    variant: &Variant,
225    is_option: bool,
226) -> Result<(proc_macro2::TokenStream, Wildcard)> {
227    let var_ident = &variant.ident;
228    let assocs = assocs
229        .filter_map(|assoc| 
230        {
231            if let AssociationType::Reverse(pat) = assoc.assoc {
232                Some(Ok(pat))
233            } else {
234                None
235            }
236        })
237        .collect::<Result<Vec<syn::Pat>>>()?;
238    let mut concrete_pats: Vec<proc_macro2::TokenStream> = Vec::new();
239    let mut wildcard_pat: Option<proc_macro2::TokenStream> = None;
240    let mut wildcard_status = Wildcard::False;
241    for pat in assocs.iter() {
242        if !matches!(variant.fields, syn::Fields::Unit) {
243            return Err(Error::new_spanned(
244                variant,
245                "Reverse associations not allowed for tuple or struct-like variants",
246            ));
247        }
248        let arm = if is_option {
249            quote!(#pat => Some(Self::#var_ident),)
250        } else {
251            quote!(#pat => Self::#var_ident,)
252        };
253        if matches!(pat, syn::Pat::Wild(_)) {
254            if wildcard_pat.is_some() {
255                return Err(syn::Error::new_spanned(
256                    pat,
257                    "Only 1 wildcard allowed per reverse association",
258                ));
259            }
260            wildcard_status = Wildcard::True;
261            wildcard_pat = Some(arm);
262        } else {
263            concrete_pats.push(arm);
264        }
265    }
266    if let Some(wildcard_pat) = wildcard_pat {
267        concrete_pats.push(wildcard_pat)
268    }
269    Ok((quote!(#(#concrete_pats) *), wildcard_status))
270}
271
272/// A container for a function parsed within a `func` attribute. Note that the
273/// span of the `func` atribute is included because the syn nodes were
274/// manipulated as a string and have lost therr own span information.
275struct DeriveFunc {
276    vis: syn::Visibility,
277    sig: syn::Signature,
278    span: proc_macro2::Span,
279    def: Option<proc_macro2::TokenStream>,
280}
281
282/// An association. Contains a function ident as well as the actual tokens of
283/// the VALUE (not the variant) of the association.
284struct Association {
285    func: syn::Ident,
286    assoc: AssociationType,
287}
288
289enum AssociationType {
290    Forward(syn::Expr),
291    Reverse(syn::Pat),
292}
293
294/// For reverse associations, this enum keeps track of wldcard patterns. For
295/// forward associations, the value is always set to "None". This is also used
296/// to sort reverse associations appropriately. If more complex sorting is to
297/// be implemented, updating this enum would be the best way to start.
298#[derive(PartialEq, Eq, PartialOrd, Ord)]
299enum Wildcard {
300    False = 0,
301    None = 1,
302    True = 2,
303}
304
305impl syn::parse::Parse for DeriveFunc {
306    /// Parse a function signature from an attribute
307    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
308        input.step(|cursor| {
309            if let Some((_, next)) = cursor.token_tree() {
310                Ok(((), next))
311            } else {
312                Err(cursor.error("Missing function signature"))
313            }
314        })?;
315        let content;
316        parenthesized!(content in input);
317        let vis = content.parse::<syn::Visibility>()?;
318        let sig = content.parse::<syn::Signature>()?;
319        let def = if content.is_empty() {
320            None
321        } else {
322            let block = content.parse::<syn::Block>()?;
323            Some(proc_macro2::TokenStream::from(ToTokens::into_token_stream(
324                block,
325            )))
326        };
327        Ok(DeriveFunc {
328            vis,
329            sig,
330            span: content.span(),
331            def,
332        })
333    }
334}
335
336/// Used to parse forward associations, which are of form Ident = Expr
337struct ForwardAssocTokens(syn::Ident, syn::Expr);
338impl syn::parse::Parse for ForwardAssocTokens
339{
340    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
341        let ident = input.parse()?;
342        input.parse::<syn::Token!(=)>()?;
343        let expr = input.parse()?;
344        Ok(Self(ident, expr))
345    }
346}
347
348/// Used to parse reverse associations, which are of form Ident = Pat
349struct ReverseAssocTokens(syn::Ident, syn::Pat);
350impl syn::parse::Parse for ReverseAssocTokens
351{
352    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
353        let ident = input.parse()?;
354        input.parse::<syn::Token!(=)>()?;
355        let pat = syn::Pat::parse_multi_with_leading_vert(input)?;
356        Ok(Self(ident, pat))
357    }
358}
359
360impl Into<Association> for ForwardAssocTokens
361{
362    fn into(self) -> Association {
363        Association {
364            func: self.0,
365            assoc: AssociationType::Forward(self.1),
366        }
367    }
368}
369
370impl Into<Association> for ReverseAssocTokens
371{
372    fn into(self) -> Association {
373        Association {
374            func: self.0,
375            assoc: AssociationType::Reverse(self.1),
376        }
377    }
378}
379
380impl Association {
381    fn get_variant_assocs(
382        variant: &Variant,
383        is_reverse: bool,
384    ) -> impl Iterator<Item = Self> + '_ {
385        variant
386            .attrs
387            .iter()
388            .filter(|assoc_attr| assoc_attr.path().is_ident(ASSOC_ATTR))
389            .filter_map(move |attr| if let syn::Meta::List(meta_list) = &attr.meta { 
390                if is_reverse {
391                    let parser = Punctuated::<ReverseAssocTokens, Token![,]>::parse_terminated;
392                    parser.parse2(meta_list.tokens.clone()).map(|tokens| tokens.into_iter().map(|tokens| tokens.into()).collect::<Vec<Self>>()).ok() 
393                } else {
394                    let parser = Punctuated::<ForwardAssocTokens, Token![,]>::parse_terminated;
395                    parser.parse2(meta_list.tokens.clone()).map(|tokens| tokens.into_iter().map(|tokens| tokens.into()).collect::<Vec<Self>>()).ok()
396                }
397            } else {
398                 None 
399            })
400            .flat_map(std::convert::identity)
401    }
402}