1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, parse_str, punctuated::Punctuated, token::Comma, Attribute,
6 GenericArgument, GenericParam, Ident, Item, ItemStruct, Meta, Type, Visibility, WhereClause,
7};
8
9const HOLDER_SUFFIX: &'static str = "Holder";
10
11struct ItemEnumOrStruct {
12 ident: Ident,
13 generic_params: Punctuated<GenericParam, Comma>,
14 where_clause: Option<WhereClause>,
15 vis: Visibility,
16}
17#[proc_macro_derive(Holdable)]
18pub fn holder_derive(input: TokenStream) -> TokenStream {
19 let item = parse_macro_input!(input as Item);
20 let item = match item {
21 Item::Enum(value) => ItemEnumOrStruct {
22 ident: value.ident,
23 generic_params: value.generics.params,
24 where_clause: value.generics.where_clause,
25 vis: value.vis,
26 },
27 Item::Struct(value) => ItemEnumOrStruct {
28 ident: value.ident,
29 generic_params: value.generics.params,
30 where_clause: value.generics.where_clause,
31 vis: value.vis,
32 },
33 _ => panic!("unimplemented item type"),
34 };
35 let struct_name = &item.ident;
36 let struct_generic = item.generic_params;
37 let struct_where_clause = item.where_clause;
38 let struct_visibility = item.vis;
39 let mut struct_generic_without_bounds = struct_generic.clone();
40 remove_bounds_from_generic(&mut struct_generic_without_bounds);
41 let holder_trait_name = format_ident!("{}{HOLDER_SUFFIX}", struct_name);
42 let fn_name = struct_name.to_string().clone().to_case(Case::Snake);
43 let mut_fn_name = format!("{}_mut", fn_name);
44 let fn_name: Ident = parse_str(fn_name.as_str()).unwrap();
45 let mut_fn_name: Ident = parse_str(mut_fn_name.as_str()).unwrap();
46 #[cfg(feature = "fast_delegate")]
47 let attr: Attribute = parse_quote!(#[fast_delegate::delegate]);
48 #[cfg(not(feature = "fast_delegate"))]
49 let attr: Option<Attribute> = None;
50 quote! {
51 #attr
52 #struct_visibility trait #holder_trait_name<#struct_generic> #struct_where_clause {
53 fn #fn_name<'__a: '__b, '__b>(&'__a self) -> &'__b #struct_name<#struct_generic_without_bounds>;
54 fn #mut_fn_name<'__a: '__b, '__b>(&'__a mut self) -> &'__b mut #struct_name<#struct_generic_without_bounds>;
55 }
56 }
57 .into()
58}
59
60#[proc_macro_derive(Holder, attributes(hold))]
61pub fn holder(input: TokenStream) -> TokenStream {
62 let item_struct = parse_macro_input!(input as ItemStruct);
63 let struct_name = &item_struct.ident;
64 let struct_generic = &item_struct.generics.params;
65 let mut struct_generic_without_bounds = struct_generic.clone();
66 remove_bounds_from_generic(&mut struct_generic_without_bounds);
67 let struct_where_clause = &item_struct.generics.where_clause;
68
69 let quotes: Vec<_> = item_struct
70 .fields
71 .iter()
72 .filter_map(|field| {
73 let mut holder_trait_name_in_attr: Option<Ident> = None;
74 let is_holdable_field = field.attrs.iter().any(|attr| match &attr.meta {
75 Meta::List(list) => {
76 if list.path.is_ident("hold") {
77 holder_trait_name_in_attr = attr.parse_args().unwrap();
78 true
79 } else {
80 false
81 }
82 },
83 Meta::Path(path) => path.is_ident("hold"),
84 _ => panic!("unimplemented attr meta type"),
85 });
86 if !is_holdable_field {
87 return Option::<proc_macro2::TokenStream>::None;
88 }
89 let field_name = field
90 .ident
91 .clone()
92 .expect("unimplemented non field name case");
93 let field_type_ident = get_ident_by_type(&field.ty);
94 let field_type = &field.ty;
95 let holder_trait_name = holder_trait_name_in_attr
96 .or_else(|| Some(parse_str(format!("{}{HOLDER_SUFFIX}", field_type_ident).as_str()).unwrap())).unwrap();
97 let mut type_name = holder_trait_name.clone().to_string();
98 type_name.truncate(holder_trait_name.to_string().len() - HOLDER_SUFFIX.len());
99 let fn_name = type_name.to_case(Case::Snake);
100 let field_type_name: Ident = parse_str(type_name.as_str()).unwrap();
101 let fn_name: Ident = parse_str(fn_name.as_str()).unwrap();
102 let mut_fn_name = format!("{}_mut", fn_name.to_string());
103 let mut_fn_name: Ident = parse_str(mut_fn_name.as_str()).unwrap();
104 let field_bounds = get_generic_by_type(field_type);
105 Some(
106 quote! {
107 impl<#struct_generic>
108 #holder_trait_name<#field_bounds> for #struct_name<#struct_generic_without_bounds> #struct_where_clause {
109 fn #fn_name<'__a: '__b, '__b>(&'__a self) -> &'__b #field_type_name<#field_bounds> {
110 &self.#field_name
111 }
112 fn #mut_fn_name<'__a: '__b, '__b>(&'__a mut self) -> &'__b mut #field_type_name<#field_bounds> {
113 &mut self.#field_name
114 }
115 }
116 }
117 .into(),
118 )
119 })
120 .collect();
121 quote! {#(#quotes)*}.into()
122}
123
124fn get_ident_by_type(ty: &Type) -> Ident {
125 match ty {
126 Type::Path(value) => value.path.segments.last().unwrap().ident.clone(),
127 Type::Reference(value) => get_ident_by_type(&*value.elem),
128 _ => panic!("unimplemented field type"),
129 }
130}
131
132fn get_generic_by_type(ty: &Type) -> Option<Punctuated<GenericArgument, Comma>> {
133 match ty {
134 Type::Path(value) => match &value.path.segments.last().unwrap().arguments {
135 syn::PathArguments::None => None,
136 syn::PathArguments::AngleBracketed(value) => Some(value.args.clone()),
137 syn::PathArguments::Parenthesized(_) => None,
138 },
139 Type::Reference(value) => get_generic_by_type(&*value.elem),
140 _ => panic!("unimplemented field type"),
141 }
142}
143
144fn remove_bounds_from_generic(generic: &mut Punctuated<GenericParam, Comma>) {
145 for struct_generic in generic.iter_mut() {
146 match struct_generic {
147 syn::GenericParam::Lifetime(lifetime) => {
148 lifetime.bounds.clear();
149 }
150 syn::GenericParam::Type(lifetime) => {
151 lifetime.bounds.clear();
152 }
153 _ => {}
154 }
155 }
156}