magic_space_derive/
lib.rs

1/**
2 * Code taken from @coral-xyz/anchor project,
3 * modified to work without the anchor crate.
4 *
5 * Source:
6 * https://github.com/coral-xyz/anchor/blob/master/lang/derive/space/src/lib.rs
7 */
8use std::collections::VecDeque;
9
10use proc_macro::TokenStream;
11use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
12use quote::{quote, quote_spanned, ToTokens};
13use syn::{
14    parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
15    DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
16};
17
18/// Implements a `Space` trait for a struct or enum.
19///
20/// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
21/// For nested types, it is necessary to specify a size for each variable type (see example).
22///
23/// # Example
24/// ```ignore
25/// #[account]
26/// #[derive(MagicSpace)]
27/// pub struct ExampleAccount {
28///     pub data: u64,
29///     #[max_len(50)]
30///     pub string_one: String,
31///     #[max_len(10, 5)]
32///     pub nested: Vec<Vec<u8>>,
33/// }
34///
35/// #[derive(Accounts)]
36/// pub struct Initialize<'info> {
37///    #[account(mut)]
38///    pub payer: Signer<'info>,
39///    pub system_program: Program<'info, System>,
40///    #[account(init, payer = payer, space = 8 + ExampleAccount::MAGIC_SPACE)]
41///    pub data: Account<'info, ExampleAccount>,
42/// }
43/// ```
44#[proc_macro_derive(MagicSpace, attributes(max_len))]
45pub fn derive_init_space(item: TokenStream) -> TokenStream {
46    let input = parse_macro_input!(item as DeriveInput);
47    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
48    let name = input.ident;
49
50    let process_struct_fields = |fields: Punctuated<Field, Comma>| {
51        let recurse = fields.into_iter().map(|f| {
52            let mut max_len_args = get_max_len_args(&f.attrs);
53            len_from_type(f.ty, &mut max_len_args)
54        });
55
56        quote! {
57            #[automatically_derived]
58            impl #impl_generics magic_space::Space for #name #ty_generics #where_clause {
59                const MAGIC_SPACE: usize = 0 #(+ #recurse)*;
60            }
61        }
62    };
63
64    let expanded: TokenStream2 = match input.data {
65        syn::Data::Struct(strct) => match strct.fields {
66            Fields::Named(named) => process_struct_fields(named.named),
67            Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
68            Fields::Unit => quote! {
69                #[automatically_derived]
70                impl #impl_generics magic_space::Space for #name #ty_generics #where_clause {
71                    const MAGIC_SPACE: usize = 0;
72                }
73            },
74        },
75        syn::Data::Enum(enm) => {
76            let variants = enm.variants.into_iter().map(|v| {
77                let len = v.fields.into_iter().map(|f| {
78                    let mut max_len_args = get_max_len_args(&f.attrs);
79                    len_from_type(f.ty, &mut max_len_args)
80                });
81
82                quote! {
83                    0 #(+ #len)*
84                }
85            });
86
87            let max = gen_max(variants);
88
89            quote! {
90                #[automatically_derived]
91                impl magic_space::Space for #name {
92                    const MAGIC_SPACE: usize = 1 + #max;
93                }
94            }
95        }
96        _ => unimplemented!(),
97    };
98
99    TokenStream::from(expanded)
100}
101
102fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
103    if let Some(item) = iter.next() {
104        let next_item = gen_max(iter);
105        quote!(magic_space::__private::max(#item, #next_item))
106    } else {
107        quote!(0)
108    }
109}
110
111fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
112    match ty {
113        Type::Array(TypeArray { elem, len, .. }) => {
114            let array_len = len.to_token_stream();
115            let type_len = len_from_type(*elem, attrs);
116            quote!((#array_len * #type_len))
117        }
118        Type::Path(ty_path) => {
119            let path_segment = ty_path.path.segments.last().unwrap();
120            let ident = &path_segment.ident;
121            let type_name = ident.to_string();
122            let first_ty = get_first_ty_arg(&path_segment.arguments);
123
124            match type_name.as_str() {
125                "i8" | "u8" | "bool" => quote!(1),
126                "i16" | "u16" => quote!(2),
127                "i32" | "u32" | "f32" => quote!(4),
128                "i64" | "u64" | "f64" => quote!(8),
129                "i128" | "u128" => quote!(16),
130                "String" => {
131                    let max_len = get_next_arg(ident, attrs);
132                    quote!((4 + #max_len))
133                }
134                "Pubkey" => quote!(32),
135                "Option" => {
136                    if let Some(ty) = first_ty {
137                        let type_len = len_from_type(ty, attrs);
138
139                        quote!((1 + #type_len))
140                    } else {
141                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
142                    }
143                }
144                "Vec" => {
145                    if let Some(ty) = first_ty {
146                        let max_len = get_next_arg(ident, attrs);
147                        let type_len = len_from_type(ty, attrs);
148
149                        quote!((4 + #type_len * #max_len))
150                    } else {
151                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
152                    }
153                }
154                _ => {
155                    let ty = &ty_path.path;
156                    quote!(<#ty as magic_space::Space>::MAGIC_SPACE)
157                }
158            }
159        }
160        _ => panic!("Type {ty:?} is not supported"),
161    }
162}
163
164fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
165    match args {
166        PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
167            GenericArgument::Type(ty) => Some(ty.to_owned()),
168            _ => None,
169        }),
170        _ => None,
171    }
172}
173
174fn parse_len_arg(item: ParseStream) -> Result<VecDeque<TokenStream2>, syn::Error> {
175    let mut result = VecDeque::new();
176    while let Some(token_tree) = item.parse()? {
177        match token_tree {
178            TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
179            TokenTree::Literal(lit) => {
180                if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
181                    result.push_front(quote!(#lit_int))
182                }
183            }
184            _ => (),
185        }
186    }
187
188    Ok(result)
189}
190
191fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
192    attributes
193        .iter()
194        .find(|a| a.path().is_ident("max_len"))
195        .and_then(|a| a.parse_args_with(parse_len_arg).ok())
196}
197
198fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
199    if let Some(arg_list) = args {
200        if let Some(arg) = arg_list.pop_back() {
201            quote!(#arg)
202        } else {
203            quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
204        }
205    } else {
206        quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
207    }
208}