mainstay_derive_space/
lib.rs1use std::collections::VecDeque;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7 parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
8 DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
9};
10
11#[proc_macro_derive(InitSpace, attributes(max_len))]
39pub fn derive_init_space(item: TokenStream) -> TokenStream {
40 let input = parse_macro_input!(item as DeriveInput);
41 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
42 let name = input.ident;
43
44 let process_struct_fields = |fields: Punctuated<Field, Comma>| {
45 let recurse = fields.into_iter().map(|f| {
46 let mut max_len_args = get_max_len_args(&f.attrs);
47 len_from_type(f.ty, &mut max_len_args)
48 });
49
50 quote! {
51 #[automatically_derived]
52 impl #impl_generics mainstay_lang::Space for #name #ty_generics #where_clause {
53 const INIT_SPACE: usize = 0 #(+ #recurse)*;
54 }
55 }
56 };
57
58 let expanded: TokenStream2 = match input.data {
59 syn::Data::Struct(strct) => match strct.fields {
60 Fields::Named(named) => process_struct_fields(named.named),
61 Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
62 Fields::Unit => quote! {
63 #[automatically_derived]
64 impl #impl_generics mainstay_lang::Space for #name #ty_generics #where_clause {
65 const INIT_SPACE: usize = 0;
66 }
67 },
68 },
69 syn::Data::Enum(enm) => {
70 let variants = enm.variants.into_iter().map(|v| {
71 let len = v.fields.into_iter().map(|f| {
72 let mut max_len_args = get_max_len_args(&f.attrs);
73 len_from_type(f.ty, &mut max_len_args)
74 });
75
76 quote! {
77 0 #(+ #len)*
78 }
79 });
80
81 let max = gen_max(variants);
82
83 quote! {
84 #[automatically_derived]
85 impl mainstay_lang::Space for #name {
86 const INIT_SPACE: usize = 1 + #max;
87 }
88 }
89 }
90 _ => unimplemented!(),
91 };
92
93 TokenStream::from(expanded)
94}
95
96fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
97 if let Some(item) = iter.next() {
98 let next_item = gen_max(iter);
99 quote!(mainstay_lang::__private::max(#item, #next_item))
100 } else {
101 quote!(0)
102 }
103}
104
105fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
106 match ty {
107 Type::Array(TypeArray { elem, len, .. }) => {
108 let array_len = len.to_token_stream();
109 let type_len = len_from_type(*elem, attrs);
110 quote!((#array_len * #type_len))
111 }
112 Type::Path(ty_path) => {
113 let path_segment = ty_path.path.segments.last().unwrap();
114 let ident = &path_segment.ident;
115 let type_name = ident.to_string();
116 let first_ty = get_first_ty_arg(&path_segment.arguments);
117
118 match type_name.as_str() {
119 "i8" | "u8" | "bool" => quote!(1),
120 "i16" | "u16" => quote!(2),
121 "i32" | "u32" | "f32" => quote!(4),
122 "i64" | "u64" | "f64" => quote!(8),
123 "i128" | "u128" => quote!(16),
124 "String" => {
125 let max_len = get_next_arg(ident, attrs);
126 quote!((4 + #max_len))
127 }
128 "Pubkey" => quote!(32),
129 "Option" => {
130 if let Some(ty) = first_ty {
131 let type_len = len_from_type(ty, attrs);
132
133 quote!((1 + #type_len))
134 } else {
135 quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
136 }
137 }
138 "Vec" => {
139 if let Some(ty) = first_ty {
140 let max_len = get_next_arg(ident, attrs);
141 let type_len = len_from_type(ty, attrs);
142
143 quote!((4 + #type_len * #max_len))
144 } else {
145 quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
146 }
147 }
148 _ => {
149 let ty = &ty_path.path;
150 quote!(<#ty as mainstay_lang::Space>::INIT_SPACE)
151 }
152 }
153 }
154 _ => panic!("Type {ty:?} is not supported"),
155 }
156}
157
158fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
159 match args {
160 PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
161 GenericArgument::Type(ty) => Some(ty.to_owned()),
162 _ => None,
163 }),
164 _ => None,
165 }
166}
167
168fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
169 let mut result = VecDeque::new();
170 while let Some(token_tree) = item.parse()? {
171 match token_tree {
172 TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
173 TokenTree::Literal(lit) => {
174 if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
175 result.push_front(quote!(#lit_int))
176 }
177 }
178 _ => (),
179 }
180 }
181
182 Ok(result)
183}
184
185fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
186 attributes
187 .iter()
188 .find(|a| a.path.is_ident("max_len"))
189 .and_then(|a| a.parse_args_with(parse_len_arg).ok())
190}
191
192fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
193 if let Some(arg_list) = args {
194 if let Some(arg) = arg_list.pop_back() {
195 quote!(#arg)
196 } else {
197 quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
198 }
199 } else {
200 quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
201 }
202}