Skip to main content

jiminy_core/check/
pda.rs

1//! PDA derivation utilities and ATA helpers.
2//!
3//! Macros and helpers for deriving program addresses without manual
4//! seed-array construction. Wraps [`derive_address`], [`derive_address_const`],
5//! and `Address::find_program_address`.
6
7use core::mem::MaybeUninit;
8
9use pinocchio::{
10    address::{MAX_SEEDS, PDA_MARKER},
11    error::ProgramError,
12    Address,
13};
14use sha2_const_stable::Sha256;
15
16/// Derive a [program address](https://solana.com/docs/core/pda) from the
17/// given seeds, optional bump, and program id.
18///
19/// Uses the `sol_sha256` syscall directly - avoids the cost of
20/// `create_program_address` (~1500 CU) at the expense of no curve-point
21/// validation.
22#[inline(always)]
23pub fn derive_address<const N: usize>(
24    seeds: &[&[u8]; N],
25    bump: Option<u8>,
26    program_id: &[u8; 32],
27) -> [u8; 32] {
28    const {
29        assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
30    }
31
32    const UNINIT: MaybeUninit<&[u8]> = MaybeUninit::<&[u8]>::uninit();
33    let mut data = [UNINIT; MAX_SEEDS + 2];
34    let mut i = 0;
35
36    while i < N {
37        unsafe {
38            data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
39        }
40        i += 1;
41    }
42
43    let bump_seed = [bump.unwrap_or_default()];
44
45    unsafe {
46        if bump.is_some() {
47            data.get_unchecked_mut(i).write(&bump_seed);
48            i += 1;
49        }
50        data.get_unchecked_mut(i).write(program_id.as_ref());
51        data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
52    }
53
54    #[cfg(target_os = "solana")]
55    {
56        let mut pda = MaybeUninit::<[u8; 32]>::uninit();
57
58        unsafe {
59            pinocchio::syscalls::sol_sha256(
60                data.as_ptr() as *const u8,
61                (i + 2) as u64,
62                pda.as_mut_ptr() as *mut u8,
63            );
64        }
65
66        unsafe { pda.assume_init() }
67    }
68
69    #[cfg(not(target_os = "solana"))]
70    {
71        let _ = data;
72        unreachable!("deriving a pda is only available on target `solana`");
73    }
74}
75
76/// Compile-time version of [`derive_address`].
77///
78/// Uses pure-Rust SHA-256 (`sha2-const-stable`) so the result is computed at
79/// compile time with zero runtime cost.
80#[inline(always)]
81pub const fn derive_address_const<const N: usize>(
82    seeds: &[&[u8]; N],
83    bump: Option<u8>,
84    program_id: &[u8; 32],
85) -> [u8; 32] {
86    const {
87        assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
88    }
89
90    let mut hasher = Sha256::new();
91    let mut i = 0;
92
93    while i < seeds.len() {
94        hasher = hasher.update(seeds[i]);
95        i += 1;
96    }
97
98    if let Some(bump) = bump {
99        hasher
100            .update(&[bump])
101            .update(program_id)
102            .update(PDA_MARKER)
103            .finalize()
104    } else {
105        hasher.update(program_id).update(PDA_MARKER).finalize()
106    }
107}
108
109/// Derive the associated token account (ATA) address for a wallet + mint pair.
110#[cfg(feature = "programs")]
111#[inline(always)]
112pub fn derive_ata(
113    wallet: &Address,
114    mint: &Address,
115) -> Result<(Address, u8), ProgramError> {
116    derive_ata_with_program(wallet, mint, &crate::programs::TOKEN)
117}
118
119/// Derive an ATA address with an explicit token program (SPL Token or Token-2022).
120#[cfg(feature = "programs")]
121#[inline(always)]
122pub fn derive_ata_with_program(
123    wallet: &Address,
124    mint: &Address,
125    token_program: &Address,
126) -> Result<(Address, u8), ProgramError> {
127    #[cfg(target_os = "solana")]
128    {
129        let seeds: &[&[u8]] = &[
130            wallet.as_ref(),
131            token_program.as_ref(),
132            mint.as_ref(),
133        ];
134        let (address, bump) = Address::find_program_address(seeds, &crate::programs::ASSOCIATED_TOKEN);
135        Ok((address, bump))
136    }
137    #[cfg(not(target_os = "solana"))]
138    {
139        let _ = (wallet, mint, token_program);
140        Err(ProgramError::InvalidSeeds)
141    }
142}
143
144/// Derive an ATA address with a known bump. Skips the bump search.
145#[cfg(feature = "programs")]
146#[inline(always)]
147pub fn derive_ata_with_bump(
148    wallet: &Address,
149    mint: &Address,
150    bump: u8,
151) -> Address {
152    Address::new_from_array(derive_address(
153        &[wallet.as_ref(), crate::programs::TOKEN.as_array().as_ref(), mint.as_ref()],
154        Some(bump),
155        crate::programs::ASSOCIATED_TOKEN.as_array(),
156    ))
157}
158
159/// Derive an ATA address at compile time. Requires known bump.
160#[cfg(feature = "programs")]
161#[macro_export]
162macro_rules! derive_ata_const {
163    ($wallet:expr, $mint:expr, $bump:expr) => {{
164        const TOKEN_BYTES: [u8; 32] = $crate::programs::TOKEN.to_bytes();
165        const ATA_BYTES: [u8; 32] = $crate::programs::ASSOCIATED_TOKEN.to_bytes();
166        ::pinocchio::Address::new_from_array($crate::check::pda::derive_address_const(
167            &[&$wallet, &TOKEN_BYTES, &$mint],
168            Some($bump),
169            &ATA_BYTES,
170        ))
171    }};
172}
173
174// ── Macros ───────────────────────────────────────────────────────────────────
175
176/// Find a PDA and return `(Address, u8)` with the canonical bump.
177///
178/// Uses the `find_program_address` syscall. Only available on-chain.
179#[macro_export]
180macro_rules! find_pda {
181    ($program_id:expr, $($seed:expr),+ $(,)?) => {{
182        #[cfg(target_os = "solana")]
183        {
184            let seeds: &[&[u8]] = &[$($seed.as_ref()),+];
185            ::pinocchio::Address::find_program_address(seeds, $program_id)
186        }
187        #[cfg(not(target_os = "solana"))]
188        {
189            let _ = ($program_id, $($seed),+);
190            unreachable!("find_pda! is only available on target solana")
191        }
192    }};
193}
194
195/// Derive a PDA with a known bump. Cheap (~100 CU, no curve check).
196///
197/// Wraps [`derive_address`]. The bump is appended automatically. Returns `Address`.
198#[macro_export]
199macro_rules! derive_pda {
200    ($program_id:expr, $bump:expr, $($seed:expr),+ $(,)?) => {{
201        ::pinocchio::Address::new_from_array($crate::check::pda::derive_address(
202            &[$($seed.as_ref()),+],
203            Some($bump),
204            ($program_id).as_array(),
205        ))
206    }};
207}
208
209/// Derive a PDA at compile time. Requires `const` seeds and bump.
210#[macro_export]
211macro_rules! derive_pda_const {
212    ($program_id:expr, $bump:expr, $($seed:expr),+ $(,)?) => {
213        ::pinocchio::Address::new_from_array($crate::check::pda::derive_address_const(
214            &[$(&$seed),+],
215            Some($bump),
216            &$program_id,
217        ))
218    };
219}
220
221/// Verify a token account is the correct ATA for a wallet + mint pair.
222#[cfg(feature = "programs")]
223#[inline(always)]
224pub fn check_ata(
225    account: &pinocchio::AccountView,
226    wallet: &Address,
227    mint: &Address,
228) -> pinocchio::ProgramResult {
229    let (expected, _) = derive_ata(wallet, mint)?;
230    if *account.address() != expected {
231        return Err(ProgramError::InvalidSeeds);
232    }
233    Ok(())
234}
235
236/// Verify a token account is the correct ATA for a specific token program.
237#[cfg(feature = "programs")]
238#[inline(always)]
239pub fn check_ata_with_program(
240    account: &pinocchio::AccountView,
241    wallet: &Address,
242    mint: &Address,
243    token_program: &Address,
244) -> pinocchio::ProgramResult {
245    let (expected, _) = derive_ata_with_program(wallet, mint, token_program)?;
246    if *account.address() != expected {
247        return Err(ProgramError::InvalidSeeds);
248    }
249    Ok(())
250}
251
252/// Derive a PDA from seeds, verify the account matches, and return the bump.
253///
254/// Wraps [`assert_pda`](super::assert_pda) as a macro so you can pass
255/// seeds inline without manual slice construction.
256///
257/// ```rust,ignore
258/// let bump = require_pda!(vault_account, program_id, b"vault", user.address())?;
259/// ```
260#[macro_export]
261macro_rules! require_pda {
262    ($account:expr, $program_id:expr, $($seed:expr),+ $(,)?) => {{
263        let seeds: &[&[u8]] = &[$($seed.as_ref()),+];
264        $crate::check::assert_pda($account, seeds, $program_id)
265    }};
266}