use std::mem::MaybeUninit;
use light_zero_copy::{ZeroCopy, ZeroCopyMut};
use pinocchio::pubkey::Pubkey;
use solana_pubkey::MAX_SEEDS;
use tinyvec::ArrayVec;
use crate::{AnchorDeserialize, AnchorSerialize, TokenError};
#[derive(
Debug, Clone, PartialEq, Eq, AnchorSerialize, AnchorDeserialize, ZeroCopy, ZeroCopyMut,
)]
#[repr(C)]
pub struct CompressibleExtensionInstructionData {
pub token_account_version: u8,
pub rent_payment: u8,
pub compression_only: u8,
pub write_top_up: u32,
pub compress_to_account_pubkey: Option<CompressToPubkey>,
}
#[derive(
Debug, Clone, PartialEq, Eq, AnchorSerialize, AnchorDeserialize, ZeroCopy, ZeroCopyMut,
)]
#[repr(C)]
pub struct CompressToPubkey {
pub bump: u8,
pub program_id: [u8; 32],
pub seeds: Vec<Vec<u8>>,
}
impl CompressToPubkey {
pub fn check_seeds(&self, pubkey: &Pubkey) -> Result<(), TokenError> {
if self.seeds.len() >= MAX_SEEDS {
return Err(TokenError::TooManySeeds(MAX_SEEDS - 1));
}
let mut references = ArrayVec::<[&[u8]; MAX_SEEDS]>::new();
for seed in self.seeds.iter() {
references.push(seed.as_slice());
}
let derived_pubkey = derive_address(references.as_slice(), self.bump, &self.program_id)?;
if derived_pubkey != *pubkey {
Err(TokenError::InvalidAccountData)
} else {
Ok(())
}
}
}
pub fn derive_address(
seeds: &[&[u8]],
bump: u8,
program_id: &Pubkey,
) -> Result<Pubkey, TokenError> {
const PDA_MARKER: &[u8; 21] = b"ProgramDerivedAddress";
if seeds.len() >= MAX_SEEDS {
return Err(TokenError::TooManySeeds(MAX_SEEDS - 1));
}
const UNINIT: MaybeUninit<&[u8]> = MaybeUninit::<&[u8]>::uninit();
let mut data = [UNINIT; MAX_SEEDS + 2];
let mut i = 0;
while i < seeds.len() {
unsafe {
data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
}
i += 1;
}
let bump_seed = [bump];
unsafe {
data.get_unchecked_mut(i).write(&bump_seed);
i += 1;
data.get_unchecked_mut(i).write(program_id.as_ref());
data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
}
#[cfg(target_os = "solana")]
{
use pinocchio::syscalls::sol_sha256;
let mut pda = MaybeUninit::<[u8; 32]>::uninit();
unsafe {
sol_sha256(
data.as_ptr() as *const u8,
(i + 2) as u64,
pda.as_mut_ptr() as *mut u8,
);
}
unsafe { Ok(pda.assume_init()) }
}
#[cfg(not(target_os = "solana"))]
unreachable!("deriving a pda is only available on target `solana`");
}