Skip to main content

rialo_sol_derive_space/
lib.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::VecDeque;
5
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
8use quote::{quote, quote_spanned, ToTokens};
9use syn::{
10    parse::ParseStream, parse2, parse_macro_input, punctuated::Punctuated, token::Comma, Attribute,
11    DeriveInput, Field, Fields, GenericArgument, LitInt, PathArguments, Type, TypeArray,
12};
13
14/// Implements a [`Space`](./trait.Space.html) trait on the given
15/// struct or enum.
16///
17/// For types that have a variable size like String and Vec, it is necessary to indicate the size by the `max_len` attribute.
18/// For nested types, it is necessary to specify a size for each variable type (see example).
19///
20/// # Example
21/// ```ignore
22/// #[account]
23/// #[derive(InitSpace)]
24/// pub struct ExampleAccount {
25///     pub data: u64,
26///     #[max_len(50)]
27///     pub string_one: String,
28///     #[max_len(10, 5)]
29///     pub nested: Vec<Vec<u8>>,
30/// }
31///
32/// #[derive(Accounts)]
33/// pub struct Initialize<'info> {
34///    #[account(mut)]
35///    pub payer: Signer<'info>,
36///    pub system_program: Program<'info, System>,
37///    #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
38///    pub data: Account<'info, ExampleAccount>,
39/// }
40/// ```
41#[proc_macro_derive(InitSpace, attributes(max_len))]
42pub fn derive_init_space(item: TokenStream) -> TokenStream {
43    let input = parse_macro_input!(item as DeriveInput);
44    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
45    let name = input.ident;
46
47    let process_struct_fields = |fields: Punctuated<Field, Comma>| {
48        let recurse = fields.into_iter().map(|f| {
49            let mut max_len_args = get_max_len_args(&f.attrs);
50            len_from_type(f.ty, &mut max_len_args)
51        });
52
53        quote! {
54            #[automatically_derived]
55            impl #impl_generics rialo_sol_lang::Space for #name #ty_generics #where_clause {
56                const INIT_SPACE: usize = 0 #(+ #recurse)*;
57            }
58        }
59    };
60
61    let expanded: TokenStream2 = match input.data {
62        syn::Data::Struct(strct) => match strct.fields {
63            Fields::Named(named) => process_struct_fields(named.named),
64            Fields::Unnamed(unnamed) => process_struct_fields(unnamed.unnamed),
65            Fields::Unit => quote! {
66                #[automatically_derived]
67                impl #impl_generics rialo_sol_lang::Space for #name #ty_generics #where_clause {
68                    const INIT_SPACE: usize = 0;
69                }
70            },
71        },
72        syn::Data::Enum(enm) => {
73            let variants = enm.variants.into_iter().map(|v| {
74                let len = v.fields.into_iter().map(|f| {
75                    let mut max_len_args = get_max_len_args(&f.attrs);
76                    len_from_type(f.ty, &mut max_len_args)
77                });
78
79                quote! {
80                    0 #(+ #len)*
81                }
82            });
83
84            let max = gen_max(variants);
85
86            quote! {
87                #[automatically_derived]
88                impl rialo_sol_lang::Space for #name {
89                    const INIT_SPACE: usize = 1 + #max;
90                }
91            }
92        }
93        _ => unimplemented!(),
94    };
95
96    TokenStream::from(expanded)
97}
98
99fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
100    if let Some(item) = iter.next() {
101        let next_item = gen_max(iter);
102        quote!(rialo_sol_lang::__private::max(#item, #next_item))
103    } else {
104        quote!(0)
105    }
106}
107
108fn len_from_type(ty: Type, attrs: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
109    match ty {
110        Type::Array(TypeArray { elem, len, .. }) => {
111            let array_len = len.to_token_stream();
112            let type_len = len_from_type(*elem, attrs);
113            quote!((#array_len * #type_len))
114        }
115        Type::Path(ty_path) => {
116            let path_segment = ty_path.path.segments.last().unwrap();
117            let ident = &path_segment.ident;
118            let type_name = ident.to_string();
119            let first_ty = get_first_ty_arg(&path_segment.arguments);
120
121            match type_name.as_str() {
122                "i8" | "u8" | "bool" => quote!(1),
123                "i16" | "u16" => quote!(2),
124                "i32" | "u32" | "f32" => quote!(4),
125                "i64" | "u64" | "f64" => quote!(8),
126                "i128" | "u128" => quote!(16),
127                "String" => {
128                    let max_len = get_next_arg(ident, attrs);
129                    quote!((4 + #max_len))
130                }
131                "Pubkey" => quote!(32),
132                "Option" => {
133                    if let Some(ty) = first_ty {
134                        let type_len = len_from_type(ty, attrs);
135
136                        quote!((1 + #type_len))
137                    } else {
138                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Option"))
139                    }
140                }
141                "Vec" => {
142                    if let Some(ty) = first_ty {
143                        let max_len = get_next_arg(ident, attrs);
144                        let type_len = len_from_type(ty, attrs);
145
146                        quote!((4 + #type_len * #max_len))
147                    } else {
148                        quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
149                    }
150                }
151                _ => {
152                    let ty = &ty_path.path;
153                    quote!(<#ty as rialo_sol_lang::Space>::INIT_SPACE)
154                }
155            }
156        }
157        Type::Tuple(ty_tuple) => {
158            let recurse = ty_tuple
159                .elems
160                .iter()
161                .map(|t| len_from_type(t.clone(), attrs));
162            quote! {
163                (0 #(+ #recurse)*)
164            }
165        }
166        _ => panic!("Type {ty:?} is not supported"),
167    }
168}
169
170fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
171    match args {
172        PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
173            GenericArgument::Type(ty) => Some(ty.to_owned()),
174            _ => None,
175        }),
176        _ => None,
177    }
178}
179
180fn parse_len_arg(item: ParseStream<'_>) -> Result<VecDeque<TokenStream2>, syn::Error> {
181    let mut result = VecDeque::new();
182    while let Some(token_tree) = item.parse()? {
183        match token_tree {
184            TokenTree::Ident(ident) => result.push_front(quote!((#ident as usize))),
185            TokenTree::Literal(lit) => {
186                if let Ok(lit_int) = parse2::<LitInt>(lit.into_token_stream()) {
187                    result.push_front(quote!(#lit_int))
188                }
189            }
190            _ => (),
191        }
192    }
193
194    Ok(result)
195}
196
197fn get_max_len_args(attributes: &[Attribute]) -> Option<VecDeque<TokenStream2>> {
198    attributes
199        .iter()
200        .find(|a| a.path().is_ident("max_len"))
201        .and_then(|a| a.parse_args_with(parse_len_arg).ok())
202}
203
204fn get_next_arg(ident: &Ident, args: &mut Option<VecDeque<TokenStream2>>) -> TokenStream2 {
205    if let Some(arg_list) = args {
206        if let Some(arg) = arg_list.pop_back() {
207            quote!(#arg)
208        } else {
209            quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
210        }
211    } else {
212        quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
213    }
214}