solana_address/derive.rs
1use {
2 crate::{Address, MAX_SEEDS, PDA_MARKER},
3 core::{mem::MaybeUninit, slice::from_raw_parts},
4 sha2_const_stable::Sha256,
5 solana_sha256_hasher::hashv,
6};
7
8impl Address {
9 /// Derive a [program address][pda] from the given seeds, optional bump and
10 /// program id.
11 ///
12 /// [pda]: https://solana.com/docs/core/pda
13 ///
14 /// In general, the derivation uses an optional bump (byte) value to ensure a
15 /// valid PDA (off-curve) is generated. Even when a program stores a bump to
16 /// derive a program address, it is necessary to use the
17 /// [`Address::create_program_address`] to validate the derivation. In
18 /// most cases, the program has the correct seeds for the derivation, so it would
19 /// be sufficient to just perform the derivation and compare it against the
20 /// expected resulting address.
21 ///
22 /// This function avoids the cost of the `create_program_address` syscall
23 /// (`1500` compute units) by directly computing the derived address
24 /// calculating the hash of the seeds, bump and program id using the
25 /// `sol_sha256` syscall.
26 ///
27 /// # Important
28 ///
29 /// This function differs from [`Address::create_program_address`] in that
30 /// it does not perform a validation to ensure that the derived address is a valid
31 /// (off-curve) program derived address. It is intended for use in cases where the
32 /// seeds, bump, and program id are known to be valid, and the caller wants to derive
33 /// the address without incurring the cost of the `create_program_address` syscall.
34 pub fn derive_address<const N: usize>(
35 seeds: &[&[u8]; N],
36 bump: Option<u8>,
37 program_id: &Address,
38 ) -> Address {
39 const {
40 assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
41 }
42
43 let mut data = [const { MaybeUninit::<&[u8]>::uninit() }; MAX_SEEDS + 2];
44 let mut i = 0;
45
46 while i < N {
47 // SAFETY: `data` is guaranteed to have enough space for `N` seeds,
48 // so `i` will always be within bounds.
49 unsafe {
50 data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
51 }
52 i += 1;
53 }
54
55 // SAFETY: `data` is guaranteed to have enough space for `MAX_SEEDS + 2`
56 // elements, and `MAX_SEEDS` is larger than `N`.
57 unsafe {
58 if bump.is_some() {
59 data.get_unchecked_mut(i).write(bump.as_slice());
60 i += 1;
61 }
62 data.get_unchecked_mut(i).write(program_id.as_ref());
63 data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
64 }
65
66 let hash = hashv(unsafe { from_raw_parts(data.as_ptr() as *const &[u8], i + 2) });
67 Address::from(hash.to_bytes())
68 }
69
70 /// Derive a [program address][pda] from the given seeds, optional bump and
71 /// program id.
72 ///
73 /// [pda]: https://solana.com/docs/core/pda
74 ///
75 /// In general, the derivation uses an optional bump (byte) value to ensure a
76 /// valid PDA (off-curve) is generated.
77 ///
78 /// This function is intended for use in `const` contexts - i.e., the seeds and
79 /// bump are known at compile time and the program id is also a constant. It avoids
80 /// the cost of the `create_program_address` syscall (`1500` compute units) by
81 /// directly computing the derived address using the SHA-256 hash of the seeds,
82 /// bump and program id.
83 ///
84 /// # Important
85 ///
86 /// This function differs from [`Address::create_program_address`] in that
87 /// it does not perform a validation to ensure that the derived address is a valid
88 /// (off-curve) program derived address. It is intended for use in cases where the
89 /// seeds, bump, and program id are known to be valid, and the caller wants to derive
90 /// the address without incurring the cost of the `create_program_address` syscall.
91 ///
92 /// This function is a compile-time constant version of [`Address::derive_address`].
93 /// It has worse performance than `derive_address`, so only use this function in
94 /// `const` contexts, where all parameters are known at compile-time.
95 pub const fn derive_address_const<const N: usize>(
96 seeds: &[&[u8]; N],
97 bump: Option<u8>,
98 program_id: &Address,
99 ) -> Address {
100 const {
101 assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
102 }
103
104 let mut hasher = Sha256::new();
105 let mut i = 0;
106
107 while i < seeds.len() {
108 hasher = hasher.update(seeds[i]);
109 i += 1;
110 }
111
112 // TODO: replace this with `bump.as_slice()` when the MSRV is
113 // upgraded to `1.84.0+`.
114 Address::new_from_array(if let Some(bump) = bump {
115 hasher
116 .update(&[bump])
117 .update(program_id.as_array())
118 .update(PDA_MARKER)
119 .finalize()
120 } else {
121 hasher
122 .update(program_id.as_array())
123 .update(PDA_MARKER)
124 .finalize()
125 })
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use crate::Address;
132
133 #[test]
134 fn test_derive_address() {
135 let program_id = Address::new_from_array([1u8; 32]);
136 let seeds: &[&[u8]; 2] = &[b"seed1", b"seed2"];
137 let (address, bump) = Address::find_program_address(seeds, &program_id);
138
139 let derived_address = Address::derive_address(seeds, Some(bump), &program_id);
140 let derived_address_const = Address::derive_address_const(seeds, Some(bump), &program_id);
141
142 assert_eq!(address, derived_address);
143 assert_eq!(address, derived_address_const);
144
145 let extended_seeds: &[&[u8]; 3] = &[b"seed1", b"seed2", &[bump]];
146
147 let derived_address = Address::derive_address(extended_seeds, None, &program_id);
148 let derived_address_const =
149 Address::derive_address_const(extended_seeds, None, &program_id);
150
151 assert_eq!(address, derived_address);
152 assert_eq!(address, derived_address_const);
153 }
154}