magic_args_derive/
lib.rs

1extern crate proc_macro;
2
3use std::collections::HashSet;
4
5use proc_macro2::TokenStream;
6use quote::{ToTokens, quote};
7use syn::parse::{Parse, ParseStream};
8use syn::{Data, DeriveInput, Index};
9
10mod keyword {
11    syn::custom_keyword!(skip);
12}
13
14enum MagicArgsAttribute {
15    Skip,
16}
17
18impl Parse for MagicArgsAttribute {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        let lookahead = input.lookahead1();
21
22        if lookahead.peek(keyword::skip) {
23            let _skip: keyword::skip = input.parse()?;
24
25            Ok(Self::Skip)
26        } else {
27            panic!("unknown attribute")
28        }
29    }
30}
31
32#[proc_macro_derive(MagicArgs, attributes(magic_args))]
33pub fn magic_args_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34    let item = syn::parse_macro_input!(input as DeriveInput);
35
36    let Data::Struct(data) = item.data else {
37        panic!("MagicArgs can only be derived on structs")
38    };
39
40    let item_name = item.ident.clone();
41    let (impl_generics, type_generics, where_clause) = item.generics.split_for_impl();
42
43    let mut output = TokenStream::new();
44
45    let mut types_seen = HashSet::new();
46
47    for (idx, field) in data
48        .fields
49        .into_iter()
50        .enumerate()
51        .map(|(idx, field)| (Index::from(idx), field))
52    {
53        let mut skip = false;
54
55        field
56            .attrs
57            .iter()
58            .map(|attr| attr.parse_args().unwrap())
59            .for_each(|attr: MagicArgsAttribute| match attr {
60                MagicArgsAttribute::Skip => skip = true,
61            });
62
63        if skip {
64            continue;
65        }
66
67        let field_type = field.ty;
68
69        let field_accessor = match field.ident {
70            Some(ident) => ident.to_token_stream(),
71            None => idx.to_token_stream(),
72        };
73
74        output.extend(quote! {
75            impl #impl_generics ::magic_args::Args<::magic_args::__private::Tagged<#field_type, #idx>> for #item_name #type_generics
76                #where_clause
77            {
78                #[inline]
79                fn get(&self) -> ::magic_args::__private::Tagged<#field_type, #idx> {
80                    ::magic_args::__private::Tagged(::core::clone::Clone::clone(&self.#field_accessor))
81                }
82            }
83        });
84
85        if !types_seen.insert(field_type) {
86            panic!("MagicArgs cannot contain two items of the same type");
87        }
88    }
89
90    quote! {
91        const _: () = {
92            #output
93        };
94    }
95    .into()
96}