rialo_sol_derive_space/
lib.rs1use std::collections::VecDeque;
5
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
8use quote::{quote, quote_spanned, ToTokens};
9use syn::{
10 parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
11 DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
12};
13
14#[proc_macro_derive(InitSpace, attributes(max_len))]
42pub fn derive_init_space(item: TokenStream) -> TokenStream {
43 let input = parse_macro_input!(item as DeriveInput);
44 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
45 let name = input.ident;
46
47 let process_struct_fields = |fields: Punctuated<Field, Comma>| {
48 let recurse = fields.into_iter().map(|f| {
49 let mut max_len_args = get_max_len_args(&f.attrs);
50 len_from_type(f.ty, &mut max_len_args)
51 });
52
53 quote! {
54 #[automatically_derived]
55 impl #impl_generics rialo_sol_lang::Space for #name #ty_generics #where_clause {
56 const INIT_SPACE: usize = 0 #(+ #recurse)*;
57 }
58 }
59 };
60
61 let expanded: TokenStream2 = match input.data {
62 syn::Data::Struct(strct) => match strct.fields {
63 Fields::Named(named) => process_struct_fields(named.named),
64 Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
65 Fields::Unit => quote! {
66 #[automatically_derived]
67 impl #impl_generics rialo_sol_lang::Space for #name #ty_generics #where_clause {
68 const INIT_SPACE: usize = 0;
69 }
70 },
71 },
72 syn::Data::Enum(enm) => {
73 let variants = enm.variants.into_iter().map(|v| {
74 let len = v.fields.into_iter().map(|f| {
75 let mut max_len_args = get_max_len_args(&f.attrs);
76 len_from_type(f.ty, &mut max_len_args)
77 });
78
79 quote! {
80 0 #(+ #len)*
81 }
82 });
83
84 let max = gen_max(variants);
85
86 quote! {
87 #[automatically_derived]
88 impl rialo_sol_lang::Space for #name {
89 const INIT_SPACE: usize = 1 + #max;
90 }
91 }
92 }
93 _ => unimplemented!(),
94 };
95
96 TokenStream::from(expanded)
97}
98
99fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
100 if let Some(item) = iter.next() {
101 let next_item = gen_max(iter);
102 quote!(rialo_sol_lang::__private::max(#item, #next_item))
103 } else {
104 quote!(0)
105 }
106}
107
108fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
109 match ty {
110 Type::Array(TypeArray { elem, len, .. }) => {
111 let array_len = len.to_token_stream();
112 let type_len = len_from_type(*elem, attrs);
113 quote!((#array_len * #type_len))
114 }
115 Type::Path(ty_path) => {
116 let path_segment = ty_path.path.segments.last().unwrap();
117 let ident = &path_segment.ident;
118 let type_name = ident.to_string();
119 let first_ty = get_first_ty_arg(&path_segment.arguments);
120
121 match type_name.as_str() {
122 "i8" | "u8" | "bool" => quote!(1),
123 "i16" | "u16" => quote!(2),
124 "i32" | "u32" | "f32" => quote!(4),
125 "i64" | "u64" | "f64" => quote!(8),
126 "i128" | "u128" => quote!(16),
127 "String" => {
128 let max_len = get_next_arg(ident, attrs);
129 quote!((4 + #max_len))
130 }
131 "Pubkey" => quote!(32),
132 "Option" => {
133 if let Some(ty) = first_ty {
134 let type_len = len_from_type(ty, attrs);
135
136 quote!((1 + #type_len))
137 } else {
138 quote_spanned!(ident.span() => compile_error!("Invalid argument in Option"))
139 }
140 }
141 "Vec" => {
142 if let Some(ty) = first_ty {
143 let max_len = get_next_arg(ident, attrs);
144 let type_len = len_from_type(ty, attrs);
145
146 quote!((4 + #type_len * #max_len))
147 } else {
148 quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
149 }
150 }
151 _ => {
152 let ty = &ty_path.path;
153 quote!(<#ty as rialo_sol_lang::Space>::INIT_SPACE)
154 }
155 }
156 }
157 Type::Tuple(ty_tuple) => {
158 let recurse = ty_tuple
159 .elems
160 .iter()
161 .map(|t| len_from_type(t.clone(), attrs));
162 quote! {
163 (0 #(+ #recurse)*)
164 }
165 }
166 _ => panic!("Type {ty:?} is not supported"),
167 }
168}
169
170fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
171 match args {
172 PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
173 GenericArgument::Type(ty) => Some(ty.to_owned()),
174 _ => None,
175 }),
176 _ => None,
177 }
178}
179
180fn parse_len_arg(item: ParseStream<'_>) -> Result<VecDeque<TokenStream2>, syn::Error> {
181 let mut result = VecDeque::new();
182 while let Some(token_tree) = item.parse()? {
183 match token_tree {
184 TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
185 TokenTree::Literal(lit) => {
186 if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
187 result.push_front(quote!(#lit_int))
188 }
189 }
190 _ => (),
191 }
192 }
193
194 Ok(result)
195}
196
197fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
198 attributes
199 .iter()
200 .find(|a| a.path().is_ident("max_len"))
201 .and_then(|a| a.parse_args_with(parse_len_arg).ok())
202}
203
204fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
205 if let Some(arg_list) = args {
206 if let Some(arg) = arg_list.pop_back() {
207 quote!(#arg)
208 } else {
209 quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
210 }
211 } else {
212 quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
213 }
214}