1
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, punctuated::Punctuated};
5
6trait Forwarder {}
7
8#[proc_macro_derive(Forwarder, attributes(forward))]
9pub fn forwarder(input_stream: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input_stream as syn::ItemEnum);
11 let ident = input.ident.clone();
12 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl().to_owned();
13
14 let mut to_derive_idents = Vec::new();
15 let mut to_derive_variants = Vec::new();
16 let mut to_derive_args: Vec<Vec<syn::Path>> = Vec::new();
17
18 for variant in input.variants.into_iter() {
19 for attr in variant.attrs {
20 if attr.path().is_ident("forward") {
21 match attr.parse_args_with(
22 Punctuated::<syn::Path, syn::Token!(,)>::parse_terminated
23 ) {
24 Ok(args) => {
25 to_derive_args.push(args.into_iter().collect());
26 },
27 Err(_) => {
28 to_derive_args.push(Vec::new());
29 }
30 };
31
32 to_derive_idents.push(match variant.fields {
33 syn::Fields::Unnamed(ref fields) => {
34 fields.unnamed.clone().into_iter().next().expect("expected a single field")
35 },
36 _ => panic!("expected unnamed tuple-like fields")
37 });
38 to_derive_variants.push(variant.ident.clone());
39 }
40 }
41 }
42
43 let gen = quote! {
44 #(
45 impl #impl_generics From<#to_derive_idents> for #ident #type_generics #where_clause {
46 fn from(value: #to_derive_idents) -> #ident #type_generics {
47 Self::#to_derive_variants(value)
48 }
49 }
50
51 #(
52 impl #impl_generics From<#to_derive_args> for #ident #type_generics #where_clause {
53 fn from(value: #to_derive_args) -> #ident #type_generics {
54 Self::#to_derive_variants(value.into())
55 }
56 }
57 )*
58 )*
59 };
60
61 gen.into()
62}
63
64#[proc_macro_attribute]
65pub fn deref(args: TokenStream, input_stream: TokenStream) -> TokenStream {
66 let input: syn::ItemStruct = parse_macro_input!(input_stream);
67 let field_name: syn::Ident = parse_macro_input!(args);
68
69 let ident = input.ident.clone();
70 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
71
72 let mut fields_iter = input.fields.iter();
73 let field_type = loop {
74 if let Some(field) = fields_iter.next() {
75 if field.ident.clone().expect("expected struct to be {}") == field_name {
76 break field.ty.clone();
77 }
78 } else {
79 panic!("field not found");
80 }
81 };
82
83 let gen = quote! {
84 #input
85 impl #impl_generics std::ops::Deref for #ident #type_generics #where_clause {
86 type Target = #field_type;
87 fn deref(&self) -> &Self::Target {
88 &self.#field_name
89 }
90 }
91 };
92 gen.into()
93}