1use core::mem::MaybeUninit;
8
9use pinocchio::{
10 address::{MAX_SEEDS, PDA_MARKER},
11 error::ProgramError,
12 Address,
13};
14use sha2_const_stable::Sha256;
15
16#[inline(always)]
23pub fn derive_address<const N: usize>(
24 seeds: &[&[u8]; N],
25 bump: Option<u8>,
26 program_id: &[u8; 32],
27) -> [u8; 32] {
28 const {
29 assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
30 }
31
32 const UNINIT: MaybeUninit<&[u8]> = MaybeUninit::<&[u8]>::uninit();
33 let mut data = [UNINIT; MAX_SEEDS + 2];
34 let mut i = 0;
35
36 while i < N {
37 unsafe {
38 data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
39 }
40 i += 1;
41 }
42
43 let bump_seed = [bump.unwrap_or_default()];
44
45 unsafe {
46 if bump.is_some() {
47 data.get_unchecked_mut(i).write(&bump_seed);
48 i += 1;
49 }
50 data.get_unchecked_mut(i).write(program_id.as_ref());
51 data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
52 }
53
54 #[cfg(target_os = "solana")]
55 {
56 let mut pda = MaybeUninit::<[u8; 32]>::uninit();
57
58 unsafe {
59 pinocchio::syscalls::sol_sha256(
60 data.as_ptr() as *const u8,
61 (i + 2) as u64,
62 pda.as_mut_ptr() as *mut u8,
63 );
64 }
65
66 unsafe { pda.assume_init() }
67 }
68
69 #[cfg(not(target_os = "solana"))]
70 {
71 let _ = data;
72 unreachable!("deriving a pda is only available on target `solana`");
73 }
74}
75
76#[inline(always)]
81pub const fn derive_address_const<const N: usize>(
82 seeds: &[&[u8]; N],
83 bump: Option<u8>,
84 program_id: &[u8; 32],
85) -> [u8; 32] {
86 const {
87 assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
88 }
89
90 let mut hasher = Sha256::new();
91 let mut i = 0;
92
93 while i < seeds.len() {
94 hasher = hasher.update(seeds[i]);
95 i += 1;
96 }
97
98 if let Some(bump) = bump {
99 hasher
100 .update(&[bump])
101 .update(program_id)
102 .update(PDA_MARKER)
103 .finalize()
104 } else {
105 hasher.update(program_id).update(PDA_MARKER).finalize()
106 }
107}
108
109#[cfg(feature = "programs")]
111#[inline(always)]
112pub fn derive_ata(
113 wallet: &Address,
114 mint: &Address,
115) -> Result<(Address, u8), ProgramError> {
116 derive_ata_with_program(wallet, mint, &crate::programs::TOKEN)
117}
118
119#[cfg(feature = "programs")]
121#[inline(always)]
122pub fn derive_ata_with_program(
123 wallet: &Address,
124 mint: &Address,
125 token_program: &Address,
126) -> Result<(Address, u8), ProgramError> {
127 #[cfg(target_os = "solana")]
128 {
129 let seeds: &[&[u8]] = &[
130 wallet.as_ref(),
131 token_program.as_ref(),
132 mint.as_ref(),
133 ];
134 let (address, bump) = Address::find_program_address(seeds, &crate::programs::ASSOCIATED_TOKEN);
135 Ok((address, bump))
136 }
137 #[cfg(not(target_os = "solana"))]
138 {
139 let _ = (wallet, mint, token_program);
140 Err(ProgramError::InvalidSeeds)
141 }
142}
143
144#[cfg(feature = "programs")]
146#[inline(always)]
147pub fn derive_ata_with_bump(
148 wallet: &Address,
149 mint: &Address,
150 bump: u8,
151) -> Address {
152 Address::new_from_array(derive_address(
153 &[wallet.as_ref(), crate::programs::TOKEN.as_array().as_ref(), mint.as_ref()],
154 Some(bump),
155 crate::programs::ASSOCIATED_TOKEN.as_array(),
156 ))
157}
158
159#[cfg(feature = "programs")]
161#[macro_export]
162macro_rules! derive_ata_const {
163 ($wallet:expr, $mint:expr, $bump:expr) => {{
164 const TOKEN_BYTES: [u8; 32] = $crate::programs::TOKEN.to_bytes();
165 const ATA_BYTES: [u8; 32] = $crate::programs::ASSOCIATED_TOKEN.to_bytes();
166 ::pinocchio::Address::new_from_array($crate::check::pda::derive_address_const(
167 &[&$wallet, &TOKEN_BYTES, &$mint],
168 Some($bump),
169 &ATA_BYTES,
170 ))
171 }};
172}
173
174#[macro_export]
180macro_rules! find_pda {
181 ($program_id:expr, $($seed:expr),+ $(,)?) => {{
182 #[cfg(target_os = "solana")]
183 {
184 let seeds: &[&[u8]] = &[$($seed.as_ref()),+];
185 ::pinocchio::Address::find_program_address(seeds, $program_id)
186 }
187 #[cfg(not(target_os = "solana"))]
188 {
189 let _ = ($program_id, $($seed),+);
190 unreachable!("find_pda! is only available on target solana")
191 }
192 }};
193}
194
195#[macro_export]
199macro_rules! derive_pda {
200 ($program_id:expr, $bump:expr, $($seed:expr),+ $(,)?) => {{
201 ::pinocchio::Address::new_from_array($crate::check::pda::derive_address(
202 &[$($seed.as_ref()),+],
203 Some($bump),
204 ($program_id).as_array(),
205 ))
206 }};
207}
208
209#[macro_export]
211macro_rules! derive_pda_const {
212 ($program_id:expr, $bump:expr, $($seed:expr),+ $(,)?) => {
213 ::pinocchio::Address::new_from_array($crate::check::pda::derive_address_const(
214 &[$(&$seed),+],
215 Some($bump),
216 &$program_id,
217 ))
218 };
219}
220
221#[cfg(feature = "programs")]
223#[inline(always)]
224pub fn check_ata(
225 account: &pinocchio::AccountView,
226 wallet: &Address,
227 mint: &Address,
228) -> pinocchio::ProgramResult {
229 let (expected, _) = derive_ata(wallet, mint)?;
230 if *account.address() != expected {
231 return Err(ProgramError::InvalidSeeds);
232 }
233 Ok(())
234}
235
236#[cfg(feature = "programs")]
238#[inline(always)]
239pub fn check_ata_with_program(
240 account: &pinocchio::AccountView,
241 wallet: &Address,
242 mint: &Address,
243 token_program: &Address,
244) -> pinocchio::ProgramResult {
245 let (expected, _) = derive_ata_with_program(wallet, mint, token_program)?;
246 if *account.address() != expected {
247 return Err(ProgramError::InvalidSeeds);
248 }
249 Ok(())
250}
251
252#[macro_export]
261macro_rules! require_pda {
262 ($account:expr, $program_id:expr, $($seed:expr),+ $(,)?) => {{
263 let seeds: &[&[u8]] = &[$($seed.as_ref()),+];
264 $crate::check::assert_pda($account, seeds, $program_id)
265 }};
266}