Skip to main content

pinocchio_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{
5    parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Data, DeriveInput,
6    Field, Ident, Meta, Token,
7};
8
9/// Generates a trait implementation for `DataLen`:
10///
11/// ```rust
12/// impl DataLen for MyStruct {
13///     pub const LEN: usize = core::mem::size_of::<MyStruct>();
14/// }
15/// ```
16#[proc_macro_derive(DataLen)]
17pub fn derive_data_len(input: TokenStream) -> TokenStream {
18    let input = parse_macro_input!(input as DeriveInput);
19    let name = input.ident;
20
21    let expanded = quote! {
22        impl pinocchio_util::DataLen for #name {
23            const LEN: usize = core::mem::size_of::<#name>();
24        }
25    };
26
27    TokenStream::from(expanded)
28}
29
30/// Generates an update enum and trait implementation for `AccountUpdates`:
31///
32/// ```rust
33/// pub enum MyStructUpdate {
34///     SetField1(u32),
35///     SetField2(u32),
36/// }
37///
38/// impl AccountUpdates for MyStruct {
39///     type Update = MyStructUpdate;
40///     fn updates(&mut self, updates: Self::Update) {
41///         match updates {
42///             MyStructUpdate::SetField1(value) => self.field1 = value,
43///             MyStructUpdate::SetField2(value) => self.field2 = value,
44///         }
45///     }
46/// }
47/// ```
48#[proc_macro_derive(Updates)]
49pub fn derive_updates(input: TokenStream) -> TokenStream {
50    let input = parse_macro_input!(input as DeriveInput);
51    let name = input.ident;
52    let update_enum_name = Ident::new(&format!("{}Update", name), name.span());
53
54    let fields = match input.data {
55        Data::Struct(data) => data.fields,
56        _ => panic!("Updates derive macro only supports structs"),
57    };
58
59    let field_variants: Vec<_> = fields
60        .iter()
61        .enumerate()
62        .map(|(_i, field)| {
63            let field_name = field.ident.as_ref().unwrap();
64            let _field_type = &field.ty;
65            let variant_name = Ident::new(
66                &format!(
67                    "Set{}",
68                    field_name
69                        .to_string()
70                        .chars()
71                        .next()
72                        .unwrap()
73                        .to_uppercase()
74                        .chain(field_name.to_string().chars().skip(1))
75                        .collect::<String>()
76                ),
77                field_name.span(),
78            );
79
80            quote! {
81                #variant_name(#_field_type)
82            }
83        })
84        .collect();
85
86    let match_arms: Vec<_> = fields
87        .iter()
88        .enumerate()
89        .map(|(_i, field)| {
90            let field_name = field.ident.as_ref().unwrap();
91            let _field_type = &field.ty;
92            let variant_name = Ident::new(
93                &format!(
94                    "Set{}",
95                    field_name
96                        .to_string()
97                        .chars()
98                        .next()
99                        .unwrap()
100                        .to_uppercase()
101                        .chain(field_name.to_string().chars().skip(1))
102                        .collect::<String>()
103                ),
104                field_name.span(),
105            );
106
107            quote! {
108                #update_enum_name::#variant_name(value) => self.#field_name = value,
109            }
110        })
111        .collect();
112
113    let expanded = quote! {
114        pub enum #update_enum_name {
115            #(#field_variants),*
116        }
117
118        impl pinocchio_util::AccountUpdates for #name {
119            type Update = #update_enum_name;
120
121            fn updates(&mut self, updates: Self::Update) -> Result<(), pinocchio::program_error::ProgramError> {
122                match updates {
123                    #(#match_arms)*
124                }
125                Ok(())
126            }
127        }
128    };
129
130    TokenStream::from(expanded)
131}
132
133struct ValidationAttr {
134    non_empty: bool,
135    is_signer: bool,
136    is_executable: bool,
137    len: Option<usize>,
138    id: Option<syn::Expr>,
139}
140
141impl Parse for ValidationAttr {
142    fn parse(input: ParseStream) -> syn::Result<Self> {
143        let mut non_empty = false;
144        let mut len = None;
145        let mut id = None;
146        let mut is_signer = false;
147        let mut is_executable = false;
148
149        let args = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
150
151        for arg in args {
152            match arg {
153                Meta::Path(path) => {
154                    if path.is_ident("non_empty") {
155                        non_empty = true;
156                    }
157                }
158                Meta::NameValue(name_value) => {
159                    if name_value.path.is_ident("len") {
160                        if let syn::Expr::Lit(syn::ExprLit {
161                            lit: syn::Lit::Int(lit_int),
162                            ..
163                        }) = &name_value.value
164                        {
165                            len = Some(lit_int.base10_parse()?);
166                        }
167                    } else if name_value.path.is_ident("id") {
168                        id = Some(name_value.value);
169                    }
170                }
171                _ => {}
172            }
173        }
174
175        Ok(ValidationAttr {
176            non_empty,
177            len,
178            id,
179            is_signer,
180            is_executable,
181        })
182    }
183}
184
185/// Generates an implementation for `Validate`:
186///
187/// ```rust
188/// pub trait Validate {
189///     fn validate(&self) -> Result<(), ProgramError>;
190/// }
191///
192/// impl Validate for MyStruct {
193///     fn validate(&self) -> Result<(), ProgramError> {
194///         // Validations here
195///         Ok(())
196///     }
197/// }
198/// ```
199///
200/// Example usage:
201///
202/// ```rust
203/// #[derive(Validate)]
204/// struct MyStruct {
205///     // Data length is non-zero, `field_1.key()` is the SYSTEM_PROGRAM_ID (Pubkey)
206///     #[validate(non_empty, id = SYSTEM_PROGRAM_ID)]
207///     field_1: &'a AccountInfo,
208///
209///     // Data length is 64, `field_2.key()` is the SOME_ID (Pubkey)
210///     #[validate(len = 64, id = SOME_ID)]
211///     field_2: &'a AccountInfo,
212/// }
213/// ```
214#[proc_macro_derive(Validate, attributes(validate))]
215pub fn derive_validate(input: TokenStream) -> TokenStream {
216    let input = parse_macro_input!(input as DeriveInput);
217    let name = input.ident;
218
219    let fields = match input.data {
220        Data::Struct(data) => data.fields,
221        _ => panic!("This macro only supports structs"),
222    };
223
224    let validation_checks: Vec<_> = fields
225        .iter()
226        .enumerate()
227        .map(|(_i, field)| {
228            let field_name = field.ident.as_ref().unwrap();
229
230            let mut validation_attr = None;
231            for attr in &field.attrs {
232                if attr.path().is_ident("validate") {
233                    validation_attr = Some(attr.parse_args::<ValidationAttr>().unwrap());
234                    break;
235                }
236            }
237
238            if let Some(attr) = validation_attr {
239                let mut checks = Vec::new();
240
241                if attr.non_empty {
242                    checks.push(quote! {
243                        if self.#field_name.data_len() == 0 {
244                            return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
245                        }
246                    });
247                }
248
249                if attr.is_signer {
250                    checks.push(quote! {
251                        if !self.#field_name.is_signer() {
252                            return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
253                        }
254                    });
255                }
256
257                if attr.is_executable {
258                    checks.push(quote! {
259                        if !self.#field_name.is_executable() {
260                            return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
261                        }
262                    });
263                }
264
265                if let Some(len) = attr.len {
266                    checks.push(quote! {
267                        if self.#field_name.data_len() != #len {
268                            return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
269                        }
270                    });
271                }
272
273                if let Some(id) = attr.id {
274                    checks.push(quote! {
275                        if self.#field_name.key() != &#id {
276                            return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
277                        }
278                    });
279                }
280
281                quote! {
282                    #(#checks)*
283                }
284            } else {
285                quote! {}
286            }
287        })
288        .collect();
289
290    let expanded = quote! {
291        impl<'info> pinocchio_util::Validate<'info> for #name<'info> {
292            fn validate(&self) -> Result<(), pinocchio::program_error::ProgramError> {
293                #(#validation_checks)*
294                Ok(())
295            }
296        }
297    };
298
299    TokenStream::from(expanded)
300}
301
302/// Generates an implementation for `Context`:
303///
304/// ```rust
305/// pub trait Context<'info> {
306///     const ACCOUNTS_LEN: usize;
307///     fn build(accounts: &'info [AccountInfo]) -> Result<Self, ProgramError>;
308/// }
309///
310/// impl<'info> Context<'info> for MyStruct<'info> {
311///     // # of fields in the struct
312///     const ACCOUNTS_LEN: usize = 1;
313///
314///     fn build(accounts: &'info [AccountInfo]) -> Result<Self, ProgramError> {
315///         let ctx = unsafe {
316///             Self {
317///                 field_1: &accounts.get_unchecked(0),
318///                 field_2: &accounts.get_unchecked(1),
319///             }
320///         }
321///
322///         Ok(ctx)
323///     }
324/// }
325/// ```
326#[proc_macro_derive(Context)]
327pub fn derive_context(input: TokenStream) -> TokenStream {
328    let input = parse_macro_input!(input as DeriveInput);
329    let name = &input.ident;
330
331    let lifetime_params: Vec<_> = input.generics.lifetimes().collect();
332
333    if lifetime_params.len() != 1 {
334        panic!("Context derive requires exactly one lifetime parameter");
335    }
336
337    let lifetime_param = &lifetime_params[0];
338    let lifetime = &lifetime_param.lifetime;
339
340    if lifetime.ident != "info" {
341        panic!("Context derive requires the lifetime parameter to be named 'info");
342    }
343
344    let fields = match input.data {
345        Data::Struct(ref data) => &data.fields,
346        _ => panic!("Context derive only works on structs"),
347    };
348
349    let accounts_len = fields.len();
350    let field_assignments: Vec<_> = fields
351        .iter()
352        .enumerate()
353        .map(|(i, field)| {
354            let field_name = field.ident.as_ref().unwrap();
355            quote! { #field_name: &accounts.get_unchecked(#i), }
356        })
357        .collect();
358
359    let expanded = quote! {
360        impl<'info> pinocchio_util::Context<'info> for #name<'info> {
361            const ACCOUNTS_LEN: usize = #accounts_len;
362
363            fn build(accounts: &'info [pinocchio::account_info::AccountInfo])
364                -> Result<Self, pinocchio::program_error::ProgramError>
365            {
366                if accounts.len() != Self::ACCOUNTS_LEN {
367                    return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
368                }
369
370                Ok(unsafe {
371                    Self {
372                        #(#field_assignments)*
373                    }
374                })
375            }
376        }
377    };
378
379    TokenStream::from(expanded)
380}