init_space/
lib.rs

1//! # init_space_derive
2//!
3//! A procedural macro that automatically calculates the space required to store a struct
4//! in Solana programs (or Borsh-serialized data).
5//!
6//! - Computes size for fixed types like `u8`, `u64`, `bool`, etc.
7//! - Supports dynamic types like `String` and `Vec<u8>` by requiring a `#[max_len = N]` attribute.
8//!
9//! ## Example
10//!
11//! ```rust
12//! use init_space_derive::InitSpace;
13//! use borsh::{BorshSerialize, BorshDeserialize};
14//!
15//! #[derive(BorshSerialize, BorshDeserialize, InitSpace)]
16//! pub struct Movie {
17//!     pub title: u64,
18//!     pub rating: u8,
19//!     #[max_len = 1000]
20//!     pub description: Vec<u8>,
21//! }
22//!
23//! let space_needed = Movie::space();
24//! assert_eq!(space_needed, 13 + 1004); // 13 for fixed, 4+1000 for Vec
25//! ```
26
27
28
29
30use proc_macro::TokenStream;
31use quote::quote;
32use syn::{
33    parse_macro_input, DeriveInput, Data, Fields, Type, Attribute, Meta,
34    Expr, ExprLit, Lit,
35};
36
37#[proc_macro_derive(InitSpace, attributes(max_len))]
38pub fn derive_init_space(input: TokenStream) -> TokenStream {
39    let input = parse_macro_input!(input as DeriveInput);
40    let name = input.ident;
41
42    let mut total_size = quote! { 0 };
43
44    if let Data::Struct(data_struct) = input.data {
45        if let Fields::Named(fields_named) = data_struct.fields {
46            for field in &fields_named.named {
47                let ty = &field.ty;
48                let attrs = &field.attrs;
49
50                // If #[max_len = N] is provided, handle dynamic size
51                if let Some(max_len)= get_max_len(attrs) {
52                       // Only allow #[max_len = N] on Vec or String
53                    if is_dynamic_type(ty) {
54                      total_size = quote! { #total_size + 4 + #max_len };
55                        continue;
56                    } else {
57                      panic!("`#[max_len = N]` is only allowed on Vec or String types");
58                  }
59                }
60
61                // Otherwise, use fixed size or fallback to 4 for unknown
62                let field_size = if let Some(size) = get_fixed_size(ty) {
63                    quote! { #size }
64                } else {
65                    quote! { 4 }
66                };
67
68                total_size = quote! { #total_size + #field_size };
69            }
70        }
71    }
72
73    let expanded = quote! {
74        impl #name {
75            pub fn space() -> usize {
76                #total_size
77            }
78        }
79    };
80
81    TokenStream::from(expanded)
82}
83
84// Known fixed-size types
85fn get_fixed_size(ty: &Type) -> Option<usize> {
86    if let Type::Path(typepath) = ty {
87        //gets the last segment of the type path
88        let ident = &typepath.path.segments.last()?.ident;
89        match ident.to_string().as_str() {
90            "u8" => Some(1),
91            "u16" => Some(2),
92            "u32" => Some(4),
93            "u64" => Some(8),
94            "u128" => Some(16),
95            "i8" => Some(1),
96            "i16" => Some(2),
97            "i32" => Some(4),
98            "i64" => Some(8),
99            "i128" => Some(16),
100            "bool" => Some(1),
101            "Pubkey" => Some(32),
102            _ => None,
103        }
104    } else {
105        None
106    }
107}
108
109// Parses #[max_len = 1000]
110fn get_max_len(attrs: &[Attribute]) -> Option<usize> {
111    for attr in attrs {
112        if attr.path().is_ident("max_len") {
113            if let Ok(meta) = attr.meta.clone().require_name_value() {
114                if let Expr::Lit(ExprLit {
115                    lit: Lit::Int(lit_int),
116                    ..
117                }) = &meta.value
118                //here the meta.value is passed as arg to be used for lint_int
119                {
120                    return lit_int.base10_parse::<usize>().ok();
121                }
122            }
123        }
124    }
125    None
126}
127
128fn is_dynamic_type(ty: &Type) -> bool {
129    if let Type::Path(typepath) = ty {
130        //return the last segment of the the type
131        if let Some(segment) = typepath.path.segments.last() {
132            let ident = segment.ident.to_string();
133            return ident == "Vec" || ident == "String";
134        }
135    }
136    false
137}