1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{
6 parenthesized, parse::Parser, punctuated::Punctuated, spanned::Spanned, Error, FnArg, Result,
7 Token, Variant,
8};
9
10const FUNC_ATTR: &'static str = "func";
11const ASSOC_ATTR: &'static str = "assoc";
12
13#[proc_macro_derive(Assoc, attributes(func, assoc))]
14pub fn derive_assoc(input: TokenStream) -> TokenStream {
15 impl_macro(&syn::parse(input).expect("Failed to parse macro input"))
16 .unwrap_or_else(syn::Error::into_compile_error)
18 .into()
19}
20
21fn impl_macro(ast: &syn::DeriveInput) -> Result<proc_macro2::TokenStream> {
22 let name = &ast.ident;
23 let generics = &ast.generics;
24 let generic_params = &generics.params;
25 let fns = ast
26 .attrs
27 .iter()
28 .filter(|attr| attr.path().is_ident(FUNC_ATTR))
29 .map(|attr| syn::parse2::<DeriveFuncs>(attr.meta.to_token_stream()))
30 .collect::<Result<Vec<DeriveFuncs>>>()?;
31 let variants: Vec<&Variant> = if let syn::Data::Enum(data) = &ast.data {
32 data.variants.iter().collect()
33 } else {
34 panic!("#[derive(Assoc)] only applicable to enums")
35 };
36 let functions: Vec<proc_macro2::TokenStream> = fns
37 .into_iter()
38 .flat_map(|DeriveFuncs(funcs)| {
39 funcs
40 .iter()
41 .map(|func| build_function(&variants, func, funcs.clone()))
42 .collect::<Vec<_>>()
43 })
44 .collect::<Result<Vec<proc_macro2::TokenStream>>>()?;
45 Ok(quote! {
46 impl <#generic_params> #name #generics
47 {
48 #(#functions)*
49 }
50 }
51 .into())
52}
53
54fn build_function(
55 variants: &[&Variant],
56 func: &DeriveFunc,
57 associated_funcs: Vec<DeriveFunc>,
58) -> Result<proc_macro2::TokenStream> {
59 let vis = &func.vis;
60 let sig = &func.sig;
61 let has_self = match func.sig.inputs.first() {
63 Some(FnArg::Receiver(_)) => true,
64 Some(FnArg::Typed(pat_type)) => {
65 let pat = &pat_type.pat;
66 quote!(#pat).to_string().trim() == "self"
67 }
68 None => false,
69 };
70 let is_option = if let syn::ReturnType::Type(_, ty) = &func.sig.output {
71 let s = quote!(#ty).to_string();
72 let trimmed = s.trim();
73 trimmed.starts_with("Option") && trimmed.len() > 6 && trimmed[6..].trim().starts_with("<")
74 } else {
75 false
76 };
77 let mut arms = variants
78 .iter()
79 .map(|variant| {
80 build_variant_arm(
81 variant,
82 &func.sig.ident,
83 associated_funcs.iter().map(|func| func.sig.ident.clone()),
84 is_option,
85 has_self,
86 &func.def,
87 )
88 })
89 .collect::<Result<Vec<(proc_macro2::TokenStream, Wildcard)>>>()?;
90 if is_option
91 && !arms
92 .iter()
93 .any(|(_, wildcard)| matches!(wildcard, Wildcard::True))
94 {
95 arms.push((quote!(_ => None,), Wildcard::True))
96 }
97 if has_self == false {
99 arms.sort_by(|(_, wildcard1), (_, wildcard2)| wildcard1.cmp(wildcard2));
100 }
101 let arms = arms.into_iter().map(|(toks, _)| toks);
102 let match_on = if has_self {
103 quote!(self)
104 } else if func.sig.inputs.is_empty() {
105 return Err(syn::Error::new(func.span, "Missing parameter"));
106 } else {
107 let mut result = quote!();
108 for input in &func.sig.inputs {
109 match input {
110 FnArg::Receiver(_) => {
111 result = quote!(self);
112 break;
113 }
114 FnArg::Typed(pat_type) => {
115 let pat = &pat_type.pat;
116 result = if result.is_empty() {
117 quote!(#pat)
118 } else {
119 quote!(#result, #pat)
120 };
121 }
122 }
123 }
124 if func.sig.inputs.len() > 1 {
125 result = quote!((#result));
126 }
127 result
128 };
129 Ok(quote! {
130 #vis #sig
131 {
132 match #match_on
133 {
134 #(#arms)*
135 }
136 }
137 })
138}
139
140fn build_variant_arm(
141 variant: &Variant,
142 func: &syn::Ident,
143 mut assoc_funcs: impl Iterator<Item = syn::Ident>,
144 is_option: bool,
145 has_self: bool,
146 def: &Option<proc_macro2::TokenStream>,
147) -> Result<(proc_macro2::TokenStream, Wildcard)> {
148 let assocs = Association::get_variant_assocs(variant, !has_self).filter(|assoc| {
150 assoc.func == *func || assoc_funcs.any(|assoc_func| assoc_func == assoc.func)
151 });
152 if has_self {
153 build_fwd_assoc(assocs, variant, is_option, func, def)
154 } else {
155 build_rev_assoc(assocs, variant, is_option)
156 }
157}
158
159fn build_fwd_assoc(
160 assocs: impl Iterator<Item = Association>,
161 variant: &Variant,
162 is_option: bool,
163 func_ident: &syn::Ident,
164 def: &Option<proc_macro2::TokenStream>,
165) -> Result<(proc_macro2::TokenStream, Wildcard)> {
166 let var_ident = &variant.ident;
167 let fields = match &variant.fields {
168 syn::Fields::Named(fields) => {
169 let named = fields
170 .named
171 .iter()
172 .map(|f| {
173 let ident = &f.ident;
174 let val: &Option<proc_macro2::Ident> = &f.ident.as_ref().map(|s| {
175 proc_macro2::Ident::new(
176 &("_".to_string() + &s.to_string()),
177 f.span().clone(),
178 )
179 });
180 quote!(#ident: #val)
181 })
182 .collect::<Vec<proc_macro2::TokenStream>>();
183 quote!({#(#named),*})
184 }
185 syn::Fields::Unnamed(fields) => {
186 let unnamed = fields
187 .unnamed
188 .iter()
189 .enumerate()
190 .map(|(i, f)| {
191 let ident = proc_macro2::Ident::new(
192 &("_".to_string() + &i.to_string()),
193 f.span().clone(),
194 );
195 quote!(#ident)
196 })
197 .collect::<Vec<proc_macro2::TokenStream>>();
198 quote!((#(#unnamed),*))
199 }
200 _ => quote!(),
201 };
202 let assocs = assocs
203 .filter_map(|assoc| {
204 if let AssociationType::Forward(expr) = assoc.assoc {
205 Some(Ok(expr))
206 } else {
207 None
208 }
209 })
210 .collect::<Result<Vec<syn::Expr>>>()?;
211 match assocs.len() {
212 0 => {
213 if let Some(tokens) = def {
214 Ok(quote! { Self::#var_ident #fields => #tokens, })
215 } else if is_option {
216 Ok(quote! { Self::#var_ident #fields => None, })
217 } else {
218 Err(Error::new_spanned(
219 variant,
220 format!("Missing `assoc` attribute for {}", func_ident),
221 ))
222 }
223 }
224 1 => {
225 let val = &assocs[0];
226 if is_option {
227 if quote!(#val).to_string().trim() == "None" {
228 Ok(quote! { Self::#var_ident #fields => #val, })
229 } else {
230 Ok(quote! { Self::#var_ident #fields => Some(#val), })
231 }
232 } else {
233 Ok(quote! { Self::#var_ident #fields => #val, })
234 }
235 }
236 _ => Err(Error::new_spanned(
237 variant,
238 format!("Too many `assoc` attributes for {}", func_ident),
239 )),
240 }
241 .map(|toks| (toks, Wildcard::None))
242}
243
244fn build_rev_assoc(
245 assocs: impl Iterator<Item = Association>,
246 variant: &Variant,
247 is_option: bool,
248) -> Result<(proc_macro2::TokenStream, Wildcard)> {
249 let var_ident = &variant.ident;
250 let assocs = assocs
251 .filter_map(|assoc| {
252 if let AssociationType::Reverse(pat) = assoc.assoc {
253 Some(Ok(pat))
254 } else {
255 None
256 }
257 })
258 .collect::<Result<Vec<syn::Pat>>>()?;
259 let mut concrete_pats: Vec<proc_macro2::TokenStream> = Vec::new();
260 let mut wildcard_pat: Option<proc_macro2::TokenStream> = None;
261 let mut wildcard_status = Wildcard::False;
262 for pat in assocs.iter() {
263 if !matches!(variant.fields, syn::Fields::Unit) {
264 return Err(Error::new_spanned(
265 variant,
266 "Reverse associations not allowed for tuple or struct-like variants",
267 ));
268 }
269 let arm = if is_option {
270 quote!(#pat => Some(Self::#var_ident),)
271 } else {
272 quote!(#pat => Self::#var_ident,)
273 };
274 if matches!(pat, syn::Pat::Wild(_)) {
275 if wildcard_pat.is_some() {
276 return Err(syn::Error::new_spanned(
277 pat,
278 "Only 1 wildcard allowed per reverse association",
279 ));
280 }
281 wildcard_status = Wildcard::True;
282 wildcard_pat = Some(arm);
283 } else {
284 concrete_pats.push(arm);
285 }
286 }
287 if let Some(wildcard_pat) = wildcard_pat {
288 concrete_pats.push(wildcard_pat)
289 }
290 Ok((quote!(#(#concrete_pats) *), wildcard_status))
291}
292
293#[derive(Clone)]
297struct DeriveFunc {
298 vis: syn::Visibility,
299 sig: syn::Signature,
300 span: proc_macro2::Span,
301 def: Option<proc_macro2::TokenStream>,
302}
303
304struct Association {
307 func: syn::Ident,
308 assoc: AssociationType,
309}
310
311enum AssociationType {
312 Forward(syn::Expr),
313 Reverse(syn::Pat),
314}
315
316#[derive(PartialEq, Eq, PartialOrd, Ord)]
321enum Wildcard {
322 False = 0,
323 None = 1,
324 True = 2,
325}
326
327impl syn::parse::Parse for DeriveFunc {
328 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
330 let vis = input.parse::<syn::Visibility>()?;
331 let sig = input.parse::<syn::Signature>()?;
332 let def = if let Ok(block) = input.parse::<syn::Block>() {
333 Some(proc_macro2::TokenStream::from(ToTokens::into_token_stream(
334 block,
335 )))
336 } else {
337 None
338 };
339 Ok(DeriveFunc {
340 vis,
341 sig,
342 span: input.span(),
343 def,
344 })
345 }
346}
347
348struct DeriveFuncs(Vec<DeriveFunc>);
349impl syn::parse::Parse for DeriveFuncs {
350 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
352 input.step(|cursor| {
353 if let Some((_, next)) = cursor.token_tree() {
354 Ok(((), next))
355 } else {
356 Err(cursor.error("Missing function signature"))
357 }
358 })?;
359 let content;
360 parenthesized!(content in input);
361 Ok(Self(
362 content
363 .parse_terminated(DeriveFunc::parse, Token!(,))
364 .map(|parsed| parsed.into_iter().collect())?,
365 ))
366 }
367}
368
369struct ForwardAssocTokens(syn::Ident, syn::Expr);
371impl syn::parse::Parse for ForwardAssocTokens {
372 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
373 let ident = input.parse()?;
374 input.parse::<syn::Token!(=)>()?;
375 let expr = input.parse()?;
376 Ok(Self(ident, expr))
377 }
378}
379
380struct ReverseAssocTokens(syn::Ident, syn::Pat);
382impl syn::parse::Parse for ReverseAssocTokens {
383 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
384 let ident = input.parse()?;
385 input.parse::<syn::Token!(=)>()?;
386 let pat = syn::Pat::parse_multi_with_leading_vert(input)?;
387 Ok(Self(ident, pat))
388 }
389}
390
391impl Into<Association> for ForwardAssocTokens {
392 fn into(self) -> Association {
393 Association {
394 func: self.0,
395 assoc: AssociationType::Forward(self.1),
396 }
397 }
398}
399
400impl Into<Association> for ReverseAssocTokens {
401 fn into(self) -> Association {
402 Association {
403 func: self.0,
404 assoc: AssociationType::Reverse(self.1),
405 }
406 }
407}
408
409impl Association {
410 fn get_variant_assocs(variant: &Variant, is_reverse: bool) -> impl Iterator<Item = Self> + '_ {
411 variant
412 .attrs
413 .iter()
414 .filter(|assoc_attr| assoc_attr.path().is_ident(ASSOC_ATTR))
415 .filter_map(move |attr| {
416 if let syn::Meta::List(meta_list) = &attr.meta {
417 if is_reverse {
418 let parser = Punctuated::<ReverseAssocTokens, Token![,]>::parse_terminated;
419 parser
420 .parse2(meta_list.tokens.clone())
421 .map(|tokens| {
422 tokens
423 .into_iter()
424 .map(|tokens| tokens.into())
425 .collect::<Vec<Self>>()
426 })
427 .ok()
428 } else {
429 let parser = Punctuated::<ForwardAssocTokens, Token![,]>::parse_terminated;
430 parser
431 .parse2(meta_list.tokens.clone())
432 .map(|tokens| {
433 tokens
434 .into_iter()
435 .map(|tokens| tokens.into())
436 .collect::<Vec<Self>>()
437 })
438 .ok()
439 }
440 } else {
441 None
442 }
443 })
444 .flat_map(std::convert::identity)
445 }
446}