holder_derive/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{
5    parse_macro_input, parse_quote, parse_str, punctuated::Punctuated, token::Comma, Attribute,
6    GenericArgument, GenericParam, Ident, Item, ItemStruct, Meta, Type, Visibility, WhereClause,
7};
8
9const HOLDER_SUFFIX: &'static str = "Holder";
10
11struct ItemEnumOrStruct {
12    ident: Ident,
13    generic_params: Punctuated<GenericParam, Comma>,
14    where_clause: Option<WhereClause>,
15    vis: Visibility,
16}
17#[proc_macro_derive(Holdable)]
18pub fn holder_derive(input: TokenStream) -> TokenStream {
19    let item = parse_macro_input!(input as Item);
20    let item = match item {
21        Item::Enum(value) => ItemEnumOrStruct {
22            ident: value.ident,
23            generic_params: value.generics.params,
24            where_clause: value.generics.where_clause,
25            vis: value.vis,
26        },
27        Item::Struct(value) => ItemEnumOrStruct {
28            ident: value.ident,
29            generic_params: value.generics.params,
30            where_clause: value.generics.where_clause,
31            vis: value.vis,
32        },
33        _ => panic!("unimplemented item type"),
34    };
35    let struct_name = &item.ident;
36    let struct_generic = item.generic_params;
37    let struct_where_clause = item.where_clause;
38    let struct_visibility = item.vis;
39    let mut struct_generic_without_bounds = struct_generic.clone();
40    remove_bounds_from_generic(&mut struct_generic_without_bounds);
41    let holder_trait_name = format_ident!("{}{HOLDER_SUFFIX}", struct_name);
42    let fn_name = struct_name.to_string().clone().to_case(Case::Snake);
43    let mut_fn_name = format!("{}_mut", fn_name);
44    let fn_name: Ident = parse_str(fn_name.as_str()).unwrap();
45    let mut_fn_name: Ident = parse_str(mut_fn_name.as_str()).unwrap();
46    #[cfg(feature = "fast_delegate")]
47    let attr: Attribute = parse_quote!(#[fast_delegate::delegate]);
48    #[cfg(not(feature = "fast_delegate"))]
49    let attr: Option<Attribute> = None;
50    quote! {
51        #attr
52        #struct_visibility trait #holder_trait_name<#struct_generic> #struct_where_clause {
53            fn #fn_name<'__a: '__b, '__b>(&'__a self) -> &'__b #struct_name<#struct_generic_without_bounds>;
54            fn #mut_fn_name<'__a: '__b, '__b>(&'__a mut self) -> &'__b mut #struct_name<#struct_generic_without_bounds>;
55        }
56    }
57    .into()
58}
59
60#[proc_macro_derive(Holder, attributes(hold))]
61pub fn holder(input: TokenStream) -> TokenStream {
62    let item_struct = parse_macro_input!(input as ItemStruct);
63    let struct_name = &item_struct.ident;
64    let struct_generic = &item_struct.generics.params;
65    let mut struct_generic_without_bounds = struct_generic.clone();
66    remove_bounds_from_generic(&mut struct_generic_without_bounds);
67    let struct_where_clause = &item_struct.generics.where_clause;
68
69    let quotes: Vec<_> = item_struct
70        .fields
71        .iter()
72        .filter_map(|field| {
73            let mut holder_trait_name_in_attr: Option<Ident> = None;
74            let is_holdable_field = field.attrs.iter().any(|attr| match &attr.meta {
75                Meta::List(list) => {
76                    if list.path.is_ident("hold") {
77                        holder_trait_name_in_attr = attr.parse_args().unwrap();
78                        true
79                    } else {
80                        false
81                    }
82                },
83                Meta::Path(path) => path.is_ident("hold"),
84                _ => panic!("unimplemented attr meta type"),
85            });
86            if !is_holdable_field {
87                return Option::<proc_macro2::TokenStream>::None;
88            }
89            let field_name = field
90                .ident
91                .clone()
92                .expect("unimplemented non field name case");
93            let field_type_ident = get_ident_by_type(&field.ty);
94            let field_type = &field.ty;
95            let holder_trait_name = holder_trait_name_in_attr
96                .or_else(|| Some(parse_str(format!("{}{HOLDER_SUFFIX}", field_type_ident).as_str()).unwrap())).unwrap();
97            let mut type_name = holder_trait_name.clone().to_string();
98            type_name.truncate(holder_trait_name.to_string().len() - HOLDER_SUFFIX.len());
99            let fn_name = type_name.to_case(Case::Snake);
100            let field_type_name: Ident = parse_str(type_name.as_str()).unwrap();
101            let fn_name: Ident = parse_str(fn_name.as_str()).unwrap();
102            let mut_fn_name = format!("{}_mut", fn_name.to_string());
103            let mut_fn_name: Ident = parse_str(mut_fn_name.as_str()).unwrap();
104            let field_bounds = get_generic_by_type(field_type);
105            Some(
106                quote! {
107                    impl<#struct_generic>
108                        #holder_trait_name<#field_bounds> for #struct_name<#struct_generic_without_bounds> #struct_where_clause {
109                        fn #fn_name<'__a: '__b, '__b>(&'__a self) -> &'__b #field_type_name<#field_bounds> {
110                            &self.#field_name
111                        }
112                        fn #mut_fn_name<'__a: '__b, '__b>(&'__a mut self) -> &'__b mut #field_type_name<#field_bounds> {
113                            &mut self.#field_name
114                        }
115                    }
116                }
117                .into(),
118            )
119        })
120        .collect();
121    quote! {#(#quotes)*}.into()
122}
123
124fn get_ident_by_type(ty: &Type) -> Ident {
125    match ty {
126        Type::Path(value) => value.path.segments.last().unwrap().ident.clone(),
127        Type::Reference(value) => get_ident_by_type(&*value.elem),
128        _ => panic!("unimplemented field type"),
129    }
130}
131
132fn get_generic_by_type(ty: &Type) -> Option<Punctuated<GenericArgument, Comma>> {
133    match ty {
134        Type::Path(value) => match &value.path.segments.last().unwrap().arguments {
135            syn::PathArguments::None => None,
136            syn::PathArguments::AngleBracketed(value) => Some(value.args.clone()),
137            syn::PathArguments::Parenthesized(_) => None,
138        },
139        Type::Reference(value) => get_generic_by_type(&*value.elem),
140        _ => panic!("unimplemented field type"),
141    }
142}
143
144fn remove_bounds_from_generic(generic: &mut Punctuated<GenericParam, Comma>) {
145    for struct_generic in generic.iter_mut() {
146        match struct_generic {
147            syn::GenericParam::Lifetime(lifetime) => {
148                lifetime.bounds.clear();
149            }
150            syn::GenericParam::Type(lifetime) => {
151                lifetime.bounds.clear();
152            }
153            _ => {}
154        }
155    }
156}