anchor_derive_space/
lib.rs

1use std::collections::VecDeque;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7    parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
8    DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
9};
10
11/// Implements a [`Space`](./trait.Space.html) trait on the given
12/// struct or enum.
13///
14/// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
15/// For nested types, it is necessary to specify a size for each variable type (see example).
16///
17/// # Example
18/// ```ignore
19/// #[account]
20/// #[derive(InitSpace)]
21/// pub struct ExampleAccount {
22///     pub data: u64,
23///     #[max_len(50)]
24///     pub string_one: String,
25///     #[max_len(10, 5)]
26///     pub nested: Vec<Vec<u8>>,
27/// }
28///
29/// #[derive(Accounts)]
30/// pub struct Initialize<'info> {
31///    #[account(mut)]
32///    pub payer: Signer<'info>,
33///    pub system_program: Program<'info, System>,
34///    #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
35///    pub data: Account<'info, ExampleAccount>,
36/// }
37/// ```
38#[proc_macro_derive(InitSpace, attributes(max_len))]
39pub fn derive_init_space(item: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(item as DeriveInput);
41    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
42    let name = input.ident;
43
44    let process_struct_fields = |fields: Punctuated<Field, Comma>| {
45        let recurse = fields.into_iter().map(|f| {
46            let mut max_len_args = get_max_len_args(&f.attrs);
47            len_from_type(f.ty, &mut max_len_args)
48        });
49
50        quote! {
51            #[automatically_derived]
52            impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
53                const INIT_SPACE: usize = 0 #(+ #recurse)*;
54            }
55        }
56    };
57
58    let expanded: TokenStream2 = match input.data {
59        syn::Data::Struct(strct) => match strct.fields {
60            Fields::Named(named) => process_struct_fields(named.named),
61            Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
62            Fields::Unit => quote! {
63                #[automatically_derived]
64                impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
65                    const INIT_SPACE: usize = 0;
66                }
67            },
68        },
69        syn::Data::Enum(enm) => {
70            let variants = enm.variants.into_iter().map(|v| {
71                let len = v.fields.into_iter().map(|f| {
72                    let mut max_len_args = get_max_len_args(&f.attrs);
73                    len_from_type(f.ty, &mut max_len_args)
74                });
75
76                quote! {
77                    0 #(+ #len)*
78                }
79            });
80
81            let max = gen_max(variants);
82
83            quote! {
84                #[automatically_derived]
85                impl anchor_lang::Space for #name {
86                    const INIT_SPACE: usize = 1 + #max;
87                }
88            }
89        }
90        _ => unimplemented!(),
91    };
92
93    TokenStream::from(expanded)
94}
95
96fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
97    if let Some(item) = iter.next() {
98        let next_item = gen_max(iter);
99        quote!(anchor_lang::__private::max(#item, #next_item))
100    } else {
101        quote!(0)
102    }
103}
104
105fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
106    match ty {
107        Type::Array(TypeArray { elem, len, .. }) => {
108            let array_len = len.to_token_stream();
109            let type_len = len_from_type(*elem, attrs);
110            quote!((#array_len * #type_len))
111        }
112        Type::Path(ty_path) => {
113            let path_segment = ty_path.path.segments.last().unwrap();
114            let ident = &path_segment.ident;
115            let type_name = ident.to_string();
116            let first_ty = get_first_ty_arg(&path_segment.arguments);
117
118            match type_name.as_str() {
119                "i8" | "u8" | "bool" => quote!(1),
120                "i16" | "u16" => quote!(2),
121                "i32" | "u32" | "f32" => quote!(4),
122                "i64" | "u64" | "f64" => quote!(8),
123                "i128" | "u128" => quote!(16),
124                "String" => {
125                    let max_len = get_next_arg(ident, attrs);
126                    quote!((4 + #max_len))
127                }
128                "Pubkey" => quote!(32),
129                "Option" => {
130                    if let Some(ty) = first_ty {
131                        let type_len = len_from_type(ty, attrs);
132
133                        quote!((1 + #type_len))
134                    } else {
135                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Option"))
136                    }
137                }
138                "Vec" => {
139                    if let Some(ty) = first_ty {
140                        let max_len = get_next_arg(ident, attrs);
141                        let type_len = len_from_type(ty, attrs);
142
143                        quote!((4 + #type_len * #max_len))
144                    } else {
145                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
146                    }
147                }
148                _ => {
149                    let ty = &ty_path.path;
150                    quote!(<#ty as anchor_lang::Space>::INIT_SPACE)
151                }
152            }
153        }
154        Type::Tuple(ty_tuple) => {
155            let recurse = ty_tuple
156                .elems
157                .iter()
158                .map(|t| len_from_type(t.clone(), attrs));
159            quote! {
160                (0 #(+ #recurse)*)
161            }
162        }
163        _ => panic!("Type {ty:?} is not supported"),
164    }
165}
166
167fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
168    match args {
169        PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
170            GenericArgument::Type(ty) => Some(ty.to_owned()),
171            _ => None,
172        }),
173        _ => None,
174    }
175}
176
177fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
178    let mut result = VecDeque::new();
179    while let Some(token_tree) = item.parse()? {
180        match token_tree {
181            TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
182            TokenTree::Literal(lit) => {
183                if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
184                    result.push_front(quote!(#lit_int))
185                }
186            }
187            _ => (),
188        }
189    }
190
191    Ok(result)
192}
193
194fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
195    attributes
196        .iter()
197        .find(|a| a.path.is_ident("max_len"))
198        .and_then(|a| a.parse_args_with(parse_len_arg).ok())
199}
200
201fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
202    if let Some(arg_list) = args {
203        if let Some(arg) = arg_list.pop_back() {
204            quote!(#arg)
205        } else {
206            quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
207        }
208    } else {
209        quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
210    }
211}