solana_address/derive.rs
1use {
2 crate::{Address, MAX_SEEDS, PDA_MARKER},
3 core::{mem::MaybeUninit, slice::from_raw_parts},
4 sha2_const_stable::Sha256,
5 solana_sha256_hasher::hashv,
6};
7
8impl Address {
9 /// Derive a [program address][pda] from the given seeds, optional bump and
10 /// program id.
11 ///
12 /// [pda]: https://solana.com/docs/core/pda
13 ///
14 /// In general, the derivation uses an optional bump (byte) value to ensure a
15 /// valid PDA (off-curve) is generated. Even when a program stores a bump to
16 /// derive a program address, it is necessary to use the
17 /// [`Address::create_program_address`] to validate the derivation. In
18 /// most cases, the program has the correct seeds for the derivation, so it would
19 /// be sufficient to just perform the derivation and compare it against the
20 /// expected resulting address.
21 ///
22 /// This function avoids the cost of the `create_program_address` syscall
23 /// (`1500` compute units) by directly computing the derived address
24 /// calculating the hash of the seeds, bump and program id using the
25 /// `sol_sha256` syscall.
26 ///
27 /// # Important
28 ///
29 /// This function differs from [`Address::create_program_address`] in that
30 /// it does not perform a validation to ensure that the derived address is a valid
31 /// (off-curve) program derived address. It is intended for use in cases where the
32 /// seeds, bump, and program id are known to be valid, and the caller wants to derive
33 /// the address without incurring the cost of the `create_program_address` syscall.
34 #[inline]
35 pub fn derive_address<const N: usize>(
36 seeds: &[&[u8]; N],
37 bump: Option<u8>,
38 program_id: &Address,
39 ) -> Address {
40 const {
41 assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
42 }
43
44 let mut data = [const { MaybeUninit::<&[u8]>::uninit() }; MAX_SEEDS + 2];
45 let mut i = 0;
46
47 while i < N {
48 // SAFETY: `data` is guaranteed to have enough space for `N` seeds,
49 // so `i` will always be within bounds.
50 unsafe {
51 data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
52 }
53 i += 1;
54 }
55
56 // SAFETY: `data` is guaranteed to have enough space for `MAX_SEEDS + 2`
57 // elements, and `MAX_SEEDS` is larger than `N`.
58 unsafe {
59 if bump.is_some() {
60 data.get_unchecked_mut(i).write(bump.as_slice());
61 i += 1;
62 }
63 data.get_unchecked_mut(i).write(program_id.as_ref());
64 data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
65 }
66
67 let hash = hashv(unsafe { from_raw_parts(data.as_ptr() as *const &[u8], i + 2) });
68 Address::from(hash.to_bytes())
69 }
70
71 /// Derive a [program address][pda] from the given seeds, optional bump and
72 /// program id.
73 ///
74 /// [pda]: https://solana.com/docs/core/pda
75 ///
76 /// In general, the derivation uses an optional bump (byte) value to ensure a
77 /// valid PDA (off-curve) is generated.
78 ///
79 /// This function is intended for use in `const` contexts - i.e., the seeds and
80 /// bump are known at compile time and the program id is also a constant. It avoids
81 /// the cost of the `create_program_address` syscall (`1500` compute units) by
82 /// directly computing the derived address using the SHA-256 hash of the seeds,
83 /// bump and program id.
84 ///
85 /// # Important
86 ///
87 /// This function differs from [`Address::create_program_address`] in that
88 /// it does not perform a validation to ensure that the derived address is a valid
89 /// (off-curve) program derived address. It is intended for use in cases where the
90 /// seeds, bump, and program id are known to be valid, and the caller wants to derive
91 /// the address without incurring the cost of the `create_program_address` syscall.
92 ///
93 /// This function is a compile-time constant version of [`Address::derive_address`].
94 /// It has worse performance than `derive_address`, so only use this function in
95 /// `const` contexts, where all parameters are known at compile-time.
96 pub const fn derive_address_const<const N: usize>(
97 seeds: &[&[u8]; N],
98 bump: Option<u8>,
99 program_id: &Address,
100 ) -> Address {
101 const {
102 assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
103 }
104
105 let mut hasher = Sha256::new();
106 let mut i = 0;
107
108 while i < seeds.len() {
109 hasher = hasher.update(seeds[i]);
110 i += 1;
111 }
112
113 // TODO: replace this with `bump.as_slice()` when the MSRV is
114 // upgraded to `1.84.0+`.
115 Address::new_from_array(if let Some(bump) = bump {
116 hasher
117 .update(&[bump])
118 .update(program_id.as_array())
119 .update(PDA_MARKER)
120 .finalize()
121 } else {
122 hasher
123 .update(program_id.as_array())
124 .update(PDA_MARKER)
125 .finalize()
126 })
127 }
128
129 /// Attempt to derive a valid [program derived address][pda] (PDA) and its corresponding
130 /// bump seed.
131 ///
132 /// [pda]: https://solana.com/docs/core/cpi#program-derived-addresses
133 ///
134 /// The main difference between this method and [`Address::derive_address`]
135 /// is that this method iterates through all possible bump seed values (starting from
136 /// `255` and decrementing) until it finds a valid (off-curve) program derived address.
137 ///
138 /// If a valid PDA is found, it returns the PDA and the bump seed used to derive it;
139 /// otherwise, it returns `None`.
140 #[inline]
141 pub fn derive_program_address<const N: usize>(
142 seeds: &[&[u8]; N],
143 program_id: &Address,
144 ) -> Option<(Address, u8)> {
145 let mut bump = u8::MAX;
146
147 loop {
148 let address = Self::derive_address(seeds, Some(bump), program_id);
149
150 // Check if the derived address is a valid (off-curve)
151 // program derived address.
152 if !address.is_on_curve() {
153 return Some((address, bump));
154 }
155
156 // If the derived address is on-curve, decrement the bump and
157 // try again until all possible bump values are tested.
158 if bump == 0 {
159 return None;
160 }
161
162 bump -= 1;
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use crate::Address;
170
171 #[test]
172 fn test_derive_address() {
173 let program_id = Address::new_from_array([1u8; 32]);
174 let seeds: &[&[u8]; 2] = &[b"seed1", b"seed2"];
175 let (address, bump) = Address::find_program_address(seeds, &program_id);
176
177 let derived_address = Address::derive_address(seeds, Some(bump), &program_id);
178 let derived_address_const = Address::derive_address_const(seeds, Some(bump), &program_id);
179
180 assert_eq!(address, derived_address);
181 assert_eq!(address, derived_address_const);
182
183 let extended_seeds: &[&[u8]; 3] = &[b"seed1", b"seed2", &[bump]];
184
185 let derived_address = Address::derive_address(extended_seeds, None, &program_id);
186 let derived_address_const =
187 Address::derive_address_const(extended_seeds, None, &program_id);
188
189 assert_eq!(address, derived_address);
190 assert_eq!(address, derived_address_const);
191 }
192
193 #[test]
194 fn test_program_derive_address() {
195 let program_id = Address::new_unique();
196 let seeds: &[&[u8]; 3] = &[b"derived", b"programm", b"address"];
197
198 let (address, bump) = Address::find_program_address(seeds, &program_id);
199
200 let (derived_address, derived_bump) =
201 Address::derive_program_address(seeds, &program_id).unwrap();
202
203 assert_eq!(address, derived_address);
204 assert_eq!(bump, derived_bump);
205 }
206}