lpl_token_metadata_context_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use std::collections::HashMap;
4use syn::{
5    self, parse_macro_input, DeriveInput, Expr, ExprPath, GenericArgument, Lit, Meta, MetaList,
6    MetaNameValue, NestedMeta, Path, PathArguments, Type, TypePath,
7};
8
9#[derive(Default)]
10struct Variant {
11    pub name: String,
12    pub tuple: Option<String>,
13    pub accounts: Vec<Account>,
14    // (name, type, generic type)
15    pub args: Vec<(String, String, Option<String>)>,
16}
17
18#[derive(Debug)]
19struct Account {
20    pub name: String,
21    pub optional: bool,
22}
23
24// Helper account attribute (reusing from shank annotation).
25const ACCOUNT_ATTRIBUTE: &str = "account";
26// Helper args attribute.
27const ARGS_ATTRIBUTE: &str = "args";
28// Name property in the account attribute.
29const NAME_PROPERTY: &str = "name";
30// Optional property in the account attribute.
31const OPTIONAL_PROPERTY: &str = "optional";
32
33#[proc_macro_derive(AccountContext, attributes(account, args))]
34pub fn account_context_derive(input: TokenStream) -> TokenStream {
35    let ast = parse_macro_input!(input as DeriveInput);
36
37    // identifies the accounts associated with each enum variant
38
39    let variants = if let syn::Data::Enum(syn::DataEnum { ref variants, .. }) = ast.data {
40        let mut enum_variants = Vec::new();
41
42        for v in variants {
43            // extract the enum data (if there is one present)
44            let mut variant = Variant {
45                tuple: if let syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) = &v.fields {
46                    match unnamed.first() {
47                        Some(syn::Field {
48                            ty:
49                                Type::Path(TypePath {
50                                    path: Path { segments, .. },
51                                    ..
52                                }),
53                            ..
54                        }) => Some(segments.first().unwrap().ident.to_string()),
55                        _ => None,
56                    }
57                } else {
58                    None
59                },
60                name: v.ident.to_string(),
61                ..Default::default()
62            };
63
64            // parse the attribute of the variant
65            for a in &v.attrs {
66                let syn::Attribute {
67                    path: syn::Path { segments, .. },
68                    ..
69                } = &a;
70                let mut skip = true;
71                let mut attribute = String::new();
72
73                for path in segments {
74                    let ident = path.ident.to_string();
75                    // we are only interested in #[account] and #[args] attributes
76                    if ident == ACCOUNT_ATTRIBUTE || ident == ARGS_ATTRIBUTE {
77                        attribute = ident;
78                        skip = false;
79                    }
80                }
81
82                if !skip {
83                    if attribute == ACCOUNT_ATTRIBUTE {
84                        let meta_tokens = a.parse_meta().unwrap();
85                        let nested_meta = if let Meta::List(MetaList { nested, .. }) = &meta_tokens
86                        {
87                            nested
88                        } else {
89                            panic!("#[account] requires attributes account name");
90                        };
91
92                        // (name, optional)
93                        let mut property: (Option<String>, Option<String>) = (None, None);
94
95                        for element in nested_meta {
96                            match element {
97                                // name = value (ignores any other attribute)
98                                NestedMeta::Meta(Meta::NameValue(MetaNameValue {
99                                    path,
100                                    lit,
101                                    ..
102                                })) => {
103                                    let ident = path.get_ident();
104                                    if let Some(ident) = ident {
105                                        if *ident == NAME_PROPERTY {
106                                            let token = match lit {
107                                                // removes the surrounding "'s from string values"
108                                                Lit::Str(lit) => {
109                                                    lit.token().to_string().replace('\"', "")
110                                                }
111                                                _ => panic!("Invalid value for property {ident}"),
112                                            };
113                                            property.0 = Some(token);
114                                        }
115                                    }
116                                }
117                                // optional
118                                NestedMeta::Meta(Meta::Path(path)) => {
119                                    let name = path.get_ident().map(|x| x.to_string());
120                                    if let Some(name) = name {
121                                        if name == OPTIONAL_PROPERTY {
122                                            property.1 = Some(name);
123                                        }
124                                    }
125                                }
126                                _ => {}
127                            }
128                        }
129                        variant.accounts.push(Account {
130                            name: property.0.unwrap(),
131                            optional: property.1.is_some(),
132                        });
133                    } else if attribute == ARGS_ATTRIBUTE {
134                        let args_tokens: syn::ExprType = a.parse_args().unwrap();
135                        // name
136                        let name = match *args_tokens.expr {
137                            Expr::Path(ExprPath {
138                                path: Path { segments, .. },
139                                ..
140                            }) => segments.first().unwrap().ident.to_string(),
141                            _ => panic!("#[args] requires an expression 'name: type'"),
142                        };
143                        // type
144                        match *args_tokens.ty {
145                            Type::Path(TypePath {
146                                path: Path { segments, .. },
147                                ..
148                            }) => {
149                                let segment = segments.first().unwrap();
150
151                                // check whether we are dealing with a generic type
152                                let generic_ty = match &segment.arguments {
153                                    PathArguments::AngleBracketed(arguments) => {
154                                        if let Some(GenericArgument::Type(Type::Path(ty))) =
155                                            arguments.args.first()
156                                        {
157                                            Some(
158                                                ty.path.segments.first().unwrap().ident.to_string(),
159                                            )
160                                        } else {
161                                            None
162                                        }
163                                    }
164                                    _ => None,
165                                };
166
167                                let ty = segment.ident.to_string();
168                                variant.args.push((name, ty, generic_ty));
169                            }
170                            _ => panic!("#[args] requires an expression 'name: type'"),
171                        }
172                    }
173                }
174            }
175
176            enum_variants.push(variant);
177        }
178
179        enum_variants
180    } else {
181        panic!("No enum variants found");
182    };
183
184    let mut account_structs = generate_accounts(&variants);
185    account_structs.extend(generate_builders(&variants));
186
187    account_structs
188}
189
190/// Generates a struct for each enum variant.
191///
192/// The struct will contain all shank annotated accounts and the impl block
193/// will initialize them using the accounts iterators. It support the use of
194/// optional accounts, which would generate an account field with an
195/// `Option<AccountInfo<'a>>` type.
196///
197/// ```ignore
198/// pub struct MyAccount<'a> {
199///     my_first_account: safecoin_program::account_info::AccountInfo<'a>,
200///     my_second_optional_account: Option<safecoin_program::account_info::AccountInfo<'a>>,
201///     ..
202/// }
203/// impl<'a> MyAccount<'a> {
204///     pub fn to_context(
205///         accounts: &'a [safecoin_program::account_info::AccountInfo<'a>]
206///     ) -> Result<Context<'a, Self>, safecoin_program::sysvar::slot_history::ProgramError> {
207///         let account_info_iter = &mut accounts.iter();
208///
209///         let my_first_account = safecoin_program::account_info::next_account_info(account_info_iter)?;
210///
211///         ..
212///
213///     }
214/// }
215/// ```
216fn generate_accounts(variants: &[Variant]) -> TokenStream {
217    // build the trait implementation
218    let variant_structs = variants.iter().map(|variant| {
219        let name = syn::parse_str::<syn::Ident>(&variant.name).unwrap();
220        // accounts names
221        let fields = variant.accounts.iter().map(|account| {
222            let account_name = syn::parse_str::<syn::Ident>(format!("{}_info", &account.name).as_str()).unwrap();
223            quote! { #account_name }
224        });
225        // accounts fields
226        let struct_fields = variant.accounts.iter().map(|account| {
227            let account_name = syn::parse_str::<syn::Ident>(format!("{}_info", &account.name).as_str()).unwrap();
228            if account.optional {
229                quote! {
230                    pub #account_name: Option<&'a safecoin_program::account_info::AccountInfo<'a>>
231                }
232            } else {
233                quote! {
234                    pub #account_name:&'a safecoin_program::account_info::AccountInfo<'a>
235                }
236            }
237        });
238        // accounts initialization for the impl block
239        let impl_fields = variant.accounts.iter().map(|account| {
240            let account_name = syn::parse_str::<syn::Ident>(format!("{}_info", &account.name).as_str()).unwrap();
241            if account.optional {
242                quote! {
243                    let #account_name = crate::processor::next_optional_account_info(account_info_iter)?;
244                }
245            } else {
246                quote! {
247                    let #account_name = safecoin_program::account_info::next_account_info(account_info_iter)?;
248                }
249            }
250        });
251
252        quote! {
253            pub struct #name<'a> {
254                #(#struct_fields,)*
255            }
256            impl<'a> #name<'a> {
257                pub fn to_context(accounts: &'a [safecoin_program::account_info::AccountInfo<'a>]) -> Result<Context<'a, Self>, safecoin_program::sysvar::slot_history::ProgramError> {
258                    let account_info_iter = &mut accounts.iter();
259
260                    #(#impl_fields)*
261
262                    let accounts = Self {
263                        #(#fields,)*
264                    };
265
266                    Ok(Context {
267                        accounts,
268                        remaining_accounts: Vec::<&'a AccountInfo<'a>>::from_iter(account_info_iter),
269                    })
270                }
271            }
272        }
273    });
274
275    TokenStream::from(quote! {
276        #(#variant_structs)*
277    })
278}
279
280fn generate_builders(variants: &[Variant]) -> TokenStream {
281    let mut default_pubkeys = HashMap::new();
282    default_pubkeys.insert(
283        "system_program".to_string(),
284        syn::parse_str::<syn::ExprPath>("safecoin_program::system_program::ID").unwrap(),
285    );
286    default_pubkeys.insert(
287        "safe_token_program".to_string(),
288        syn::parse_str::<syn::ExprPath>("safe_token::ID").unwrap(),
289    );
290    default_pubkeys.insert(
291        "spl_ata_program".to_string(),
292        syn::parse_str::<syn::ExprPath>("safe_associated_token_account::ID").unwrap(),
293    );
294    default_pubkeys.insert(
295        "sysvar_instructions".to_string(),
296        syn::parse_str::<syn::ExprPath>("safecoin_program::sysvar::instructions::ID").unwrap(),
297    );
298    default_pubkeys.insert(
299        "authorization_rules_program".to_string(),
300        syn::parse_str::<syn::ExprPath>("lpl_token_auth_rules::ID").unwrap(),
301    );
302
303    // build the trait implementation
304    let variant_structs = variants.iter().map(|variant| {
305        let name = syn::parse_str::<syn::Ident>(&variant.name).unwrap();
306
307        // struct block for the builder: this will contain both accounts and
308        // args for the builder
309
310        // accounts
311        let struct_accounts = variant.accounts.iter().map(|account| {
312            let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
313            if account.optional {
314                quote! {
315                    pub #account_name: Option<safecoin_program::pubkey::Pubkey>
316                }
317            } else {
318                quote! {
319                    pub #account_name: safecoin_program::pubkey::Pubkey
320                }
321            }
322        });
323
324        // args
325        let struct_args = variant.args.iter().map(|(name, ty, generic_ty)| {
326            let ident_ty = syn::parse_str::<syn::Ident>(ty).unwrap();
327            let arg_ty = if let Some(genetic_ty) = generic_ty {
328                let arg_generic_ty = syn::parse_str::<syn::Ident>(genetic_ty).unwrap();
329                quote! { #ident_ty<#arg_generic_ty> }
330            } else {
331                quote! { #ident_ty }
332            };
333            let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
334              
335            quote! {
336                pub #arg_name: #arg_ty
337            }
338        });
339
340        // builder block: this will have all accounts and args as optional fields
341        // that need to be set before the build method is called
342
343        // accounts
344        let builder_accounts = variant.accounts.iter().map(|account| {
345            let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
346            quote! {
347                pub #account_name: Option<safecoin_program::pubkey::Pubkey>
348            }
349        });
350
351        // accounts initialization
352        let builder_initialize_accounts = variant.accounts.iter().map(|account| {
353            let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
354            quote! {
355                #account_name: None
356            }
357        });
358
359        // args
360        let builder_args = variant.args.iter().map(|(name, ty, generic_ty)| {
361            let ident_ty = syn::parse_str::<syn::Ident>(ty).unwrap();
362            let arg_ty = if let Some(genetic_ty) = generic_ty {
363                let arg_generic_ty = syn::parse_str::<syn::Ident>(genetic_ty).unwrap();
364                quote! { #ident_ty<#arg_generic_ty> }
365            } else {
366                quote! { #ident_ty }
367            };
368            let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
369
370            quote! {
371                pub #arg_name: Option<#arg_ty>
372            }
373        });
374
375        // args initialization
376        let builder_initialize_args = variant.args.iter().map(|(name, _ty, _generi_ty)| {
377            let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
378            quote! {
379                #arg_name: None
380            }
381        });
382
383        // account setter methods
384        let builder_accounts_methods = variant.accounts.iter().map(|account| {
385            let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
386            quote! {
387                pub fn #account_name(&mut self, #account_name: safecoin_program::pubkey::Pubkey) -> &mut Self {
388                    self.#account_name = Some(#account_name);
389                    self
390                }
391            }
392        });
393
394        // args setter methods
395        let builder_args_methods = variant.args.iter().map(|(name, ty, generic_ty)| {
396            let ident_ty = syn::parse_str::<syn::Ident>(ty).unwrap();
397            let arg_ty = if let Some(genetic_ty) = generic_ty {
398                let arg_generic_ty = syn::parse_str::<syn::Ident>(genetic_ty).unwrap();
399                quote! { #ident_ty<#arg_generic_ty> }
400            } else {
401                quote! { #ident_ty }
402            };
403            let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
404
405            quote! {
406                pub fn #arg_name(&mut self, #arg_name: #arg_ty) -> &mut Self {
407                    self.#arg_name = Some(#arg_name);
408                    self
409                }
410            }
411        });
412
413        // required accounts
414        let required_accounts = variant.accounts.iter().map(|account| {
415            let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
416
417            if account.optional {
418                quote! {
419                    #account_name: self.#account_name
420                }
421            } else {
422                // are we dealing with a default pubkey?
423                if default_pubkeys.contains_key(&account.name) {
424                    let pubkey = default_pubkeys.get(&account.name).unwrap();
425                    // we add the default key as the fallback value
426                    quote! {
427                        #account_name: self.#account_name.unwrap_or(#pubkey)
428                    }
429                }
430                else {
431                    // if not a default pubkey, we will need to have it set
432                    quote! {
433                        #account_name: self.#account_name.ok_or(concat!(stringify!(#account_name), " is not set"))?
434                    }
435                }
436            }
437        });
438
439        // required args
440        let required_args = variant.args.iter().map(|(name, _ty, _generic_ty)| {
441            let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
442            quote! {
443                #arg_name: self.#arg_name.clone().ok_or(concat!(stringify!(#arg_name), " is not set"))?
444            }
445        });
446
447        // args parameter list
448        let args = if let Some(args) = &variant.tuple {
449            let arg_ty = syn::parse_str::<syn::Ident>(args).unwrap();
450            quote! { &mut self, args: #arg_ty }
451        } else {
452            quote! { &mut self }
453        };
454
455        // instruction args
456        let instruction_args = if let Some(args) = &variant.tuple {
457            let arg_ty = syn::parse_str::<syn::Ident>(args).unwrap();
458            quote! { pub args: #arg_ty, }
459        } else {
460            quote! { }
461        };
462
463        // required instruction args
464        let required_instruction_args = if variant.tuple.is_some() {
465            quote! { args, }
466        } else {
467            quote! { }
468        };
469
470        // builder name
471        let builder_name = syn::parse_str::<syn::Ident>(&format!("{}Builder", name)).unwrap();
472
473        quote! {
474            pub struct #name {
475                #(#struct_accounts,)*
476                #(#struct_args,)*
477                #instruction_args
478            }
479
480            pub struct #builder_name {
481                #(#builder_accounts,)*
482                #(#builder_args,)*
483            }
484
485            impl #builder_name {
486                pub fn new() -> Box<#builder_name> {
487                    Box::new(#builder_name {
488                        #(#builder_initialize_accounts,)*
489                        #(#builder_initialize_args,)*
490                    })
491                }
492
493                #(#builder_accounts_methods)*
494                #(#builder_args_methods)*
495
496                pub fn build(#args) -> Result<Box<#name>, Box<dyn std::error::Error>> {
497                    Ok(Box::new(#name {
498                        #(#required_accounts,)*
499                        #(#required_args,)*
500                        #required_instruction_args
501                    }))
502                }
503            }
504        }
505    });
506
507    TokenStream::from(quote! {
508        pub mod builders {
509            use super::*;
510
511            #(#variant_structs)*
512        }
513    })
514}