bolt_anchor_derive_space/
lib.rs

1use std::collections::VecDeque;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7    parse::ParseStream, parse2, parse_macro_input, Attribute, DeriveInput, Fields, GenericArgument,
8    LitInt, PathArguments, Type, TypeArray,
9};
10
11/// Implements a [`Space`](./trait.Space.html) trait on the given
12/// struct or enum.
13///
14/// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
15/// For nested types, it is necessary to specify a size for each variable type (see example).
16///
17/// # Example
18/// ```ignore
19/// #[account]
20/// #[derive(InitSpace)]
21/// pub struct ExampleAccount {
22///     pub data: u64,
23///     #[max_len(50)]
24///     pub string_one: String,
25///     #[max_len(10, 5)]
26///     pub nested: Vec<Vec<u8>>,
27/// }
28///
29/// #[derive(Accounts)]
30/// pub struct Initialize<'info> {
31///    #[account(mut)]
32///    pub payer: Signer<'info>,
33///    pub system_program: Program<'info, System>,
34///    #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
35///    pub data: Account<'info, ExampleAccount>,
36/// }
37/// ```
38#[proc_macro_derive(InitSpace, attributes(max_len))]
39pub fn derive_init_space(item: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(item as DeriveInput);
41    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
42    let name = input.ident;
43
44    let expanded: TokenStream2 = match input.data {
45        syn::Data::Struct(strct) => match strct.fields {
46            Fields::Named(named) => {
47                let recurse = named.named.into_iter().map(|f| {
48                    let mut max_len_args = get_max_len_args(&f.attrs);
49                    len_from_type(f.ty, &mut max_len_args)
50                });
51
52                quote! {
53                    #[automatically_derived]
54                    impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
55                        const INIT_SPACE: usize = 0 #(+ #recurse)*;
56                    }
57                }
58            }
59            _ => panic!("Please use named fields in account structure"),
60        },
61        syn::Data::Enum(enm) => {
62            let variants = enm.variants.into_iter().map(|v| {
63                let len = v.fields.into_iter().map(|f| {
64                    let mut max_len_args = get_max_len_args(&f.attrs);
65                    len_from_type(f.ty, &mut max_len_args)
66                });
67
68                quote! {
69                    0 #(+ #len)*
70                }
71            });
72
73            let max = gen_max(variants);
74
75            quote! {
76                #[automatically_derived]
77                impl anchor_lang::Space for #name {
78                    const INIT_SPACE: usize = 1 + #max;
79                }
80            }
81        }
82        _ => unimplemented!(),
83    };
84
85    TokenStream::from(expanded)
86}
87
88fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
89    if let Some(item) = iter.next() {
90        let next_item = gen_max(iter);
91        quote!(anchor_lang::__private::max(#item, #next_item))
92    } else {
93        quote!(0)
94    }
95}
96
97fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
98    match ty {
99        Type::Array(TypeArray { elem, len, .. }) => {
100            let array_len = len.to_token_stream();
101            let type_len = len_from_type(*elem, attrs);
102            quote!((#array_len * #type_len))
103        }
104        Type::Path(ty_path) => {
105            let path_segment = ty_path.path.segments.last().unwrap();
106            let ident = &path_segment.ident;
107            let type_name = ident.to_string();
108            let first_ty = get_first_ty_arg(&path_segment.arguments);
109
110            match type_name.as_str() {
111                "i8" | "u8" | "bool" => quote!(1),
112                "i16" | "u16" => quote!(2),
113                "i32" | "u32" | "f32" => quote!(4),
114                "i64" | "u64" | "f64" => quote!(8),
115                "i128" | "u128" => quote!(16),
116                "String" => {
117                    let max_len = get_next_arg(ident, attrs);
118                    quote!((4 + #max_len))
119                }
120                "Pubkey" => quote!(32),
121                "Option" => {
122                    if let Some(ty) = first_ty {
123                        let type_len = len_from_type(ty, attrs);
124
125                        quote!((1 + #type_len))
126                    } else {
127                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
128                    }
129                }
130                "Vec" => {
131                    if let Some(ty) = first_ty {
132                        let max_len = get_next_arg(ident, attrs);
133                        let type_len = len_from_type(ty, attrs);
134
135                        quote!((4 + #type_len * #max_len))
136                    } else {
137                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
138                    }
139                }
140                _ => {
141                    let ty = &ty_path.path;
142                    quote!(<#ty as anchor_lang::Space>::INIT_SPACE)
143                }
144            }
145        }
146        _ => panic!("Type {ty:?} is not supported"),
147    }
148}
149
150fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
151    match args {
152        PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
153            GenericArgument::Type(ty) => Some(ty.to_owned()),
154            _ => None,
155        }),
156        _ => None,
157    }
158}
159
160fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
161    let mut result = VecDeque::new();
162    while let Some(token_tree) = item.parse()? {
163        match token_tree {
164            TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
165            TokenTree::Literal(lit) => {
166                if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
167                    result.push_front(quote!(#lit_int))
168                }
169            }
170            _ => (),
171        }
172    }
173
174    Ok(result)
175}
176
177fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
178    attributes
179        .iter()
180        .find(|a| a.path.is_ident("max_len"))
181        .and_then(|a| a.parse_args_with(parse_len_arg).ok())
182}
183
184fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
185    if let Some(arg_list) = args {
186        if let Some(arg) = arg_list.pop_back() {
187            quote!(#arg)
188        } else {
189            quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
190        }
191    } else {
192        quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
193    }
194}