Skip to main content

hopper_native/
pda.rs

1//! PDA (Program Derived Address) helpers.
2//!
3//! Direct syscall-based PDA creation and derivation. No external dependencies.
4
5use crate::account_view::AccountView;
6use crate::address::{Address, MAX_SEEDS};
7use crate::error::ProgramError;
8
9#[cfg(target_os = "solana")]
10const CURVE25519_EDWARDS: u64 = 0;
11#[cfg(target_os = "solana")]
12const PDA_MARKER_BYTES: &[u8; 21] = crate::address::PDA_MARKER;
13
14/// Create a program-derived address from seeds and a program ID.
15///
16/// Returns `Err(InvalidSeeds)` if the derived address falls on the
17/// ed25519 curve (not a valid PDA).
18#[inline(always)]
19pub fn create_program_address(
20    seeds: &[&[u8]],
21    program_id: &Address,
22) -> Result<Address, ProgramError> {
23    #[cfg(target_os = "solana")]
24    {
25        // Build the seeds array in the format expected by the syscall:
26        // each seed is a (ptr, len) pair packed as two u64 values.
27        let mut seed_buf: [u64; 32] = [0; 32]; // MAX_SEEDS * 2
28        let num_seeds = seeds.len().min(16);
29        let mut i = 0;
30        while i < num_seeds {
31            seed_buf[i * 2] = seeds[i].as_ptr() as u64;
32            seed_buf[i * 2 + 1] = seeds[i].len() as u64;
33            i += 1;
34        }
35
36        let mut result = Address::default();
37        let rc = unsafe {
38            crate::syscalls::sol_create_program_address(
39                seed_buf.as_ptr() as *const u8,
40                num_seeds as u64,
41                program_id.as_array().as_ptr(),
42                result.0.as_mut_ptr(),
43            )
44        };
45        if rc == 0 {
46            Ok(result)
47        } else {
48            Err(ProgramError::InvalidSeeds)
49        }
50    }
51    #[cfg(not(target_os = "solana"))]
52    {
53        let _ = (seeds, program_id);
54        Err(ProgramError::InvalidSeeds)
55    }
56}
57
58/// Find a program-derived address and its bump seed.
59///
60/// Iterates bump seeds 255..=0 until a valid PDA is found.
61#[inline(always)]
62pub fn find_program_address(seeds: &[&[u8]], program_id: &Address) -> (Address, u8) {
63    #[cfg(target_os = "solana")]
64    {
65        based_try_find_program_address(seeds, program_id).unwrap_or((Address::default(), 0))
66    }
67    #[cfg(not(target_os = "solana"))]
68    {
69        let _ = (seeds, program_id);
70        (Address::default(), 0)
71    }
72}
73
74/// Verify that an expected address matches the PDA hash for the provided seeds.
75///
76/// The seeds slice must already include the bump byte.
77#[inline(always)]
78pub fn verify_program_address(
79    seeds: &[&[u8]],
80    program_id: &Address,
81    expected: &Address,
82) -> Result<(), ProgramError> {
83    if seeds.len() > MAX_SEEDS + 1 {
84        return Err(ProgramError::InvalidSeeds);
85    }
86
87    #[cfg(target_os = "solana")]
88    {
89        let n = seeds.len();
90        let mut slices = core::mem::MaybeUninit::<[&[u8]; MAX_SEEDS + 3]>::uninit();
91        let slice_ptr = slices.as_mut_ptr() as *mut &[u8];
92
93        let mut i = 0;
94        while i < n {
95            unsafe { slice_ptr.add(i).write(seeds[i]) };
96            i += 1;
97        }
98        unsafe {
99            slice_ptr.add(n).write(program_id.as_ref());
100            slice_ptr.add(n + 1).write(PDA_MARKER_BYTES.as_slice());
101        }
102
103        let input = unsafe { core::slice::from_raw_parts(slice_ptr, n + 2) };
104        let mut hash = core::mem::MaybeUninit::<[u8; 32]>::uninit();
105
106        unsafe {
107            crate::syscalls::sol_sha256(
108                input as *const _ as *const u8,
109                input.len() as u64,
110                hash.as_mut_ptr() as *mut u8,
111            );
112        }
113
114        let derived = unsafe { &*(hash.as_ptr() as *const Address) };
115        if derived == expected {
116            Ok(())
117        } else {
118            Err(ProgramError::InvalidSeeds)
119        }
120    }
121    #[cfg(not(target_os = "solana"))]
122    {
123        let _ = (seeds, program_id, expected);
124        Err(ProgramError::InvalidSeeds)
125    }
126}
127
128/// Find a valid PDA by hashing seeds directly and checking curve validity.
129///
130/// This avoids the `sol_try_find_program_address` syscall and substantially
131/// reduces the per-attempt CU cost on SBF.
132#[inline(always)]
133pub fn based_try_find_program_address(
134    seeds: &[&[u8]],
135    program_id: &Address,
136) -> Result<(Address, u8), ProgramError> {
137    if seeds.len() > MAX_SEEDS {
138        return Err(ProgramError::InvalidSeeds);
139    }
140
141    #[cfg(target_os = "solana")]
142    {
143        let n = seeds.len();
144        let mut slices = core::mem::MaybeUninit::<[&[u8]; MAX_SEEDS + 3]>::uninit();
145        let slice_ptr = slices.as_mut_ptr() as *mut &[u8];
146
147        let mut i = 0;
148        while i < n {
149            unsafe { slice_ptr.add(i).write(seeds[i]) };
150            i += 1;
151        }
152        unsafe {
153            slice_ptr.add(n + 1).write(program_id.as_ref());
154            slice_ptr.add(n + 2).write(PDA_MARKER_BYTES.as_slice());
155        }
156
157        let mut bump_seed = [u8::MAX];
158        let bump_ptr = bump_seed.as_mut_ptr();
159        unsafe {
160            slice_ptr
161                .add(n)
162                .write(core::slice::from_raw_parts(bump_ptr, 1))
163        };
164
165        let input = unsafe { core::slice::from_raw_parts(slice_ptr, n + 3) };
166        let mut hash = core::mem::MaybeUninit::<[u8; 32]>::uninit();
167        let mut bump: u64 = u8::MAX as u64;
168
169        loop {
170            unsafe { bump_ptr.write(bump as u8) };
171
172            unsafe {
173                crate::syscalls::sol_sha256(
174                    input as *const _ as *const u8,
175                    input.len() as u64,
176                    hash.as_mut_ptr() as *mut u8,
177                );
178            }
179
180            let on_curve = unsafe {
181                crate::syscalls::sol_curve_validate_point(
182                    CURVE25519_EDWARDS,
183                    hash.as_ptr() as *const u8,
184                    core::ptr::null_mut(),
185                )
186            };
187
188            if on_curve != 0 {
189                return Ok((
190                    Address::new_from_array(unsafe { hash.assume_init() }),
191                    bump as u8,
192                ));
193            }
194
195            if bump == 0 {
196                break;
197            }
198            bump -= 1;
199        }
200
201        Err(ProgramError::InvalidSeeds)
202    }
203    #[cfg(not(target_os = "solana"))]
204    {
205        let _ = (seeds, program_id);
206        Err(ProgramError::InvalidSeeds)
207    }
208}
209
210/// Verify that an account's address matches a PDA derived from the given seeds.
211///
212/// Returns `Ok(())` if the account address matches the derived PDA,
213/// or `Err(InvalidSeeds)` if it does not.
214#[inline(always)]
215pub fn verify_pda(
216    account: &AccountView,
217    seeds: &[&[u8]],
218    program_id: &Address,
219) -> Result<(), ProgramError> {
220    let expected = create_program_address(seeds, program_id)?;
221    if account.address() == &expected {
222        Ok(())
223    } else {
224        Err(ProgramError::InvalidSeeds)
225    }
226}
227
228/// Verify a PDA with an explicit bump seed appended to the seeds.
229///
230/// Appends `&[bump]` to the end of the seed list before verifying via
231/// SHA-256 (~200 CU). This is substantially cheaper than the syscall-based
232/// `create_program_address` approach (~1500 CU).
233#[inline]
234pub fn verify_pda_with_bump(
235    account: &AccountView,
236    seeds: &[&[u8]],
237    bump: u8,
238    program_id: &Address,
239) -> Result<(), ProgramError> {
240    // Build a seed list with the bump appended.
241    // We use a stack-allocated array since MAX_SEEDS is 16.
242    let mut full_seeds: [&[u8]; 17] = [&[]; 17];
243    let num = seeds.len().min(15);
244    let mut i = 0;
245    while i < num {
246        full_seeds[i] = seeds[i];
247        i += 1;
248    }
249    let bump_bytes = [bump];
250    full_seeds[num] = &bump_bytes;
251
252    verify_program_address(&full_seeds[..num + 1], program_id, account.address())
253}
254
255/// Verify that an address matches a PDA derived from the given seeds.
256///
257/// Unlike `verify_pda` which takes an `AccountView`, this accepts a raw
258/// `Address` reference directly. Useful when validating addresses outside
259/// of the account parsing flow (e.g. instruction data, cross-program reads).
260///
261/// The seeds slice must already include the bump byte (like
262/// `verify_program_address`). Uses SHA-256 verify-only path (~200 CU)
263/// instead of the full `find_program_address` (~1500 CU).
264///
265/// Returns `Ok(())` if the address matches the derived PDA,
266/// or `Err(InvalidSeeds)` if it does not.
267#[inline]
268pub fn verify_pda_strict(
269    expected: &Address,
270    seeds: &[&[u8]],
271    program_id: &Address,
272) -> Result<(), ProgramError> {
273    verify_program_address(seeds, program_id, expected)
274}
275
276/// Find the bump seed for a known PDA address, skipping curve validation.
277///
278/// When you already know the expected address (e.g. from a transaction
279/// account), there is no need to validate the derived hash is off-curve.
280/// If the hash matches `expected` and the account exists on-chain, it
281/// must be a valid PDA. This saves ~90 CU per attempt compared to
282/// `based_try_find_program_address` which calls `sol_curve_validate_point`.
283///
284/// Returns the bump seed, or `Err(InvalidSeeds)` if no bump produces a match.
285#[inline(always)]
286pub fn find_bump_for_address(
287    seeds: &[&[u8]],
288    program_id: &Address,
289    expected: &Address,
290) -> Result<u8, ProgramError> {
291    if seeds.len() > MAX_SEEDS {
292        return Err(ProgramError::InvalidSeeds);
293    }
294
295    #[cfg(target_os = "solana")]
296    {
297        let n = seeds.len();
298        let mut slices = core::mem::MaybeUninit::<[&[u8]; MAX_SEEDS + 3]>::uninit();
299        let slice_ptr = slices.as_mut_ptr() as *mut &[u8];
300
301        let mut i = 0;
302        while i < n {
303            unsafe { slice_ptr.add(i).write(seeds[i]) };
304            i += 1;
305        }
306        unsafe {
307            slice_ptr.add(n + 1).write(program_id.as_ref());
308            slice_ptr.add(n + 2).write(PDA_MARKER_BYTES.as_slice());
309        }
310
311        let mut bump_seed = [u8::MAX];
312        let bump_ptr = bump_seed.as_mut_ptr();
313        unsafe {
314            slice_ptr
315                .add(n)
316                .write(core::slice::from_raw_parts(bump_ptr, 1))
317        };
318
319        let input = unsafe { core::slice::from_raw_parts(slice_ptr, n + 3) };
320        let mut hash = core::mem::MaybeUninit::<[u8; 32]>::uninit();
321        let mut bump: u64 = u8::MAX as u64;
322
323        loop {
324            unsafe { bump_ptr.write(bump as u8) };
325
326            unsafe {
327                crate::syscalls::sol_sha256(
328                    input as *const _ as *const u8,
329                    input.len() as u64,
330                    hash.as_mut_ptr() as *mut u8,
331                );
332            }
333
334            // Address-match shortcut: skip curve check entirely.
335            // If the hash matches the expected address and that address
336            // exists on-chain, it is guaranteed to be a valid PDA.
337            let derived = unsafe { &*(hash.as_ptr() as *const Address) };
338            if derived == expected {
339                return Ok(bump as u8);
340            }
341
342            if bump == 0 {
343                break;
344            }
345            bump -= 1;
346        }
347
348        Err(ProgramError::InvalidSeeds)
349    }
350    #[cfg(not(target_os = "solana"))]
351    {
352        let _ = (seeds, program_id, expected);
353        Err(ProgramError::InvalidSeeds)
354    }
355}
356
357/// Read the bump byte directly from account data at a known offset.
358///
359/// Used with `BUMP_OFFSET` from `hopper_layout!` types to read the stored
360/// bump without any derivation. Combined with `verify_program_address`,
361/// the total PDA verification cost is ~200 CU vs ~1500 CU for
362/// `find_program_address`.
363///
364/// Returns `Err(AccountDataTooSmall)` if the account data is shorter than
365/// `bump_offset + 1`.
366#[inline(always)]
367pub fn read_bump_from_account(
368    account: &AccountView,
369    bump_offset: usize,
370) -> Result<u8, ProgramError> {
371    let data = account.try_borrow()?;
372    if data.len() <= bump_offset {
373        return Err(ProgramError::AccountDataTooSmall);
374    }
375    Ok(data[bump_offset])
376}
377
378/// Verify a PDA using the bump stored in account data (cheapest path).
379///
380/// Reads the bump at `bump_offset`, appends it to seeds, then uses
381/// SHA-256 verify-only. Total cost: ~200 CU vs ~1500 CU.
382///
383/// This is the optimal PDA verification path and should be the default
384/// for Hopper programs that store bumps in their account layout.
385#[inline]
386pub fn verify_pda_from_stored_bump(
387    account: &AccountView,
388    seeds: &[&[u8]],
389    bump_offset: usize,
390    program_id: &Address,
391) -> Result<(), ProgramError> {
392    let bump = read_bump_from_account(account, bump_offset)?;
393
394    let mut full_seeds: [&[u8]; 17] = [&[]; 17];
395    let num = seeds.len().min(15);
396    let mut i = 0;
397    while i < num {
398        full_seeds[i] = seeds[i];
399        i += 1;
400    }
401    let bump_bytes = [bump];
402    full_seeds[num] = &bump_bytes;
403
404    verify_program_address(&full_seeds[..num + 1], program_id, account.address())
405}