1use std::str::FromStr;
2
3extern crate proc_macro;
4
5use {
6 proc_macro::TokenStream,
7 proc_macro2::Span,
8 quote::{quote, ToTokens},
9 solana_program::pubkey::Pubkey,
10 std::convert::TryFrom,
11 syn::{
12 parse::{Parse, ParseStream, Result},
13 parse_macro_input, Expr, LitByte, LitStr, Token,
14 },
15};
16
17fn parse_id(input: ParseStream) -> Result<proc_macro2::TokenStream> {
18 let id = if input.peek(syn::LitStr) {
19 let id_literal: LitStr = input.parse()?;
20 parse_pubkey(&id_literal)?
21 } else {
22 let expr: Expr = input.parse()?;
23 quote! { #expr }
24 };
25
26 if !input.is_empty() {
27 let stream: proc_macro2::TokenStream = input.parse()?;
28 return Err(syn::Error::new_spanned(stream, "unexpected token"));
29 }
30 Ok(id)
31}
32
33fn parse_pubkey(id_literal: &LitStr) -> Result<proc_macro2::TokenStream> {
34 let id_vec = bs58::decode(id_literal.value())
35 .into_vec()
36 .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
37 let id_array = <[u8; 32]>::try_from(<&[u8]>::clone(&&id_vec[..])).map_err(|_| {
38 syn::Error::new_spanned(
39 id_literal,
40 format!("pubkey array is not 32 bytes long: len={}", id_vec.len()),
41 )
42 })?;
43 let bytes = id_array.iter().map(|b| LitByte::new(*b, Span::call_site()));
44 Ok(quote! {
45 Pubkey::new_from_array(
46 [#(#bytes,)*]
47 )
48 })
49}
50
51fn parse_pda(
52 id_literal: &LitStr,
53 program_id: &LitStr,
54 seed: &LitStr,
55) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
56 let pda_key = Pubkey::from_str(&id_literal.value())
57 .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
58 let program_id = Pubkey::from_str(&program_id.value())
59 .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
60
61 let (computed_key, bump_seed) =
62 Pubkey::find_program_address(&[&seed.value().as_ref()], &program_id);
63
64 if pda_key != computed_key {
65 return Err(syn::Error::new_spanned(
66 id_literal,
67 "provided PDA does not match the computed PDA",
68 ));
69 }
70
71 let pda_token_stream = parse_pubkey(id_literal)?;
72
73 let bump = LitByte::new(bump_seed, Span::call_site());
74 let bump_token_stream = quote! {
75 #bump
76 };
77 Ok((pda_token_stream, bump_token_stream))
78}
79
80fn generate_static_pubkey_code(
81 id: &proc_macro2::TokenStream,
82 tokens: &mut proc_macro2::TokenStream,
83) {
84 tokens.extend(quote! {
85 pub static ID: Pubkey = #id;
87
88 pub fn check_id(id: &Pubkey) -> bool {
90 id == &ID
91 }
92
93 pub fn id() -> Pubkey {
95 ID
96 }
97
98 #[cfg(test)]
99 #[test]
100 fn test_id() {
101 assert!(check_id(&id()));
102 }
103 });
104}
105
106fn generate_static_bump_code(
107 bump: &proc_macro2::TokenStream,
108 tokens: &mut proc_macro2::TokenStream,
109) {
110 tokens.extend(quote! {
111 pub const BUMP: u8 = #bump;
112
113 pub fn bump() -> u8 {
114 BUMP
115 }
116 });
117}
118
119struct Id(proc_macro2::TokenStream);
120
121impl Parse for Id {
122 fn parse(input: ParseStream) -> Result<Self> {
123 parse_id(input).map(Self)
124 }
125}
126
127impl ToTokens for Id {
128 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
129 generate_static_pubkey_code(&self.0, tokens)
130 }
131}
132
133struct ProgramPdaArgs {
134 pda: proc_macro2::TokenStream,
135 bump: proc_macro2::TokenStream,
136}
137
138impl Parse for ProgramPdaArgs {
139 fn parse(input: ParseStream) -> Result<Self> {
140 let pda_address: LitStr = input.parse()?;
141 input.parse::<Token![,]>()?;
142 let program_id: LitStr = input.parse()?;
143 input.parse::<Token![,]>()?;
144 let seed: LitStr = input.parse()?;
145 if !input.is_empty() {
146 return Err(input.error("unexpected token"));
147 }
148 let (pda, bump) = parse_pda(&pda_address, &program_id, &seed)?;
149 Ok(Self { pda, bump })
150 }
151}
152
153impl ToTokens for ProgramPdaArgs {
154 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
155 generate_static_bump_code(&self.bump, tokens);
156 generate_static_pubkey_code(&self.pda, tokens)
157 }
158}
159
160#[proc_macro]
161pub fn declare_id(input: TokenStream) -> TokenStream {
162 let id = parse_macro_input!(input as Id);
163 TokenStream::from(quote! {#id})
164}
165
166#[proc_macro]
167pub fn declare_pda(input: TokenStream) -> TokenStream {
168 let id = parse_macro_input!(input as ProgramPdaArgs);
169 TokenStream::from(quote! {#id})
170}