1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{parenthesized, parse::Parser, punctuated::Punctuated, spanned::Spanned, Error, FnArg, Result, Token, Variant};
6
7const FUNC_ATTR: &'static str = "func";
8const ASSOC_ATTR: &'static str = "assoc";
9
10#[proc_macro_derive(Assoc, attributes(func, assoc))]
11pub fn derive_assoc(input: TokenStream) -> TokenStream {
12 impl_macro(&syn::parse(input).expect("Failed to parse macro input"))
13 .unwrap_or_else(syn::Error::into_compile_error)
15 .into()
16}
17
18fn impl_macro(ast: &syn::DeriveInput) -> Result<proc_macro2::TokenStream> {
19 let name = &ast.ident;
20 let generics = &ast.generics;
21 let generic_params = &generics.params;
22 let fns = ast
23 .attrs
24 .iter()
25 .filter(|attr| attr.path().is_ident(FUNC_ATTR))
26 .map(|attr| syn::parse2::<DeriveFunc>(attr.meta.to_token_stream()))
27 .collect::<Result<Vec<DeriveFunc>>>()?;
28 let variants: Vec<&Variant> = if let syn::Data::Enum(data) = &ast.data {
29 data.variants.iter().collect()
30 } else {
31 panic!("#[derive(Assoc)] only applicable to enums")
32 };
33 let functions: Vec<proc_macro2::TokenStream> = fns
34 .into_iter()
35 .map(|func| build_function(&variants, func))
36 .collect::<Result<Vec<proc_macro2::TokenStream>>>()?;
37 Ok(quote! {
38 impl <#generic_params> #name #generics
39 {
40 #(#functions)*
41 }
42 }
43 .into())
44}
45
46fn build_function(variants: &[&Variant], func: DeriveFunc) -> Result<proc_macro2::TokenStream> {
47 let vis = &func.vis;
48 let sig = &func.sig;
49 let has_self = match func.sig.inputs.first() {
51 Some(FnArg::Receiver(_)) => true,
52 Some(FnArg::Typed(pat_type)) => {
53 let pat = &pat_type.pat;
54 quote!(#pat).to_string().trim() == "self"
55 }
56 None => false,
57 };
58 let is_option = if let syn::ReturnType::Type(_, ty) = &func.sig.output {
59 let s = quote!(#ty).to_string();
60 let trimmed = s.trim();
61 trimmed.starts_with("Option") && trimmed.len() > 6 && trimmed[6..].trim().starts_with("<")
62 } else {
63 false
64 };
65 let mut arms = variants
66 .iter()
67 .map(|variant| build_variant_arm(variant, &func.sig.ident, is_option, has_self, &func.def))
68 .collect::<Result<Vec<(proc_macro2::TokenStream, Wildcard)>>>()?;
69 if is_option
70 && !arms
71 .iter()
72 .any(|(_, wildcard)| matches!(wildcard, Wildcard::True))
73 {
74 arms.push((quote!(_ => None,), Wildcard::True))
75 }
76 if has_self == false {
78 arms.sort_by(|(_, wildcard1), (_, wildcard2)| wildcard1.cmp(wildcard2));
79 }
80 let arms = arms.into_iter().map(|(toks, _)| toks);
81 let match_on = if has_self {
82 quote!(self)
83 } else if func.sig.inputs.is_empty() {
84 return Err(syn::Error::new(func.span, "Missing parameter"));
85 } else {
86 let mut result = quote!();
87 for input in &func.sig.inputs {
88 match input {
89 FnArg::Receiver(_) => {
90 result = quote!(self);
91 break;
92 }
93 FnArg::Typed(pat_type) => {
94 let pat = &pat_type.pat;
95 result = if result.is_empty() {
96 quote!(#pat)
97 } else {
98 quote!(#result, #pat)
99 };
100 }
101 }
102 }
103 if func.sig.inputs.len() > 1 {
104 result = quote!((#result));
105 }
106 result
107 };
108 Ok(quote! {
109 #vis #sig
110 {
111 match #match_on
112 {
113 #(#arms)*
114 }
115 }
116 })
117}
118
119fn build_variant_arm(
120 variant: &Variant,
121 func: &syn::Ident,
122 is_option: bool,
123 has_self: bool,
124 def: &Option<proc_macro2::TokenStream>,
125) -> Result<(proc_macro2::TokenStream, Wildcard)> {
126 let assocs =
128 Association::get_variant_assocs(variant, !has_self).filter(|assoc| assoc.func == *func);
129 if has_self {
130 build_fwd_assoc(assocs, variant, is_option, func, def)
131 } else {
132 build_rev_assoc(assocs, variant, is_option)
133 }
134}
135
136fn build_fwd_assoc(
137 assocs: impl Iterator<Item = Association>,
138 variant: &Variant,
139 is_option: bool,
140 func_ident: &syn::Ident,
141 def: &Option<proc_macro2::TokenStream>,
142) -> Result<(proc_macro2::TokenStream, Wildcard)> {
143 let var_ident = &variant.ident;
144 let fields = match &variant.fields {
145 syn::Fields::Named(fields) => {
146 let named = fields
147 .named
148 .iter()
149 .map(|f| {
150 let ident = &f.ident;
151 let val: &Option<proc_macro2::Ident> = &f.ident.as_ref().map(|s| {
152 proc_macro2::Ident::new(
153 &("_".to_string() + &s.to_string()),
154 f.span().clone(),
155 )
156 });
157 quote!(#ident: #val)
158 })
159 .collect::<Vec<proc_macro2::TokenStream>>();
160 quote!({#(#named),*})
161 }
162 syn::Fields::Unnamed(fields) => {
163 let unnamed = fields
164 .unnamed
165 .iter()
166 .enumerate()
167 .map(|(i, f)| {
168 let ident = proc_macro2::Ident::new(
169 &("_".to_string() + &i.to_string()),
170 f.span().clone(),
171 );
172 quote!(#ident)
173 })
174 .collect::<Vec<proc_macro2::TokenStream>>();
175 quote!((#(#unnamed),*))
176 }
177 _ => quote!(),
178 };
179 let assocs = assocs
180 .filter_map(|assoc|
181 {
182 if let AssociationType::Forward(expr) = assoc.assoc {
183 Some(Ok(expr))
184 } else {
185 None
186 }
187 })
188 .collect::<Result<Vec<syn::Expr>>>()?;
189 match assocs.len() {
190 0 => {
191 if let Some(tokens) = def {
192 Ok(quote! { Self::#var_ident #fields => #tokens, })
193 } else if is_option {
194 Ok(quote! { Self::#var_ident #fields => None, })
195 } else {
196 Err(Error::new_spanned(
197 variant,
198 format!("Missing `assoc` attribute for {}", func_ident),
199 ))
200 }
201 }
202 1 => {
203 let val = &assocs[0];
204 if is_option {
205 if quote!(#val).to_string().trim() == "None" {
206 Ok(quote! { Self::#var_ident #fields => #val, })
207 } else {
208 Ok(quote! { Self::#var_ident #fields => Some(#val), })
209 }
210 } else {
211 Ok(quote! { Self::#var_ident #fields => #val, })
212 }
213 }
214 _ => Err(Error::new_spanned(
215 variant,
216 format!("Too many `assoc` attributes for {}", func_ident),
217 )),
218 }
219 .map(|toks| (toks, Wildcard::None))
220}
221
222fn build_rev_assoc(
223 assocs: impl Iterator<Item = Association>,
224 variant: &Variant,
225 is_option: bool,
226) -> Result<(proc_macro2::TokenStream, Wildcard)> {
227 let var_ident = &variant.ident;
228 let assocs = assocs
229 .filter_map(|assoc|
230 {
231 if let AssociationType::Reverse(pat) = assoc.assoc {
232 Some(Ok(pat))
233 } else {
234 None
235 }
236 })
237 .collect::<Result<Vec<syn::Pat>>>()?;
238 let mut concrete_pats: Vec<proc_macro2::TokenStream> = Vec::new();
239 let mut wildcard_pat: Option<proc_macro2::TokenStream> = None;
240 let mut wildcard_status = Wildcard::False;
241 for pat in assocs.iter() {
242 if !matches!(variant.fields, syn::Fields::Unit) {
243 return Err(Error::new_spanned(
244 variant,
245 "Reverse associations not allowed for tuple or struct-like variants",
246 ));
247 }
248 let arm = if is_option {
249 quote!(#pat => Some(Self::#var_ident),)
250 } else {
251 quote!(#pat => Self::#var_ident,)
252 };
253 if matches!(pat, syn::Pat::Wild(_)) {
254 if wildcard_pat.is_some() {
255 return Err(syn::Error::new_spanned(
256 pat,
257 "Only 1 wildcard allowed per reverse association",
258 ));
259 }
260 wildcard_status = Wildcard::True;
261 wildcard_pat = Some(arm);
262 } else {
263 concrete_pats.push(arm);
264 }
265 }
266 if let Some(wildcard_pat) = wildcard_pat {
267 concrete_pats.push(wildcard_pat)
268 }
269 Ok((quote!(#(#concrete_pats) *), wildcard_status))
270}
271
272struct DeriveFunc {
276 vis: syn::Visibility,
277 sig: syn::Signature,
278 span: proc_macro2::Span,
279 def: Option<proc_macro2::TokenStream>,
280}
281
282struct Association {
285 func: syn::Ident,
286 assoc: AssociationType,
287}
288
289enum AssociationType {
290 Forward(syn::Expr),
291 Reverse(syn::Pat),
292}
293
294#[derive(PartialEq, Eq, PartialOrd, Ord)]
299enum Wildcard {
300 False = 0,
301 None = 1,
302 True = 2,
303}
304
305impl syn::parse::Parse for DeriveFunc {
306 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
308 input.step(|cursor| {
309 if let Some((_, next)) = cursor.token_tree() {
310 Ok(((), next))
311 } else {
312 Err(cursor.error("Missing function signature"))
313 }
314 })?;
315 let content;
316 parenthesized!(content in input);
317 let vis = content.parse::<syn::Visibility>()?;
318 let sig = content.parse::<syn::Signature>()?;
319 let def = if content.is_empty() {
320 None
321 } else {
322 let block = content.parse::<syn::Block>()?;
323 Some(proc_macro2::TokenStream::from(ToTokens::into_token_stream(
324 block,
325 )))
326 };
327 Ok(DeriveFunc {
328 vis,
329 sig,
330 span: content.span(),
331 def,
332 })
333 }
334}
335
336struct ForwardAssocTokens(syn::Ident, syn::Expr);
338impl syn::parse::Parse for ForwardAssocTokens
339{
340 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
341 let ident = input.parse()?;
342 input.parse::<syn::Token!(=)>()?;
343 let expr = input.parse()?;
344 Ok(Self(ident, expr))
345 }
346}
347
348struct ReverseAssocTokens(syn::Ident, syn::Pat);
350impl syn::parse::Parse for ReverseAssocTokens
351{
352 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
353 let ident = input.parse()?;
354 input.parse::<syn::Token!(=)>()?;
355 let pat = syn::Pat::parse_multi_with_leading_vert(input)?;
356 Ok(Self(ident, pat))
357 }
358}
359
360impl Into<Association> for ForwardAssocTokens
361{
362 fn into(self) -> Association {
363 Association {
364 func: self.0,
365 assoc: AssociationType::Forward(self.1),
366 }
367 }
368}
369
370impl Into<Association> for ReverseAssocTokens
371{
372 fn into(self) -> Association {
373 Association {
374 func: self.0,
375 assoc: AssociationType::Reverse(self.1),
376 }
377 }
378}
379
380impl Association {
381 fn get_variant_assocs(
382 variant: &Variant,
383 is_reverse: bool,
384 ) -> impl Iterator<Item = Self> + '_ {
385 variant
386 .attrs
387 .iter()
388 .filter(|assoc_attr| assoc_attr.path().is_ident(ASSOC_ATTR))
389 .filter_map(move |attr| if let syn::Meta::List(meta_list) = &attr.meta {
390 if is_reverse {
391 let parser = Punctuated::<ReverseAssocTokens, Token![,]>::parse_terminated;
392 parser.parse2(meta_list.tokens.clone()).map(|tokens| tokens.into_iter().map(|tokens| tokens.into()).collect::<Vec<Self>>()).ok()
393 } else {
394 let parser = Punctuated::<ForwardAssocTokens, Token![,]>::parse_terminated;
395 parser.parse2(meta_list.tokens.clone()).map(|tokens| tokens.into_iter().map(|tokens| tokens.into()).collect::<Vec<Self>>()).ok()
396 }
397 } else {
398 None
399 })
400 .flat_map(std::convert::identity)
401 }
402}