mpl_token_auth_rules/
utils.rs

1//! Utilities for the program
2use crate::{
3    error::RuleSetError,
4    payload::ProofInfo,
5    state::{
6        RuleSetHeader, RuleSetRevisionMapV1, RULE_SET_REV_MAP_VERSION,
7        RULE_SET_SERIALIZED_HEADER_LEN,
8    },
9};
10use borsh::BorshDeserialize;
11use solana_program::{
12    account_info::AccountInfo,
13    entrypoint::ProgramResult,
14    msg,
15    program::{invoke, invoke_signed},
16    program_error::ProgramError,
17    program_memory::sol_memcmp,
18    pubkey::{Pubkey, PUBKEY_BYTES},
19    rent::Rent,
20    system_instruction,
21    sysvar::Sysvar,
22};
23// TODO: Uncomment this when the syscall is available.
24//use solana_zk_token_sdk::curve25519::curve_syscall_traits::CURVE25519_EDWARDS;
25
26/// Create account almost from scratch, lifted from
27/// <https://github.com/solana-labs/solana-program-library/tree/master/associated-token-account/program/src/processor.rs#L51-L98>
28#[inline(always)]
29pub fn create_or_allocate_account_raw<'a>(
30    program_id: Pubkey,
31    new_account_info: &AccountInfo<'a>,
32    system_program_info: &AccountInfo<'a>,
33    payer_info: &AccountInfo<'a>,
34    size: usize,
35    signer_seeds: &[&[u8]],
36) -> ProgramResult {
37    let rent = &Rent::get()?;
38    let required_lamports = rent
39        .minimum_balance(size)
40        .max(1)
41        .saturating_sub(new_account_info.lamports());
42
43    if required_lamports > 0 {
44        msg!("Transfer {} lamports to the new account", required_lamports);
45        invoke(
46            &system_instruction::transfer(payer_info.key, new_account_info.key, required_lamports),
47            &[
48                payer_info.clone(),
49                new_account_info.clone(),
50                system_program_info.clone(),
51            ],
52        )?;
53    }
54
55    let accounts = &[new_account_info.clone(), system_program_info.clone()];
56
57    msg!("Allocate space for the account");
58    invoke_signed(
59        &system_instruction::allocate(new_account_info.key, size.try_into().unwrap()),
60        accounts,
61        &[signer_seeds],
62    )?;
63
64    msg!("Assign the account to the owning program");
65    invoke_signed(
66        &system_instruction::assign(new_account_info.key, &program_id),
67        accounts,
68        &[signer_seeds],
69    )?;
70
71    Ok(())
72}
73
74/// Resize an account using realloc, lifted from Solana Cookbook.
75#[inline(always)]
76pub fn resize_or_reallocate_account_raw<'a>(
77    target_account: &AccountInfo<'a>,
78    funding_account: &AccountInfo<'a>,
79    system_program: &AccountInfo<'a>,
80    new_size: usize,
81) -> ProgramResult {
82    let rent = Rent::get()?;
83    let new_minimum_balance = rent.minimum_balance(new_size);
84
85    let lamports_diff = new_minimum_balance.saturating_sub(target_account.lamports());
86    invoke(
87        &system_instruction::transfer(funding_account.key, target_account.key, lamports_diff),
88        &[
89            funding_account.clone(),
90            target_account.clone(),
91            system_program.clone(),
92        ],
93    )?;
94
95    target_account.realloc(new_size, false)?;
96
97    Ok(())
98}
99
100/// Verify the derivation of the seeds against the given account.
101pub fn assert_derivation(
102    program_id: &Pubkey,
103    account: &Pubkey,
104    path: &[&[u8]],
105) -> Result<u8, ProgramError> {
106    let (key, bump) = Pubkey::find_program_address(path, program_id);
107    if key != *account {
108        return Err(RuleSetError::DerivedKeyInvalid.into());
109    }
110    Ok(bump)
111}
112
113/// Assert that the given account is owned by the given pubkey.
114pub fn assert_owned_by(account: &AccountInfo, owner: &Pubkey) -> ProgramResult {
115    if account.owner != owner {
116        Err(RuleSetError::IncorrectOwner.into())
117    } else {
118        Ok(())
119    }
120}
121
122/// Convenience function for comparing two [`Pubkey`]s.
123pub fn cmp_pubkeys(a: &Pubkey, b: &Pubkey) -> bool {
124    sol_memcmp(a.as_ref(), b.as_ref(), PUBKEY_BYTES) == 0
125}
126
127/// Compute the root of a Merkle tree given a leaf and a proof.  Uses a constant value
128/// of 0x01 as an input to the hashing function along with the values to be hashed.
129pub fn compute_merkle_root(leaf: &Pubkey, merkle_proof: &ProofInfo) -> [u8; 32] {
130    let mut computed_hash = leaf.to_bytes();
131    for proof_element in merkle_proof.proof.iter() {
132        if computed_hash <= *proof_element {
133            // Hash(current computed hash + current element of the proof).
134            computed_hash =
135                solana_program::keccak::hashv(&[&[0x01], &computed_hash, proof_element]).0;
136        } else {
137            // Hash(current element of the proof + current computed hash).
138            computed_hash =
139                solana_program::keccak::hashv(&[&[0x01], proof_element, &computed_hash]).0;
140        }
141    }
142
143    computed_hash
144}
145
146/// Get a revision map by looking at the header, finding its location, and deserializing it.
147pub fn get_existing_revision_map(
148    rule_set_pda_info: &AccountInfo,
149) -> Result<(RuleSetRevisionMapV1, usize), ProgramError> {
150    // Mutably borrow the existing `RuleSet` PDA data.
151    let data = rule_set_pda_info
152        .data
153        .try_borrow()
154        .map_err(|_| ProgramError::AccountBorrowFailed)?;
155
156    // Deserialize header.
157    let header = if data.len() >= RULE_SET_SERIALIZED_HEADER_LEN {
158        RuleSetHeader::try_from_slice(&data[..RULE_SET_SERIALIZED_HEADER_LEN])?
159    } else {
160        return Err(RuleSetError::DataTypeMismatch.into());
161    };
162
163    // Get revision map version location from header and use it check revision map version.
164    match data.get(header.rev_map_version_location) {
165        Some(&RULE_SET_REV_MAP_VERSION) => {
166            // Increment starting location by size of the revision map version.
167            let start = header
168                .rev_map_version_location
169                .checked_add(1)
170                .ok_or(RuleSetError::NumericalOverflow)?;
171
172            // Deserialize revision map.
173            if start < data.len() {
174                let mut location = &data[start..];
175                let revision_map = RuleSetRevisionMapV1::deserialize(&mut location)?;
176
177                Ok((revision_map, header.rev_map_version_location))
178            } else {
179                Err(RuleSetError::DataTypeMismatch.into())
180            }
181        }
182        Some(_) => Err(RuleSetError::UnsupportedRuleSetRevMapVersion.into()),
183        None => Err(RuleSetError::DataTypeMismatch.into()),
184    }
185}
186
187/// Get the latest revision number stored on the revision map.
188///
189/// This will first deserialize the header to find the map location and then deserialize the
190/// revision map.
191pub fn get_latest_revision(rule_set_pda_info: &AccountInfo) -> Result<Option<usize>, ProgramError> {
192    let (revision_map, _) = get_existing_revision_map(rule_set_pda_info)?;
193
194    match revision_map.rule_set_revisions.len() {
195        // we should always have at least one revision
196        0 => Err(RuleSetError::RuleSetRevisionNotAvailable.into()),
197        // determine the index of the last revision
198        length => Ok(Some(length - 1)),
199    }
200}
201
202/// Return whether the pubkey is on the Edwards 25519 curve.
203pub fn is_on_curve(pubkey: &Pubkey) -> bool {
204    let _point = pubkey.to_bytes();
205    let mut _validate_result = 0u8;
206    // TODO: Uncomment this when the syscall is available.
207    // let result = unsafe {
208    //     solana_program::syscalls::sol_curve_validate_point(
209    //         CURVE25519_EDWARDS,
210    //         &point as *const u8,
211    //         &mut validate_result,
212    //     )
213    // };
214
215    // For now return false instead of checking the result.
216    // result == 0
217    false
218}
219
220/// See if a slice contains all zeroes.  Useful for checking an account's data.
221pub fn is_zeroed(buf: &[u8]) -> bool {
222    const ZEROS_LEN: usize = 1024;
223    const ZEROS: [u8; ZEROS_LEN] = [0; ZEROS_LEN];
224
225    let mut chunks = buf.chunks_exact(ZEROS_LEN);
226
227    #[allow(clippy::indexing_slicing)]
228    {
229        chunks.all(|chunk| chunk == &ZEROS[..])
230            && chunks.remainder() == &ZEROS[..chunks.remainder().len()]
231    }
232}