Skip to main content

ephemeral_vrf_sdk_vrf_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, ItemStruct};
5
6/// Resolve the base path to the VRF SDK that generated code references.
7///
8/// The generated code needs an absolute path to the VRF SDK. Users may depend on it in two ways:
9/// - directly, as `ephemeral-vrf-sdk` (the macro emits `::ephemeral_vrf_sdk`), or
10/// - via `ephemeral-rollups-sdk`, which re-exports it under `::ephemeral_rollups_sdk::vrf`.
11///
12/// The `rollups` feature selects between the two. `ephemeral-rollups-sdk` enables it (see its
13/// `anchor-support` feature) so users depending only on the rollups SDK build without a direct
14/// `ephemeral-vrf-sdk` dependency. With the feature off (direct VRF SDK users), the historical
15/// `::ephemeral_vrf_sdk` path is emitted unchanged.
16fn vrf_sdk_path() -> TokenStream2 {
17    if cfg!(feature = "rollups") {
18        quote!(::ephemeral_rollups_sdk::vrf)
19    } else {
20        quote!(::ephemeral_vrf_sdk)
21    }
22}
23
24#[proc_macro_attribute]
25pub fn vrf(_attr: TokenStream, item: TokenStream) -> TokenStream {
26    let input = parse_macro_input!(item as ItemStruct);
27
28    let vrf = vrf_sdk_path();
29    let unchecked_account = generated_unchecked_account_type();
30    let struct_name = &input.ident;
31    let fields = &input.fields;
32    let original_attrs = &input.attrs;
33    let mut new_fields = Vec::new();
34    let mut has_program_identity = false;
35    let mut has_slot_hashes = false;
36    let mut has_vrf_program = false;
37    let mut has_system_program = false;
38
39    for field in fields.iter() {
40        let field_attrs = field.attrs.clone();
41
42        let field_name = match &field.ident {
43            Some(name) => name,
44            None => {
45                return syn::Error::new_spanned(
46                    field,
47                    "Unnamed fields are not supported in this macro",
48                )
49                .to_compile_error()
50                .into();
51            }
52        };
53
54        let field_type = &field.ty;
55        new_fields.push(quote! {
56            #(#field_attrs)*
57            pub #field_name: #field_type,
58        });
59
60        // Check for existing required fields
61        if field_name.eq("program_identity") {
62            has_program_identity = true;
63        }
64        if field_name.eq("vrf_program") {
65            has_vrf_program = true;
66        }
67        if field_name.eq("slot_hashes") {
68            has_slot_hashes = true;
69        }
70        if field_name.eq("system_program") {
71            has_system_program = true;
72        }
73    }
74
75    // Add missing required fields
76    if !has_program_identity {
77        new_fields.push(quote! {
78            /// CHECK: Used to verify the identity of the program
79            #[account(seeds = [b"identity"], bump)]
80            pub program_identity: #unchecked_account,
81        });
82    }
83    if !has_vrf_program {
84        new_fields.push(quote! {
85            pub vrf_program: Program<'info, #vrf::anchor::VrfProgram>,
86        });
87    }
88    if !has_slot_hashes {
89        new_fields.push(quote! {
90            /// CHECK: Slot hashes sysvar
91            #[account(address = #vrf::compat::slot_hashes::ID)]
92            pub slot_hashes: #unchecked_account,
93        });
94    }
95    if !has_system_program {
96        new_fields.push(quote! {
97            pub system_program: Program<'info, System>,
98        });
99    }
100
101    // Generate the new struct definition
102    let expanded = quote! {
103        #(#original_attrs)*
104        pub struct #struct_name<'info> {
105            #(#new_fields)*
106        }
107
108        impl<'info> #struct_name<'info> {
109            fn invoke_signed_vrf<'a>(&self, payer: &'a AccountInfo<'info>, ix: &#vrf::compat::Instruction) -> #vrf::compat::anchor_lang::solana_program::entrypoint::ProgramResult {
110                let bump = Pubkey::try_find_program_address(&[#vrf::consts::IDENTITY], &crate::ID).ok_or(#vrf::compat::anchor_lang::prelude::ProgramError::InvalidSeeds)?;
111                // `#[vrf]` issues scoped randomness requests by default: the fulfillment signs
112                // the callback with the per-program scoped identity PDA, which the callback
113                // validates (see `#[vrf_callback]`). Map any legacy request discriminator to its
114                // scoped equivalent (3/11 -> 11 high priority, else -> 10).
115                let mut ix = ix.clone();
116                if let Some(disc) = ix.data.first_mut() {
117                    *disc = if *disc == 3 || *disc == 11 { 11 } else { 10 };
118                }
119                #vrf::compat::anchor_lang::solana_program::program::invoke_signed(
120                    &ix,
121                    &[
122                        payer.clone(),
123                        self.program_identity.to_account_info(),
124                        self.oracle_queue.to_account_info(),
125                        self.slot_hashes.to_account_info(),
126                    ],
127                    &[&[#vrf::consts::IDENTITY, &[bump.1]]],
128                )
129            }
130        }
131    };
132
133    TokenStream::from(expanded)
134}
135
136/// Attribute macro for a callback (consume) `#[derive(Accounts)]` struct.
137///
138/// Injects a `vrf_program_identity: Signer<'info>` constrained to the scoped per-program VRF
139/// identity PDA (`scoped_vrf_identity(&crate::ID)`). This is the default way to authenticate
140/// the VRF program in a callback; the identity is bound to this program. The legacy
141/// global-identity check (`address = VRF_PROGRAM_IDENTITY`) is deprecated.
142///
143/// Place `#[vrf_callback]` ABOVE `#[derive(Accounts)]`.
144#[proc_macro_attribute]
145pub fn vrf_callback(_attr: TokenStream, item: TokenStream) -> TokenStream {
146    let input = parse_macro_input!(item as ItemStruct);
147    let vrf = vrf_sdk_path();
148    let struct_name = &input.ident;
149    let original_attrs = &input.attrs;
150
151    let mut new_fields = Vec::new();
152    let mut has_identity = false;
153    for field in input.fields.iter() {
154        let field_attrs = field.attrs.clone();
155        let field_name = match &field.ident {
156            Some(name) => name,
157            None => {
158                return syn::Error::new_spanned(field, "Unnamed fields are not supported")
159                    .to_compile_error()
160                    .into();
161            }
162        };
163        let field_type = &field.ty;
164        new_fields.push(quote! {
165            #(#field_attrs)*
166            pub #field_name: #field_type,
167        });
168        if field_name.eq("vrf_program_identity") {
169            has_identity = true;
170        }
171    }
172
173    if !has_identity {
174        new_fields.insert(
175            0,
176            quote! {
177                /// Scoped VRF identity PDA, bound to this program. Its presence as a signer proves
178                /// the callback was issued by the VRF program for this program.
179                #[account(address = #vrf::consts::scoped_vrf_identity(&crate::ID))]
180                pub vrf_program_identity: Signer<'info>,
181            },
182        );
183    }
184
185    let expanded = quote! {
186        #(#original_attrs)*
187        pub struct #struct_name<'info> {
188            #(#new_fields)*
189        }
190    };
191
192    TokenStream::from(expanded)
193}
194
195fn generated_unchecked_account_type() -> proc_macro2::TokenStream {
196    if cfg!(feature = "backward-compat") {
197        quote! { AccountInfo<'info> }
198    } else {
199        quote! { UncheckedAccount<'info> }
200    }
201}