bolt_anchor_attribute_account/
lib.rs

1extern crate proc_macro;
2
3use quote::quote;
4use syn::parse_macro_input;
5
6mod id;
7
8/// An attribute for a data structure representing a Solana account.
9///
10/// `#[account]` generates trait implementations for the following traits:
11///
12/// - [`AccountSerialize`](./trait.AccountSerialize.html)
13/// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
14/// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
15/// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
16/// - [`Clone`](https://doc.rust-lang.org/std/clone/trait.Clone.html)
17/// - [`Discriminator`](./trait.Discriminator.html)
18/// - [`Owner`](./trait.Owner.html)
19///
20/// When implementing account serialization traits the first 8 bytes are
21/// reserved for a unique account discriminator, self described by the first 8
22/// bytes of the SHA256 of the account's Rust ident.
23///
24/// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
25/// check this discriminator. If it doesn't match, an invalid account was given,
26/// and the account deserialization will exit with an error.
27///
28/// # Zero Copy Deserialization
29///
30/// **WARNING**: Zero copy deserialization is an experimental feature. It's
31/// recommended to use it only when necessary, i.e., when you have extremely
32/// large accounts that cannot be Borsh deserialized without hitting stack or
33/// heap limits.
34///
35/// ## Usage
36///
37/// To enable zero-copy-deserialization, one can pass in the `zero_copy`
38/// argument to the macro as follows:
39///
40/// ```ignore
41/// #[account(zero_copy)]
42/// ```
43///
44/// This can be used to conveniently implement
45/// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
46/// with [`AccountLoader`](./accounts/account_loader/struct.AccountLoader.html).
47///
48/// Other than being more efficient, the most salient benefit this provides is
49/// the ability to define account types larger than the max stack or heap size.
50/// When using borsh, the account has to be copied and deserialized into a new
51/// data structure and thus is constrained by stack and heap limits imposed by
52/// the BPF VM. With zero copy deserialization, all bytes from the account's
53/// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
54/// the data structure. No allocations or copies necessary. Hence the ability
55/// to get around stack and heap limitations.
56///
57/// To facilitate this, all fields in an account must be constrained to be
58/// "plain old  data", i.e., they must implement
59/// [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html). Please review the
60/// [`safety`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html#safety)
61/// section before using.
62///
63/// Using `zero_copy` requires adding the following to your `cargo.toml` file:
64/// `bytemuck = { version = "1.4.0", features = ["derive", "min_const_generics"]}`
65#[proc_macro_attribute]
66pub fn account(
67    args: proc_macro::TokenStream,
68    input: proc_macro::TokenStream,
69) -> proc_macro::TokenStream {
70    let mut namespace = "".to_string();
71    let mut is_zero_copy = false;
72    let mut unsafe_bytemuck = false;
73    let args_str = args.to_string();
74    let args: Vec<&str> = args_str.split(',').collect();
75    if args.len() > 2 {
76        panic!("Only two args are allowed to the account attribute.")
77    }
78    for arg in args {
79        let ns = arg
80            .to_string()
81            .replace('\"', "")
82            .chars()
83            .filter(|c| !c.is_whitespace())
84            .collect();
85        if ns == "zero_copy" {
86            is_zero_copy = true;
87            unsafe_bytemuck = false;
88        } else if ns == "zero_copy(unsafe)" {
89            is_zero_copy = true;
90            unsafe_bytemuck = true;
91        } else {
92            namespace = ns;
93        }
94    }
95
96    let account_strct = parse_macro_input!(input as syn::ItemStruct);
97    let account_name = &account_strct.ident;
98    let account_name_str = account_name.to_string();
99    let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
100
101    let discriminator: proc_macro2::TokenStream = {
102        // Namespace the discriminator to prevent collisions.
103        let discriminator_preimage = {
104            // For now, zero copy accounts can't be namespaced.
105            if namespace.is_empty() {
106                format!("account:{account_name}")
107            } else {
108                format!("{namespace}:{account_name}")
109            }
110        };
111
112        let mut discriminator = [0u8; 8];
113        discriminator.copy_from_slice(
114            &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
115        );
116        format!("{discriminator:?}").parse().unwrap()
117    };
118
119    let owner_impl = {
120        if namespace.is_empty() {
121            quote! {
122                #[automatically_derived]
123                impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
124                    fn owner() -> Pubkey {
125                        crate::ID
126                    }
127                }
128            }
129        } else {
130            quote! {}
131        }
132    };
133
134    let unsafe_bytemuck_impl = {
135        if unsafe_bytemuck {
136            quote! {
137                #[automatically_derived]
138                unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
139                #[automatically_derived]
140                unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
141            }
142        } else {
143            quote! {}
144        }
145    };
146
147    let bytemuck_derives = {
148        if !unsafe_bytemuck {
149            quote! {
150                #[zero_copy]
151            }
152        } else {
153            quote! {
154                #[zero_copy(unsafe)]
155            }
156        }
157    };
158
159    proc_macro::TokenStream::from({
160        if is_zero_copy {
161            quote! {
162                #bytemuck_derives
163                #account_strct
164
165                #unsafe_bytemuck_impl
166
167                #[automatically_derived]
168                impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
169
170                #[automatically_derived]
171                impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
172                    const DISCRIMINATOR: [u8; 8] = #discriminator;
173                }
174
175                // This trait is useful for clients deserializing accounts.
176                // It's expected on-chain programs deserialize via zero-copy.
177                #[automatically_derived]
178                impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
179                    fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
180                        if buf.len() < #discriminator.len() {
181                            return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
182                        }
183                        let given_disc = &buf[..8];
184                        if &#discriminator != given_disc {
185                            return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
186                        }
187                        Self::try_deserialize_unchecked(buf)
188                    }
189
190                    fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
191                        let data: &[u8] = &buf[8..];
192                        // Re-interpret raw bytes into the POD data structure.
193                        let account = anchor_lang::__private::bytemuck::from_bytes(data);
194                        // Copy out the bytes into a new, owned data structure.
195                        Ok(*account)
196                    }
197                }
198
199                #owner_impl
200            }
201        } else {
202            quote! {
203                #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
204                #account_strct
205
206                #[automatically_derived]
207                impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
208                    fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
209                        if writer.write_all(&#discriminator).is_err() {
210                            return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
211                        }
212
213                        if AnchorSerialize::serialize(self, writer).is_err() {
214                            return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
215                        }
216                        Ok(())
217                    }
218                }
219
220                #[automatically_derived]
221                impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
222                    fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
223                        if buf.len() < #discriminator.len() {
224                            return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
225                        }
226                        let given_disc = &buf[..8];
227                        if &#discriminator != given_disc {
228                            return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
229                        }
230                        Self::try_deserialize_unchecked(buf)
231                    }
232
233                    fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
234                        let mut data: &[u8] = &buf[8..];
235                        AnchorDeserialize::deserialize(&mut data)
236                            .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
237                    }
238                }
239
240                #[automatically_derived]
241                impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
242                    const DISCRIMINATOR: [u8; 8] = #discriminator;
243                }
244
245                #owner_impl
246            }
247        }
248    })
249}
250
251#[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
252pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
253    let account_strct = parse_macro_input!(item as syn::ItemStruct);
254    let account_name = &account_strct.ident;
255    let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
256
257    let fields = match &account_strct.fields {
258        syn::Fields::Named(n) => n,
259        _ => panic!("Fields must be named"),
260    };
261    let methods: Vec<proc_macro2::TokenStream> = fields
262        .named
263        .iter()
264        .filter_map(|field: &syn::Field| {
265            field
266                .attrs
267                .iter()
268                .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
269                .map(|attr| {
270                    let mut tts = attr.tokens.clone().into_iter();
271                    let g_stream = match tts.next().expect("Must have a token group") {
272                        proc_macro2::TokenTree::Group(g) => g.stream(),
273                        _ => panic!("Invalid syntax"),
274                    };
275                    let accessor_ty = match g_stream.into_iter().next() {
276                        Some(token) => token,
277                        _ => panic!("Missing accessor type"),
278                    };
279
280                    let field_name = field.ident.as_ref().unwrap();
281
282                    let get_field: proc_macro2::TokenStream =
283                        format!("get_{field_name}").parse().unwrap();
284                    let set_field: proc_macro2::TokenStream =
285                        format!("set_{field_name}").parse().unwrap();
286
287                    quote! {
288                        pub fn #get_field(&self) -> #accessor_ty {
289                            anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
290                        }
291                        pub fn #set_field(&mut self, input: &#accessor_ty) {
292                            self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
293                        }
294                    }
295                })
296        })
297        .collect();
298    proc_macro::TokenStream::from(quote! {
299        #[automatically_derived]
300        impl #impl_gen #account_name #ty_gen #where_clause {
301            #(#methods)*
302        }
303    })
304}
305
306/// A data structure that can be used as an internal field for a zero copy
307/// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
308///
309/// `#[zero_copy]` is just a convenient alias for
310///
311/// ```ignore
312/// #[derive(Copy, Clone)]
313/// #[derive(bytemuck::Zeroable)]
314/// #[derive(bytemuck::Pod)]
315/// #[repr(C)]
316/// struct MyStruct {...}
317/// ```
318#[proc_macro_attribute]
319pub fn zero_copy(
320    args: proc_macro::TokenStream,
321    item: proc_macro::TokenStream,
322) -> proc_macro::TokenStream {
323    let mut is_unsafe = false;
324    for arg in args.into_iter() {
325        match arg {
326            proc_macro::TokenTree::Ident(ident) => {
327                if ident.to_string() == "unsafe" {
328                    // `#[zero_copy(unsafe)]` maintains the old behaviour
329                    //
330                    // ```ignore
331                    // #[derive(Copy, Clone)]
332                    // #[repr(packed)]
333                    // struct MyStruct {...}
334                    // ```
335                    is_unsafe = true;
336                } else {
337                    // TODO: how to return a compile error with a span (can't return prase error because expected type TokenStream)
338                    panic!("expected single ident `unsafe`");
339                }
340            }
341            _ => {
342                panic!("expected single ident `unsafe`");
343            }
344        }
345    }
346
347    let account_strct = parse_macro_input!(item as syn::ItemStruct);
348
349    // Takes the first repr. It's assumed that more than one are not on the
350    // struct.
351    let attr = account_strct
352        .attrs
353        .iter()
354        .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
355
356    let repr = match attr {
357        // Users might want to manually specify repr modifiers e.g. repr(C, packed)
358        Some(_attr) => quote! {},
359        None => {
360            if is_unsafe {
361                quote! {#[repr(packed)]}
362            } else {
363                quote! {#[repr(C)]}
364            }
365        }
366    };
367
368    let mut has_pod_attr = false;
369    let mut has_zeroable_attr = false;
370    for attr in account_strct.attrs.iter() {
371        let token_string = attr.tokens.to_string();
372        if token_string.contains("bytemuck :: Pod") {
373            has_pod_attr = true;
374        }
375        if token_string.contains("bytemuck :: Zeroable") {
376            has_zeroable_attr = true;
377        }
378    }
379
380    // Once the Pod derive macro is expanded the compiler has to use the local crate's
381    // bytemuck `::bytemuck::Pod` anyway, so we're no longer using the privately
382    // exported anchor bytemuck `__private::bytemuck`, so that there won't be any
383    // possible disparity between the anchor version and the local crate's version.
384    let pod = if has_pod_attr || is_unsafe {
385        quote! {}
386    } else {
387        quote! {#[derive(::bytemuck::Pod)]}
388    };
389    let zeroable = if has_zeroable_attr || is_unsafe {
390        quote! {}
391    } else {
392        quote! {#[derive(::bytemuck::Zeroable)]}
393    };
394
395    let ret = quote! {
396        #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
397        #repr
398        #pod
399        #zeroable
400        #account_strct
401    };
402
403    #[cfg(feature = "idl-build")]
404    {
405        let derive_unsafe = if is_unsafe {
406            // Not a real proc-macro but exists in order to pass the serialization info
407            quote! { #[derive(bytemuck::Unsafe)] }
408        } else {
409            quote! {}
410        };
411        let zc_struct = syn::parse2(quote! {
412            #derive_unsafe
413            #ret
414        })
415        .unwrap();
416        let idl_build_impl = anchor_syn::idl::build::impl_idl_build_struct(&zc_struct);
417        return proc_macro::TokenStream::from(quote! {
418            #ret
419            #idl_build_impl
420        });
421    }
422
423    #[allow(unreachable_code)]
424    proc_macro::TokenStream::from(ret)
425}
426
427/// Defines the program's ID. This should be used at the root of all Anchor
428/// based programs.
429#[proc_macro]
430pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
431    #[cfg(feature = "idl-build")]
432    let address = input.clone().to_string();
433
434    let id = parse_macro_input!(input as id::Id);
435    let ret = quote! { #id };
436
437    #[cfg(feature = "idl-build")]
438    {
439        let idl_print = anchor_syn::idl::build::gen_idl_print_fn_address(address);
440        return proc_macro::TokenStream::from(quote! {
441            #ret
442            #idl_print
443        });
444    }
445
446    #[allow(unreachable_code)]
447    proc_macro::TokenStream::from(ret)
448}