bolt_anchor_attribute_account/
lib.rs1extern crate proc_macro;
2
3use quote::quote;
4use syn::parse_macro_input;
5
6mod id;
7
8#[proc_macro_attribute]
66pub fn account(
67 args: proc_macro::TokenStream,
68 input: proc_macro::TokenStream,
69) -> proc_macro::TokenStream {
70 let mut namespace = "".to_string();
71 let mut is_zero_copy = false;
72 let mut unsafe_bytemuck = false;
73 let args_str = args.to_string();
74 let args: Vec<&str> = args_str.split(',').collect();
75 if args.len() > 2 {
76 panic!("Only two args are allowed to the account attribute.")
77 }
78 for arg in args {
79 let ns = arg
80 .to_string()
81 .replace('\"', "")
82 .chars()
83 .filter(|c| !c.is_whitespace())
84 .collect();
85 if ns == "zero_copy" {
86 is_zero_copy = true;
87 unsafe_bytemuck = false;
88 } else if ns == "zero_copy(unsafe)" {
89 is_zero_copy = true;
90 unsafe_bytemuck = true;
91 } else {
92 namespace = ns;
93 }
94 }
95
96 let account_strct = parse_macro_input!(input as syn::ItemStruct);
97 let account_name = &account_strct.ident;
98 let account_name_str = account_name.to_string();
99 let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
100
101 let discriminator: proc_macro2::TokenStream = {
102 let discriminator_preimage = {
104 if namespace.is_empty() {
106 format!("account:{account_name}")
107 } else {
108 format!("{namespace}:{account_name}")
109 }
110 };
111
112 let mut discriminator = [0u8; 8];
113 discriminator.copy_from_slice(
114 &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
115 );
116 format!("{discriminator:?}").parse().unwrap()
117 };
118
119 let owner_impl = {
120 if namespace.is_empty() {
121 quote! {
122 #[automatically_derived]
123 impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
124 fn owner() -> Pubkey {
125 crate::ID
126 }
127 }
128 }
129 } else {
130 quote! {}
131 }
132 };
133
134 let unsafe_bytemuck_impl = {
135 if unsafe_bytemuck {
136 quote! {
137 #[automatically_derived]
138 unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
139 #[automatically_derived]
140 unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
141 }
142 } else {
143 quote! {}
144 }
145 };
146
147 let bytemuck_derives = {
148 if !unsafe_bytemuck {
149 quote! {
150 #[zero_copy]
151 }
152 } else {
153 quote! {
154 #[zero_copy(unsafe)]
155 }
156 }
157 };
158
159 proc_macro::TokenStream::from({
160 if is_zero_copy {
161 quote! {
162 #bytemuck_derives
163 #account_strct
164
165 #unsafe_bytemuck_impl
166
167 #[automatically_derived]
168 impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
169
170 #[automatically_derived]
171 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
172 const DISCRIMINATOR: [u8; 8] = #discriminator;
173 }
174
175 #[automatically_derived]
178 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
179 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
180 if buf.len() < #discriminator.len() {
181 return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
182 }
183 let given_disc = &buf[..8];
184 if &#discriminator != given_disc {
185 return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
186 }
187 Self::try_deserialize_unchecked(buf)
188 }
189
190 fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
191 let data: &[u8] = &buf[8..];
192 let account = anchor_lang::__private::bytemuck::from_bytes(data);
194 Ok(*account)
196 }
197 }
198
199 #owner_impl
200 }
201 } else {
202 quote! {
203 #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
204 #account_strct
205
206 #[automatically_derived]
207 impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
208 fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
209 if writer.write_all(&#discriminator).is_err() {
210 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
211 }
212
213 if AnchorSerialize::serialize(self, writer).is_err() {
214 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
215 }
216 Ok(())
217 }
218 }
219
220 #[automatically_derived]
221 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
222 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
223 if buf.len() < #discriminator.len() {
224 return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
225 }
226 let given_disc = &buf[..8];
227 if &#discriminator != given_disc {
228 return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
229 }
230 Self::try_deserialize_unchecked(buf)
231 }
232
233 fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
234 let mut data: &[u8] = &buf[8..];
235 AnchorDeserialize::deserialize(&mut data)
236 .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
237 }
238 }
239
240 #[automatically_derived]
241 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
242 const DISCRIMINATOR: [u8; 8] = #discriminator;
243 }
244
245 #owner_impl
246 }
247 }
248 })
249}
250
251#[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
252pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
253 let account_strct = parse_macro_input!(item as syn::ItemStruct);
254 let account_name = &account_strct.ident;
255 let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
256
257 let fields = match &account_strct.fields {
258 syn::Fields::Named(n) => n,
259 _ => panic!("Fields must be named"),
260 };
261 let methods: Vec<proc_macro2::TokenStream> = fields
262 .named
263 .iter()
264 .filter_map(|field: &syn::Field| {
265 field
266 .attrs
267 .iter()
268 .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
269 .map(|attr| {
270 let mut tts = attr.tokens.clone().into_iter();
271 let g_stream = match tts.next().expect("Must have a token group") {
272 proc_macro2::TokenTree::Group(g) => g.stream(),
273 _ => panic!("Invalid syntax"),
274 };
275 let accessor_ty = match g_stream.into_iter().next() {
276 Some(token) => token,
277 _ => panic!("Missing accessor type"),
278 };
279
280 let field_name = field.ident.as_ref().unwrap();
281
282 let get_field: proc_macro2::TokenStream =
283 format!("get_{field_name}").parse().unwrap();
284 let set_field: proc_macro2::TokenStream =
285 format!("set_{field_name}").parse().unwrap();
286
287 quote! {
288 pub fn #get_field(&self) -> #accessor_ty {
289 anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
290 }
291 pub fn #set_field(&mut self, input: &#accessor_ty) {
292 self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
293 }
294 }
295 })
296 })
297 .collect();
298 proc_macro::TokenStream::from(quote! {
299 #[automatically_derived]
300 impl #impl_gen #account_name #ty_gen #where_clause {
301 #(#methods)*
302 }
303 })
304}
305
306#[proc_macro_attribute]
319pub fn zero_copy(
320 args: proc_macro::TokenStream,
321 item: proc_macro::TokenStream,
322) -> proc_macro::TokenStream {
323 let mut is_unsafe = false;
324 for arg in args.into_iter() {
325 match arg {
326 proc_macro::TokenTree::Ident(ident) => {
327 if ident.to_string() == "unsafe" {
328 is_unsafe = true;
336 } else {
337 panic!("expected single ident `unsafe`");
339 }
340 }
341 _ => {
342 panic!("expected single ident `unsafe`");
343 }
344 }
345 }
346
347 let account_strct = parse_macro_input!(item as syn::ItemStruct);
348
349 let attr = account_strct
352 .attrs
353 .iter()
354 .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
355
356 let repr = match attr {
357 Some(_attr) => quote! {},
359 None => {
360 if is_unsafe {
361 quote! {#[repr(packed)]}
362 } else {
363 quote! {#[repr(C)]}
364 }
365 }
366 };
367
368 let mut has_pod_attr = false;
369 let mut has_zeroable_attr = false;
370 for attr in account_strct.attrs.iter() {
371 let token_string = attr.tokens.to_string();
372 if token_string.contains("bytemuck :: Pod") {
373 has_pod_attr = true;
374 }
375 if token_string.contains("bytemuck :: Zeroable") {
376 has_zeroable_attr = true;
377 }
378 }
379
380 let pod = if has_pod_attr || is_unsafe {
385 quote! {}
386 } else {
387 quote! {#[derive(::bytemuck::Pod)]}
388 };
389 let zeroable = if has_zeroable_attr || is_unsafe {
390 quote! {}
391 } else {
392 quote! {#[derive(::bytemuck::Zeroable)]}
393 };
394
395 let ret = quote! {
396 #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
397 #repr
398 #pod
399 #zeroable
400 #account_strct
401 };
402
403 #[cfg(feature = "idl-build")]
404 {
405 let derive_unsafe = if is_unsafe {
406 quote! { #[derive(bytemuck::Unsafe)] }
408 } else {
409 quote! {}
410 };
411 let zc_struct = syn::parse2(quote! {
412 #derive_unsafe
413 #ret
414 })
415 .unwrap();
416 let idl_build_impl = anchor_syn::idl::build::impl_idl_build_struct(&zc_struct);
417 return proc_macro::TokenStream::from(quote! {
418 #ret
419 #idl_build_impl
420 });
421 }
422
423 #[allow(unreachable_code)]
424 proc_macro::TokenStream::from(ret)
425}
426
427#[proc_macro]
430pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
431 #[cfg(feature = "idl-build")]
432 let address = input.clone().to_string();
433
434 let id = parse_macro_input!(input as id::Id);
435 let ret = quote! { #id };
436
437 #[cfg(feature = "idl-build")]
438 {
439 let idl_print = anchor_syn::idl::build::gen_idl_print_fn_address(address);
440 return proc_macro::TokenStream::from(quote! {
441 #ret
442 #idl_print
443 });
444 }
445
446 #[allow(unreachable_code)]
447 proc_macro::TokenStream::from(ret)
448}