reso_macros/
lib.rs

1use std::str::FromStr;
2
3extern crate proc_macro;
4
5use {
6    proc_macro::TokenStream,
7    proc_macro2::Span,
8    quote::{quote, ToTokens},
9    solana_program::pubkey::Pubkey,
10    std::convert::TryFrom,
11    syn::{
12        parse::{Parse, ParseStream, Result},
13        parse_macro_input, Expr, LitByte, LitStr, Token,
14    },
15};
16
17fn parse_id(input: ParseStream) -> Result<proc_macro2::TokenStream> {
18    let id = if input.peek(syn::LitStr) {
19        let id_literal: LitStr = input.parse()?;
20        parse_pubkey(&id_literal)?
21    } else {
22        let expr: Expr = input.parse()?;
23        quote! { #expr }
24    };
25
26    if !input.is_empty() {
27        let stream: proc_macro2::TokenStream = input.parse()?;
28        return Err(syn::Error::new_spanned(stream, "unexpected token"));
29    }
30    Ok(id)
31}
32
33fn parse_pubkey(id_literal: &LitStr) -> Result<proc_macro2::TokenStream> {
34    let id_vec = bs58::decode(id_literal.value())
35        .into_vec()
36        .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
37    let id_array = <[u8; 32]>::try_from(<&[u8]>::clone(&&id_vec[..])).map_err(|_| {
38        syn::Error::new_spanned(
39            id_literal,
40            format!("pubkey array is not 32 bytes long: len={}", id_vec.len()),
41        )
42    })?;
43    let bytes = id_array.iter().map(|b| LitByte::new(*b, Span::call_site()));
44    Ok(quote! {
45        Pubkey::new_from_array(
46            [#(#bytes,)*]
47        )
48    })
49}
50
51fn parse_pda(
52    id_literal: &LitStr,
53    program_id: &LitStr,
54    seed: &LitStr,
55) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
56    let pda_key = Pubkey::from_str(&id_literal.value())
57        .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
58    let program_id = Pubkey::from_str(&program_id.value())
59        .map_err(|_| syn::Error::new_spanned(id_literal, "failed to decode base58 string"))?;
60
61    let (computed_key, bump_seed) =
62        Pubkey::find_program_address(&[&seed.value().as_ref()], &program_id);
63
64    if pda_key != computed_key {
65        return Err(syn::Error::new_spanned(
66            id_literal,
67            "provided PDA does not match the computed PDA",
68        ));
69    }
70
71    let pda_token_stream = parse_pubkey(id_literal)?;
72
73    let bump = LitByte::new(bump_seed, Span::call_site());
74    let bump_token_stream = quote! {
75        #bump
76    };
77    Ok((pda_token_stream, bump_token_stream))
78}
79
80fn generate_static_pubkey_code(
81    id: &proc_macro2::TokenStream,
82    tokens: &mut proc_macro2::TokenStream,
83) {
84    tokens.extend(quote! {
85        /// The static program ID
86        pub static ID: Pubkey = #id;
87
88        /// Confirms that a given pubkey is equivalent to the program ID
89        pub fn check_id(id: &Pubkey) -> bool {
90            id == &ID
91        }
92
93        /// Returns the program ID
94        pub fn id() -> Pubkey {
95            ID
96        }
97
98        #[cfg(test)]
99        #[test]
100        fn test_id() {
101            assert!(check_id(&id()));
102        }
103    });
104}
105
106fn generate_static_bump_code(
107    bump: &proc_macro2::TokenStream,
108    tokens: &mut proc_macro2::TokenStream,
109) {
110    tokens.extend(quote! {
111        pub const BUMP: u8 = #bump;
112
113        pub fn bump() -> u8 {
114            BUMP
115        }
116    });
117}
118
119struct Id(proc_macro2::TokenStream);
120
121impl Parse for Id {
122    fn parse(input: ParseStream) -> Result<Self> {
123        parse_id(input).map(Self)
124    }
125}
126
127impl ToTokens for Id {
128    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
129        generate_static_pubkey_code(&self.0, tokens)
130    }
131}
132
133struct ProgramPdaArgs {
134    pda: proc_macro2::TokenStream,
135    bump: proc_macro2::TokenStream,
136}
137
138impl Parse for ProgramPdaArgs {
139    fn parse(input: ParseStream) -> Result<Self> {
140        let pda_address: LitStr = input.parse()?;
141        input.parse::<Token![,]>()?;
142        let program_id: LitStr = input.parse()?;
143        input.parse::<Token![,]>()?;
144        let seed: LitStr = input.parse()?;
145        if !input.is_empty() {
146            return Err(input.error("unexpected token"));
147        }
148        let (pda, bump) = parse_pda(&pda_address, &program_id, &seed)?;
149        Ok(Self { pda, bump })
150    }
151}
152
153impl ToTokens for ProgramPdaArgs {
154    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
155        generate_static_bump_code(&self.bump, tokens);
156        generate_static_pubkey_code(&self.pda, tokens)
157    }
158}
159
160#[proc_macro]
161pub fn declare_id(input: TokenStream) -> TokenStream {
162    let id = parse_macro_input!(input as Id);
163    TokenStream::from(quote! {#id})
164}
165
166#[proc_macro]
167pub fn declare_pda(input: TokenStream) -> TokenStream {
168    let id = parse_macro_input!(input as ProgramPdaArgs);
169    TokenStream::from(quote! {#id})
170}