atlas_sdk_macro/
lib.rs

1//! Convenience macro to declare a static public key and functions to interact with it
2//!
3//! Input: a single literal base58 string representation of a program's id
4#![cfg_attr(docsrs, feature(doc_cfg))]
5
6extern crate proc_macro;
7
8use {
9    proc_macro::TokenStream,
10    proc_macro2::Span,
11    quote::{quote, ToTokens},
12    syn::{
13        bracketed,
14        parse::{Parse, ParseStream, Result},
15        parse_macro_input,
16        punctuated::Punctuated,
17        token::Bracket,
18        Expr, Ident, LitByte, LitStr, Token,
19    },
20};
21
22fn parse_id(
23    input: ParseStream,
24    pubkey_type: proc_macro2::TokenStream,
25) -> Result<proc_macro2::TokenStream> {
26    let id = if input.peek(syn::LitStr) {
27        let id_literal: LitStr = input.parse()?;
28        parse_pubkey(&id_literal, &pubkey_type)?
29    } else {
30        let expr: Expr = input.parse()?;
31        quote! { #expr }
32    };
33
34    if !input.is_empty() {
35        let stream: proc_macro2::TokenStream = input.parse()?;
36        return Err(syn::Error::new_spanned(stream, "unexpected token"));
37    }
38    Ok(id)
39}
40
41fn id_to_tokens(
42    id: &proc_macro2::TokenStream,
43    pubkey_type: proc_macro2::TokenStream,
44    tokens: &mut proc_macro2::TokenStream,
45) {
46    tokens.extend(quote! {
47        /// The const program ID.
48        pub const ID: #pubkey_type = #id;
49
50        /// Returns `true` if given pubkey is the program ID.
51        // TODO make this const once `derive_const` makes it out of nightly
52        // and we can `derive_const(PartialEq)` on `Pubkey`.
53        pub fn check_id(id: &#pubkey_type) -> bool {
54            id == &ID
55        }
56
57        /// Returns the program ID.
58        pub const fn id() -> #pubkey_type {
59            ID
60        }
61
62        #[cfg(test)]
63        #[test]
64        fn test_id() {
65            assert!(check_id(&id()));
66        }
67    });
68}
69
70fn deprecated_id_to_tokens(
71    id: &proc_macro2::TokenStream,
72    pubkey_type: proc_macro2::TokenStream,
73    tokens: &mut proc_macro2::TokenStream,
74) {
75    tokens.extend(quote! {
76        /// The static program ID.
77        pub static ID: #pubkey_type = #id;
78
79        /// Returns `true` if given pubkey is the program ID.
80        #[deprecated()]
81        pub fn check_id(id: &#pubkey_type) -> bool {
82            id == &ID
83        }
84
85        /// Returns the program ID.
86        #[deprecated()]
87        pub fn id() -> #pubkey_type {
88            ID
89        }
90
91        #[cfg(test)]
92        #[test]
93        #[allow(deprecated)]
94        fn test_id() {
95            assert!(check_id(&id()));
96        }
97    });
98}
99
100struct Id(proc_macro2::TokenStream);
101
102impl Parse for Id {
103    fn parse(input: ParseStream) -> Result<Self> {
104        parse_id(input, quote! { ::atlas_sdk::pubkey::Pubkey }).map(Self)
105    }
106}
107
108impl ToTokens for Id {
109    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
110        id_to_tokens(&self.0, quote! { ::atlas_sdk::pubkey::Pubkey }, tokens)
111    }
112}
113
114struct IdDeprecated(proc_macro2::TokenStream);
115
116impl Parse for IdDeprecated {
117    fn parse(input: ParseStream) -> Result<Self> {
118        parse_id(input, quote! { ::atlas_sdk::pubkey::Pubkey }).map(Self)
119    }
120}
121
122impl ToTokens for IdDeprecated {
123    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
124        deprecated_id_to_tokens(&self.0, quote! { ::atlas_sdk::pubkey::Pubkey }, tokens)
125    }
126}
127
128#[proc_macro]
129pub fn declare_id(input: TokenStream) -> TokenStream {
130    let id = parse_macro_input!(input as Id);
131    TokenStream::from(quote! {#id})
132}
133
134#[proc_macro]
135pub fn declare_deprecated_id(input: TokenStream) -> TokenStream {
136    let id = parse_macro_input!(input as IdDeprecated);
137    TokenStream::from(quote! {#id})
138}
139
140fn parse_pubkey(
141    id_literal: &LitStr,
142    pubkey_type: &proc_macro2::TokenStream,
143) -> Result<proc_macro2::TokenStream> {
144    let id_vec = bs58::decode(id_literal.value())
145        .into_vec()
146        .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
147    let id_array = <[u8; 32]>::try_from(<&[u8]>::clone(&&id_vec[..])).map_err(|_| {
148        syn::Error::new_spanned(
149            id_literal,
150            format!("pubkey array is not 32 bytes long: len={}", id_vec.len()),
151        )
152    })?;
153    let bytes = id_array.iter().map(|b| LitByte::new(*b, Span::call_site()));
154    Ok(quote! {
155        #pubkey_type::new_from_array(
156            [#(#bytes,)*]
157        )
158    })
159}
160
161struct Pubkeys {
162    method: Ident,
163    num: usize,
164    pubkeys: proc_macro2::TokenStream,
165}
166impl Parse for Pubkeys {
167    fn parse(input: ParseStream) -> Result<Self> {
168        let pubkey_type = quote! {
169            ::atlas_sdk::pubkey::Pubkey
170        };
171
172        let method = input.parse()?;
173        let _comma: Token![,] = input.parse()?;
174        let (num, pubkeys) = if input.peek(syn::LitStr) {
175            let id_literal: LitStr = input.parse()?;
176            (1, parse_pubkey(&id_literal, &pubkey_type)?)
177        } else if input.peek(Bracket) {
178            let pubkey_strings;
179            bracketed!(pubkey_strings in input);
180            let punctuated: Punctuated<LitStr, Token![,]> =
181                Punctuated::parse_terminated(&pubkey_strings)?;
182            let mut pubkeys: Punctuated<proc_macro2::TokenStream, Token![,]> = Punctuated::new();
183            for string in punctuated.iter() {
184                pubkeys.push(parse_pubkey(string, &pubkey_type)?);
185            }
186            (pubkeys.len(), quote! {#pubkeys})
187        } else {
188            let stream: proc_macro2::TokenStream = input.parse()?;
189            return Err(syn::Error::new_spanned(stream, "unexpected token"));
190        };
191
192        Ok(Pubkeys {
193            method,
194            num,
195            pubkeys,
196        })
197    }
198}
199
200impl ToTokens for Pubkeys {
201    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
202        let Pubkeys {
203            method,
204            num,
205            pubkeys,
206        } = self;
207
208        let pubkey_type = quote! {
209            ::atlas_sdk::pubkey::Pubkey
210        };
211        if *num == 1 {
212            tokens.extend(quote! {
213                pub fn #method() -> #pubkey_type {
214                    #pubkeys
215                }
216            });
217        } else {
218            tokens.extend(quote! {
219                pub fn #method() -> ::std::vec::Vec<#pubkey_type> {
220                    vec![#pubkeys]
221                }
222            });
223        }
224    }
225}
226
227#[proc_macro]
228pub fn pubkeys(input: TokenStream) -> TokenStream {
229    let pubkeys = parse_macro_input!(input as Pubkeys);
230    TokenStream::from(quote! {#pubkeys})
231}
232
233// Sets padding in structures to zero explicitly.
234// Otherwise padding could be inconsistent across the network and lead to divergence / consensus failures.
235#[proc_macro_derive(CloneZeroed)]
236pub fn derive_clone_zeroed(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
237    match parse_macro_input!(input as syn::Item) {
238        syn::Item::Struct(item_struct) => {
239            let clone_statements = match item_struct.fields {
240                syn::Fields::Named(ref fields) => fields.named.iter().map(|f| {
241                    let name = &f.ident;
242                    quote! {
243                        core::ptr::addr_of_mut!((*ptr).#name).write(self.#name.clone());
244                    }
245                }),
246                _ => unimplemented!(),
247            };
248            let name = &item_struct.ident;
249            quote! {
250                impl Clone for #name {
251                    // Clippy lint `incorrect_clone_impl_on_copy_type` requires that clone
252                    // implementations on `Copy` types are simply wrappers of `Copy`.
253                    // This is not the case here, and intentionally so because we want to
254                    // guarantee zeroed padding.
255                    fn clone(&self) -> Self {
256                        let mut value = core::mem::MaybeUninit::<Self>::uninit();
257                        unsafe {
258                            core::ptr::write_bytes(&mut value, 0, 1);
259                            let ptr = value.as_mut_ptr();
260                            #(#clone_statements)*
261                            value.assume_init()
262                        }
263                    }
264                }
265            }
266        }
267        _ => unimplemented!(),
268    }
269    .into()
270}