lpl_token_metadata_context_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use std::collections::HashMap;
4use syn::{
5 self, parse_macro_input, DeriveInput, Expr, ExprPath, GenericArgument, Lit, Meta, MetaList,
6 MetaNameValue, NestedMeta, Path, PathArguments, Type, TypePath,
7};
8
9#[derive(Default)]
10struct Variant {
11 pub name: String,
12 pub tuple: Option<String>,
13 pub accounts: Vec<Account>,
14 pub args: Vec<(String, String, Option<String>)>,
16}
17
18#[derive(Debug)]
19struct Account {
20 pub name: String,
21 pub optional: bool,
22}
23
24const ACCOUNT_ATTRIBUTE: &str = "account";
26const ARGS_ATTRIBUTE: &str = "args";
28const NAME_PROPERTY: &str = "name";
30const OPTIONAL_PROPERTY: &str = "optional";
32
33#[proc_macro_derive(AccountContext, attributes(account, args))]
34pub fn account_context_derive(input: TokenStream) -> TokenStream {
35 let ast = parse_macro_input!(input as DeriveInput);
36
37 let variants = if let syn::Data::Enum(syn::DataEnum { ref variants, .. }) = ast.data {
40 let mut enum_variants = Vec::new();
41
42 for v in variants {
43 let mut variant = Variant {
45 tuple: if let syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) = &v.fields {
46 match unnamed.first() {
47 Some(syn::Field {
48 ty:
49 Type::Path(TypePath {
50 path: Path { segments, .. },
51 ..
52 }),
53 ..
54 }) => Some(segments.first().unwrap().ident.to_string()),
55 _ => None,
56 }
57 } else {
58 None
59 },
60 name: v.ident.to_string(),
61 ..Default::default()
62 };
63
64 for a in &v.attrs {
66 let syn::Attribute {
67 path: syn::Path { segments, .. },
68 ..
69 } = &a;
70 let mut skip = true;
71 let mut attribute = String::new();
72
73 for path in segments {
74 let ident = path.ident.to_string();
75 if ident == ACCOUNT_ATTRIBUTE || ident == ARGS_ATTRIBUTE {
77 attribute = ident;
78 skip = false;
79 }
80 }
81
82 if !skip {
83 if attribute == ACCOUNT_ATTRIBUTE {
84 let meta_tokens = a.parse_meta().unwrap();
85 let nested_meta = if let Meta::List(MetaList { nested, .. }) = &meta_tokens
86 {
87 nested
88 } else {
89 panic!("#[account] requires attributes account name");
90 };
91
92 let mut property: (Option<String>, Option<String>) = (None, None);
94
95 for element in nested_meta {
96 match element {
97 NestedMeta::Meta(Meta::NameValue(MetaNameValue {
99 path,
100 lit,
101 ..
102 })) => {
103 let ident = path.get_ident();
104 if let Some(ident) = ident {
105 if *ident == NAME_PROPERTY {
106 let token = match lit {
107 Lit::Str(lit) => {
109 lit.token().to_string().replace('\"', "")
110 }
111 _ => panic!("Invalid value for property {ident}"),
112 };
113 property.0 = Some(token);
114 }
115 }
116 }
117 NestedMeta::Meta(Meta::Path(path)) => {
119 let name = path.get_ident().map(|x| x.to_string());
120 if let Some(name) = name {
121 if name == OPTIONAL_PROPERTY {
122 property.1 = Some(name);
123 }
124 }
125 }
126 _ => {}
127 }
128 }
129 variant.accounts.push(Account {
130 name: property.0.unwrap(),
131 optional: property.1.is_some(),
132 });
133 } else if attribute == ARGS_ATTRIBUTE {
134 let args_tokens: syn::ExprType = a.parse_args().unwrap();
135 let name = match *args_tokens.expr {
137 Expr::Path(ExprPath {
138 path: Path { segments, .. },
139 ..
140 }) => segments.first().unwrap().ident.to_string(),
141 _ => panic!("#[args] requires an expression 'name: type'"),
142 };
143 match *args_tokens.ty {
145 Type::Path(TypePath {
146 path: Path { segments, .. },
147 ..
148 }) => {
149 let segment = segments.first().unwrap();
150
151 let generic_ty = match &segment.arguments {
153 PathArguments::AngleBracketed(arguments) => {
154 if let Some(GenericArgument::Type(Type::Path(ty))) =
155 arguments.args.first()
156 {
157 Some(
158 ty.path.segments.first().unwrap().ident.to_string(),
159 )
160 } else {
161 None
162 }
163 }
164 _ => None,
165 };
166
167 let ty = segment.ident.to_string();
168 variant.args.push((name, ty, generic_ty));
169 }
170 _ => panic!("#[args] requires an expression 'name: type'"),
171 }
172 }
173 }
174 }
175
176 enum_variants.push(variant);
177 }
178
179 enum_variants
180 } else {
181 panic!("No enum variants found");
182 };
183
184 let mut account_structs = generate_accounts(&variants);
185 account_structs.extend(generate_builders(&variants));
186
187 account_structs
188}
189
190fn generate_accounts(variants: &[Variant]) -> TokenStream {
217 let variant_structs = variants.iter().map(|variant| {
219 let name = syn::parse_str::<syn::Ident>(&variant.name).unwrap();
220 let fields = variant.accounts.iter().map(|account| {
222 let account_name = syn::parse_str::<syn::Ident>(format!("{}_info", &account.name).as_str()).unwrap();
223 quote! { #account_name }
224 });
225 let struct_fields = variant.accounts.iter().map(|account| {
227 let account_name = syn::parse_str::<syn::Ident>(format!("{}_info", &account.name).as_str()).unwrap();
228 if account.optional {
229 quote! {
230 pub #account_name: Option<&'a safecoin_program::account_info::AccountInfo<'a>>
231 }
232 } else {
233 quote! {
234 pub #account_name:&'a safecoin_program::account_info::AccountInfo<'a>
235 }
236 }
237 });
238 let impl_fields = variant.accounts.iter().map(|account| {
240 let account_name = syn::parse_str::<syn::Ident>(format!("{}_info", &account.name).as_str()).unwrap();
241 if account.optional {
242 quote! {
243 let #account_name = crate::processor::next_optional_account_info(account_info_iter)?;
244 }
245 } else {
246 quote! {
247 let #account_name = safecoin_program::account_info::next_account_info(account_info_iter)?;
248 }
249 }
250 });
251
252 quote! {
253 pub struct #name<'a> {
254 #(#struct_fields,)*
255 }
256 impl<'a> #name<'a> {
257 pub fn to_context(accounts: &'a [safecoin_program::account_info::AccountInfo<'a>]) -> Result<Context<'a, Self>, safecoin_program::sysvar::slot_history::ProgramError> {
258 let account_info_iter = &mut accounts.iter();
259
260 #(#impl_fields)*
261
262 let accounts = Self {
263 #(#fields,)*
264 };
265
266 Ok(Context {
267 accounts,
268 remaining_accounts: Vec::<&'a AccountInfo<'a>>::from_iter(account_info_iter),
269 })
270 }
271 }
272 }
273 });
274
275 TokenStream::from(quote! {
276 #(#variant_structs)*
277 })
278}
279
280fn generate_builders(variants: &[Variant]) -> TokenStream {
281 let mut default_pubkeys = HashMap::new();
282 default_pubkeys.insert(
283 "system_program".to_string(),
284 syn::parse_str::<syn::ExprPath>("safecoin_program::system_program::ID").unwrap(),
285 );
286 default_pubkeys.insert(
287 "safe_token_program".to_string(),
288 syn::parse_str::<syn::ExprPath>("safe_token::ID").unwrap(),
289 );
290 default_pubkeys.insert(
291 "spl_ata_program".to_string(),
292 syn::parse_str::<syn::ExprPath>("safe_associated_token_account::ID").unwrap(),
293 );
294 default_pubkeys.insert(
295 "sysvar_instructions".to_string(),
296 syn::parse_str::<syn::ExprPath>("safecoin_program::sysvar::instructions::ID").unwrap(),
297 );
298 default_pubkeys.insert(
299 "authorization_rules_program".to_string(),
300 syn::parse_str::<syn::ExprPath>("lpl_token_auth_rules::ID").unwrap(),
301 );
302
303 let variant_structs = variants.iter().map(|variant| {
305 let name = syn::parse_str::<syn::Ident>(&variant.name).unwrap();
306
307 let struct_accounts = variant.accounts.iter().map(|account| {
312 let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
313 if account.optional {
314 quote! {
315 pub #account_name: Option<safecoin_program::pubkey::Pubkey>
316 }
317 } else {
318 quote! {
319 pub #account_name: safecoin_program::pubkey::Pubkey
320 }
321 }
322 });
323
324 let struct_args = variant.args.iter().map(|(name, ty, generic_ty)| {
326 let ident_ty = syn::parse_str::<syn::Ident>(ty).unwrap();
327 let arg_ty = if let Some(genetic_ty) = generic_ty {
328 let arg_generic_ty = syn::parse_str::<syn::Ident>(genetic_ty).unwrap();
329 quote! { #ident_ty<#arg_generic_ty> }
330 } else {
331 quote! { #ident_ty }
332 };
333 let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
334
335 quote! {
336 pub #arg_name: #arg_ty
337 }
338 });
339
340 let builder_accounts = variant.accounts.iter().map(|account| {
345 let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
346 quote! {
347 pub #account_name: Option<safecoin_program::pubkey::Pubkey>
348 }
349 });
350
351 let builder_initialize_accounts = variant.accounts.iter().map(|account| {
353 let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
354 quote! {
355 #account_name: None
356 }
357 });
358
359 let builder_args = variant.args.iter().map(|(name, ty, generic_ty)| {
361 let ident_ty = syn::parse_str::<syn::Ident>(ty).unwrap();
362 let arg_ty = if let Some(genetic_ty) = generic_ty {
363 let arg_generic_ty = syn::parse_str::<syn::Ident>(genetic_ty).unwrap();
364 quote! { #ident_ty<#arg_generic_ty> }
365 } else {
366 quote! { #ident_ty }
367 };
368 let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
369
370 quote! {
371 pub #arg_name: Option<#arg_ty>
372 }
373 });
374
375 let builder_initialize_args = variant.args.iter().map(|(name, _ty, _generi_ty)| {
377 let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
378 quote! {
379 #arg_name: None
380 }
381 });
382
383 let builder_accounts_methods = variant.accounts.iter().map(|account| {
385 let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
386 quote! {
387 pub fn #account_name(&mut self, #account_name: safecoin_program::pubkey::Pubkey) -> &mut Self {
388 self.#account_name = Some(#account_name);
389 self
390 }
391 }
392 });
393
394 let builder_args_methods = variant.args.iter().map(|(name, ty, generic_ty)| {
396 let ident_ty = syn::parse_str::<syn::Ident>(ty).unwrap();
397 let arg_ty = if let Some(genetic_ty) = generic_ty {
398 let arg_generic_ty = syn::parse_str::<syn::Ident>(genetic_ty).unwrap();
399 quote! { #ident_ty<#arg_generic_ty> }
400 } else {
401 quote! { #ident_ty }
402 };
403 let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
404
405 quote! {
406 pub fn #arg_name(&mut self, #arg_name: #arg_ty) -> &mut Self {
407 self.#arg_name = Some(#arg_name);
408 self
409 }
410 }
411 });
412
413 let required_accounts = variant.accounts.iter().map(|account| {
415 let account_name = syn::parse_str::<syn::Ident>(&account.name).unwrap();
416
417 if account.optional {
418 quote! {
419 #account_name: self.#account_name
420 }
421 } else {
422 if default_pubkeys.contains_key(&account.name) {
424 let pubkey = default_pubkeys.get(&account.name).unwrap();
425 quote! {
427 #account_name: self.#account_name.unwrap_or(#pubkey)
428 }
429 }
430 else {
431 quote! {
433 #account_name: self.#account_name.ok_or(concat!(stringify!(#account_name), " is not set"))?
434 }
435 }
436 }
437 });
438
439 let required_args = variant.args.iter().map(|(name, _ty, _generic_ty)| {
441 let arg_name = syn::parse_str::<syn::Ident>(name).unwrap();
442 quote! {
443 #arg_name: self.#arg_name.clone().ok_or(concat!(stringify!(#arg_name), " is not set"))?
444 }
445 });
446
447 let args = if let Some(args) = &variant.tuple {
449 let arg_ty = syn::parse_str::<syn::Ident>(args).unwrap();
450 quote! { &mut self, args: #arg_ty }
451 } else {
452 quote! { &mut self }
453 };
454
455 let instruction_args = if let Some(args) = &variant.tuple {
457 let arg_ty = syn::parse_str::<syn::Ident>(args).unwrap();
458 quote! { pub args: #arg_ty, }
459 } else {
460 quote! { }
461 };
462
463 let required_instruction_args = if variant.tuple.is_some() {
465 quote! { args, }
466 } else {
467 quote! { }
468 };
469
470 let builder_name = syn::parse_str::<syn::Ident>(&format!("{}Builder", name)).unwrap();
472
473 quote! {
474 pub struct #name {
475 #(#struct_accounts,)*
476 #(#struct_args,)*
477 #instruction_args
478 }
479
480 pub struct #builder_name {
481 #(#builder_accounts,)*
482 #(#builder_args,)*
483 }
484
485 impl #builder_name {
486 pub fn new() -> Box<#builder_name> {
487 Box::new(#builder_name {
488 #(#builder_initialize_accounts,)*
489 #(#builder_initialize_args,)*
490 })
491 }
492
493 #(#builder_accounts_methods)*
494 #(#builder_args_methods)*
495
496 pub fn build(#args) -> Result<Box<#name>, Box<dyn std::error::Error>> {
497 Ok(Box::new(#name {
498 #(#required_accounts,)*
499 #(#required_args,)*
500 #required_instruction_args
501 }))
502 }
503 }
504 }
505 });
506
507 TokenStream::from(quote! {
508 pub mod builders {
509 use super::*;
510
511 #(#variant_structs)*
512 }
513 })
514}