mainstay_derive_space/
lib.rs

1use std::collections::VecDeque;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7    parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
8    DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
9};
10
11/// Implements a [`Space`](./trait.Space.html) trait on the given
12/// struct or enum.
13///
14/// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
15/// For nested types, it is necessary to specify a size for each variable type (see example).
16///
17/// # Example
18/// ```ignore
19/// #[account]
20/// #[derive(InitSpace)]
21/// pub struct ExampleAccount {
22///     pub data: u64,
23///     #[max_len(50)]
24///     pub string_one: String,
25///     #[max_len(10, 5)]
26///     pub nested: Vec<Vec<u8>>,
27/// }
28///
29/// #[derive(Accounts)]
30/// pub struct Initialize<'info> {
31///    #[account(mut)]
32///    pub payer: Signer<'info>,
33///    pub system_program: Program<'info, System>,
34///    #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
35///    pub data: Account<'info, ExampleAccount>,
36/// }
37/// ```
38#[proc_macro_derive(InitSpace, attributes(max_len))]
39pub fn derive_init_space(item: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(item as DeriveInput);
41    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
42    let name = input.ident;
43
44    let process_struct_fields = |fields: Punctuated<Field, Comma>| {
45        let recurse = fields.into_iter().map(|f| {
46            let mut max_len_args = get_max_len_args(&f.attrs);
47            len_from_type(f.ty, &mut max_len_args)
48        });
49
50        quote! {
51            #[automatically_derived]
52            impl #impl_generics mainstay_lang::Space for #name #ty_generics #where_clause {
53                const INIT_SPACE: usize = 0 #(+ #recurse)*;
54            }
55        }
56    };
57
58    let expanded: TokenStream2 = match input.data {
59        syn::Data::Struct(strct) => match strct.fields {
60            Fields::Named(named) => process_struct_fields(named.named),
61            Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
62            Fields::Unit => quote! {
63                #[automatically_derived]
64                impl #impl_generics mainstay_lang::Space for #name #ty_generics #where_clause {
65                    const INIT_SPACE: usize = 0;
66                }
67            },
68        },
69        syn::Data::Enum(enm) => {
70            let variants = enm.variants.into_iter().map(|v| {
71                let len = v.fields.into_iter().map(|f| {
72                    let mut max_len_args = get_max_len_args(&f.attrs);
73                    len_from_type(f.ty, &mut max_len_args)
74                });
75
76                quote! {
77                    0 #(+ #len)*
78                }
79            });
80
81            let max = gen_max(variants);
82
83            quote! {
84                #[automatically_derived]
85                impl mainstay_lang::Space for #name {
86                    const INIT_SPACE: usize = 1 + #max;
87                }
88            }
89        }
90        _ => unimplemented!(),
91    };
92
93    TokenStream::from(expanded)
94}
95
96fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
97    if let Some(item) = iter.next() {
98        let next_item = gen_max(iter);
99        quote!(mainstay_lang::__private::max(#item, #next_item))
100    } else {
101        quote!(0)
102    }
103}
104
105fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
106    match ty {
107        Type::Array(TypeArray { elem, len, .. }) => {
108            let array_len = len.to_token_stream();
109            let type_len = len_from_type(*elem, attrs);
110            quote!((#array_len * #type_len))
111        }
112        Type::Path(ty_path) => {
113            let path_segment = ty_path.path.segments.last().unwrap();
114            let ident = &path_segment.ident;
115            let type_name = ident.to_string();
116            let first_ty = get_first_ty_arg(&path_segment.arguments);
117
118            match type_name.as_str() {
119                "i8" | "u8" | "bool" => quote!(1),
120                "i16" | "u16" => quote!(2),
121                "i32" | "u32" | "f32" => quote!(4),
122                "i64" | "u64" | "f64" => quote!(8),
123                "i128" | "u128" => quote!(16),
124                "String" => {
125                    let max_len = get_next_arg(ident, attrs);
126                    quote!((4 + #max_len))
127                }
128                "Pubkey" => quote!(32),
129                "Option" => {
130                    if let Some(ty) = first_ty {
131                        let type_len = len_from_type(ty, attrs);
132
133                        quote!((1 + #type_len))
134                    } else {
135                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
136                    }
137                }
138                "Vec" => {
139                    if let Some(ty) = first_ty {
140                        let max_len = get_next_arg(ident, attrs);
141                        let type_len = len_from_type(ty, attrs);
142
143                        quote!((4 + #type_len * #max_len))
144                    } else {
145                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
146                    }
147                }
148                _ => {
149                    let ty = &ty_path.path;
150                    quote!(<#ty as mainstay_lang::Space>::INIT_SPACE)
151                }
152            }
153        }
154        _ => panic!("Type {ty:?} is not supported"),
155    }
156}
157
158fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
159    match args {
160        PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
161            GenericArgument::Type(ty) => Some(ty.to_owned()),
162            _ => None,
163        }),
164        _ => None,
165    }
166}
167
168fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
169    let mut result = VecDeque::new();
170    while let Some(token_tree) = item.parse()? {
171        match token_tree {
172            TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
173            TokenTree::Literal(lit) => {
174                if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
175                    result.push_front(quote!(#lit_int))
176                }
177            }
178            _ => (),
179        }
180    }
181
182    Ok(result)
183}
184
185fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
186    attributes
187        .iter()
188        .find(|a| a.path.is_ident("max_len"))
189        .and_then(|a| a.parse_args_with(parse_len_arg).ok())
190}
191
192fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
193    if let Some(arg_list) = args {
194        if let Some(arg) = arg_list.pop_back() {
195            quote!(#arg)
196        } else {
197            quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
198        }
199    } else {
200        quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
201    }
202}