magic_space_derive/
lib.rs1use std::collections::VecDeque;
9
10use proc_macro::TokenStream;
11use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
12use quote::{quote, quote_spanned, ToTokens};
13use syn::{
14 parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
15 DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
16};
17
18#[proc_macro_derive(MagicSpace, attributes(max_len))]
45pub fn derive_init_space(item: TokenStream) -> TokenStream {
46 let input = parse_macro_input!(item as DeriveInput);
47 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
48 let name = input.ident;
49
50 let process_struct_fields = |fields: Punctuated<Field, Comma>| {
51 let recurse = fields.into_iter().map(|f| {
52 let mut max_len_args = get_max_len_args(&f.attrs);
53 len_from_type(f.ty, &mut max_len_args)
54 });
55
56 quote! {
57 #[automatically_derived]
58 impl #impl_generics magic_space::Space for #name #ty_generics #where_clause {
59 const MAGIC_SPACE: usize = 0 #(+ #recurse)*;
60 }
61 }
62 };
63
64 let expanded: TokenStream2 = match input.data {
65 syn::Data::Struct(strct) => match strct.fields {
66 Fields::Named(named) => process_struct_fields(named.named),
67 Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
68 Fields::Unit => quote! {
69 #[automatically_derived]
70 impl #impl_generics magic_space::Space for #name #ty_generics #where_clause {
71 const MAGIC_SPACE: usize = 0;
72 }
73 },
74 },
75 syn::Data::Enum(enm) => {
76 let variants = enm.variants.into_iter().map(|v| {
77 let len = v.fields.into_iter().map(|f| {
78 let mut max_len_args = get_max_len_args(&f.attrs);
79 len_from_type(f.ty, &mut max_len_args)
80 });
81
82 quote! {
83 0 #(+ #len)*
84 }
85 });
86
87 let max = gen_max(variants);
88
89 quote! {
90 #[automatically_derived]
91 impl magic_space::Space for #name {
92 const MAGIC_SPACE: usize = 1 + #max;
93 }
94 }
95 }
96 _ => unimplemented!(),
97 };
98
99 TokenStream::from(expanded)
100}
101
102fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
103 if let Some(item) = iter.next() {
104 let next_item = gen_max(iter);
105 quote!(magic_space::__private::max(#item, #next_item))
106 } else {
107 quote!(0)
108 }
109}
110
111fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
112 match ty {
113 Type::Array(TypeArray { elem, len, .. }) => {
114 let array_len = len.to_token_stream();
115 let type_len = len_from_type(*elem, attrs);
116 quote!((#array_len * #type_len))
117 }
118 Type::Path(ty_path) => {
119 let path_segment = ty_path.path.segments.last().unwrap();
120 let ident = &path_segment.ident;
121 let type_name = ident.to_string();
122 let first_ty = get_first_ty_arg(&path_segment.arguments);
123
124 match type_name.as_str() {
125 "i8" | "u8" | "bool" => quote!(1),
126 "i16" | "u16" => quote!(2),
127 "i32" | "u32" | "f32" => quote!(4),
128 "i64" | "u64" | "f64" => quote!(8),
129 "i128" | "u128" => quote!(16),
130 "String" => {
131 let max_len = get_next_arg(ident, attrs);
132 quote!((4 + #max_len))
133 }
134 "Pubkey" => quote!(32),
135 "Option" => {
136 if let Some(ty) = first_ty {
137 let type_len = len_from_type(ty, attrs);
138
139 quote!((1 + #type_len))
140 } else {
141 quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
142 }
143 }
144 "Vec" => {
145 if let Some(ty) = first_ty {
146 let max_len = get_next_arg(ident, attrs);
147 let type_len = len_from_type(ty, attrs);
148
149 quote!((4 + #type_len * #max_len))
150 } else {
151 quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
152 }
153 }
154 _ => {
155 let ty = &ty_path.path;
156 quote!(<#ty as magic_space::Space>::MAGIC_SPACE)
157 }
158 }
159 }
160 _ => panic!("Type {ty:?} is not supported"),
161 }
162}
163
164fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
165 match args {
166 PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
167 GenericArgument::Type(ty) => Some(ty.to_owned()),
168 _ => None,
169 }),
170 _ => None,
171 }
172}
173
174fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
175 let mut result = VecDeque::new();
176 while let Some(token_tree) = item.parse()? {
177 match token_tree {
178 TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
179 TokenTree::Literal(lit) => {
180 if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
181 result.push_front(quote!(#lit_int))
182 }
183 }
184 _ => (),
185 }
186 }
187
188 Ok(result)
189}
190
191fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
192 attributes
193 .iter()
194 .find(|a| a.path().is_ident("max_len"))
195 .and_then(|a| a.parse_args_with(parse_len_arg).ok())
196}
197
198fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
199 if let Some(arg_list) = args {
200 if let Some(arg) = arg_list.pop_back() {
201 quote!(#arg)
202 } else {
203 quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
204 }
205 } else {
206 quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
207 }
208}