anchor_i11n_derive/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{quote, ToTokens};
4use sha2::{Digest, Sha256};
5use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, PathArguments, PathSegment, Type};
6
7#[proc_macro_derive(TryFromInstruction)]
8pub fn try_from_instruction(input: TokenStream) -> TokenStream {
9    let context_struct = parse_macro_input!(input as DeriveInput);
10    let context_name = &context_struct.ident;
11
12    // Extract lifetime from the generic parameters, if any
13    let lifetime = match context_struct.generics.lifetimes().next() {
14        Some(l) => {
15            let lifetime_name = &l.lifetime;
16            quote! { #lifetime_name }.into()
17        }
18        None => quote! {},
19    };
20
21    let mut has_accounts_info_lifetime = false;
22    let mut has_accounts_path = None;
23    let mut has_args_path = None;
24
25    if let Data::Struct(data_struct) = &context_struct.data {
26        if let Fields::Named(fields) = &data_struct.fields {
27            for field in &fields.named {
28                if let Some(ident) = &field.ident {
29                    match ident.to_string().as_str() {
30                        "accounts" => {
31                            if let Type::Path(type_path) = &field.ty {
32                                let mut new_type_path = type_path.clone();
33                                if let Some(last_segment) = new_type_path.path.segments.last_mut() {
34                                    if matches!(last_segment.arguments, PathArguments::AngleBracketed(_)) {
35                                        has_accounts_info_lifetime = true;
36                                    }
37                                    *last_segment = PathSegment {
38                                        ident: last_segment.ident.clone(),
39                                        arguments: PathArguments::None
40                                    };
41                                }
42                                has_accounts_path = Some(new_type_path);
43                            }
44                        }
45                        "args" => {
46                            if let Type::Path(type_path) = &field.ty {
47                                has_args_path = Some(type_path);
48                            }
49                        }
50                        "remaining_accounts" => {}
51                        _ => panic!("Expected field name of \"accounts\", \"args\" or \"remaining_accounts\""),
52                    }
53                }
54            }
55        } else {
56            panic!("Expected named fields in the struct.");
57        }
58    } else {
59        panic!("Expected named fields in the struct.");
60    }
61
62    // Unwrap is safe here because we ensure both fields are present
63    let (accounts_path, args_path) = (has_accounts_path.unwrap(), has_args_path.unwrap());
64
65    let accounts_derive = match has_accounts_info_lifetime {
66        true => quote! { #accounts_path::<#lifetime>::try_from(&ix.accounts)?; },
67        false => quote! { #accounts_path::try_from(&ix.accounts)?; }
68    };
69
70    // Generate the discriminator
71    let expanded = quote! {
72        impl<#lifetime> TryFrom<&#lifetime anchor_lang::solana_program::instruction::Instruction> for #context_name<#lifetime> {
73            type Error = Error;
74
75            fn try_from(ix: &#lifetime anchor_lang::solana_program::instruction::Instruction) -> Result<#context_name<#lifetime>> {
76                require_keys_eq!(ix.program_id, ID, ErrorCode::InvalidProgramId);
77
78                require!(ix.data[..8].eq(&#args_path::DISCRIMINATOR), ErrorCode::InstructionDidNotDeserialize);
79
80                let accounts = #accounts_derive;
81                let remaining_accounts = #accounts_path::try_remaining_accounts_from(&ix.accounts)?;
82                let args = #args_path::try_from_slice(&ix.data[8..])?;
83
84                Ok(#context_name {
85                    accounts,
86                    args,
87                    remaining_accounts
88                })
89            }
90        }
91    };
92
93    // Convert the generated implementation back into tokens and return it
94    TokenStream::from(expanded)
95}
96
97// Derive the discriminator from an instruction struct
98#[proc_macro_derive(AnchorDiscriminator)]
99pub fn anchor_discriminator(input: TokenStream) -> TokenStream {
100    let args_struct = parse_macro_input!(input as DeriveInput);
101    let args_type = &args_struct.ident;
102    let mut hasher = Sha256::new();
103    hasher.update(format!("global:{}", args_type.to_string().to_case(Case::Snake)).as_bytes());
104
105    let mut discriminator_bytes: [u8; 8] = [0u8; 8];
106    discriminator_bytes.clone_from_slice(&hasher.finalize().to_vec()[..8]);
107
108    let discriminator: Vec<_> = discriminator_bytes
109        .into_iter()
110        .map(|i| {
111            let idx = i as u8;
112            quote! { #idx }
113        })
114        .collect();
115
116    quote! {
117        impl Discriminator for #args_type {
118            const DISCRIMINATOR: [u8; 8] = [#(#discriminator),*];
119            fn discriminator() -> [u8; 8] {
120                Self::DISCRIMINATOR
121            }
122        }
123    }
124    .into()
125}
126
127#[proc_macro_derive(TryFromAccountMetas)]
128pub fn try_from_account_metas(input: TokenStream) -> TokenStream {
129    let accounts_struct = parse_macro_input!(input as DeriveInput);
130    let accounts_name = &accounts_struct.ident;
131
132    // Extract lifetime from the generic parameters, if any
133    let lifetime = match accounts_struct.generics.lifetimes().next() {
134        Some(l) => {
135            let lifetime_name = &l.lifetime;
136            quote! {#lifetime_name}
137        }
138        None => quote! {},
139    };
140
141    // Extract the field names from the struct
142    let mut optional_account_names: Vec<&Ident> = vec![];
143    let account_names = if let Data::Struct(data_struct) = &accounts_struct.data {
144        if let syn::Fields::Named(fields) = &data_struct.fields {
145            fields
146                .named
147                .iter()
148                .map(|field| {
149                    if let syn::Type::Path(type_path) = &field.ty {
150                        if let Some(segment) = type_path.path.segments.last() {
151                            if segment.ident.to_string() == "Option" {
152                                optional_account_names.push(field.ident.as_ref().unwrap())
153                            }
154                        }
155                    }
156                    field.ident.as_ref().unwrap()
157                })
158                .collect::<Vec<_>>()
159        } else {
160            Vec::new() // Handle tuple structs or unit structs
161        }
162    } else {
163        Vec::new() // Handle enums
164    };
165
166    // Handle optional accounts
167    let optional_accounts: Vec<_> = optional_account_names
168        .iter()
169        .map(|ident| {
170            let id = ident.to_token_stream();
171            quote! {
172                let #id = match &#id.pubkey.eq(&ID) {
173                    true => None,
174                    false => Some(#id), // Dereference the reference here
175                };
176            }
177        })
178        .collect();
179
180    // Extract the number of fields in the struct
181    let accounts_length = account_names.len();
182
183    // Generate array access expressions for the value vector
184    let value_indices: Vec<_> = (0..account_names.len())
185        .map(|i| {
186            let idx = syn::Index::from(i);
187            quote! { &value[#idx] }
188        })
189        .collect();
190
191    let account_generators = match account_names.len() > 0 {
192        true => quote! {
193            if value.len() < #accounts_length {
194                return Err(ProgramError::NotEnoughAccountKeys.into());
195            }
196
197            let [#(#account_names),*] = [#(#value_indices),*];
198
199            #(#optional_accounts)*
200
201            Ok(Self {
202                #(#account_names),*
203            })
204        },
205        false => quote! {
206            Ok(Self {})
207        },
208    };
209
210    // Generate the implementation
211    let expanded = quote! {
212        impl<#lifetime> TryFrom<&#lifetime Vec<AccountMeta>> for #accounts_name<#lifetime> {
213            type Error = Error;
214
215            fn try_from(value: &#lifetime Vec<AccountMeta>) -> Result<Self> {
216                if value.len() < #accounts_length {
217                    return Err(ProgramError::NotEnoughAccountKeys.into());
218                }
219
220                #account_generators
221            }
222        }
223
224        impl<#lifetime> #accounts_name<#lifetime> {
225            fn try_remaining_accounts_from(value: &#lifetime Vec<AccountMeta>) -> Result<Vec<&#lifetime AccountMeta>> {
226                if value.len() < #accounts_length {
227                    return Err(ProgramError::NotEnoughAccountKeys.into());
228                }
229                Ok(value[#accounts_length..].iter().map(|a| a).collect())
230            }
231        }
232    };
233
234    // Convert the generated implementation back into tokens and return it
235    TokenStream::from(expanded)
236}