#[allow(deprecated)]
use anchor_lang::{
solana_program::{
borsh0_10::try_from_slice_unchecked, program_error::ProgramError, program_pack::Pack,
},
AnchorDeserialize,
};
use anchor_spl::{
token::spl_token::state::{Account, Multisig},
token_interface::spl_token_2022::extension::{
AccountType, BaseState, Extension, ExtensionType, Length,
},
};
use bytemuck::Pod;
use std::mem::size_of;
const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
struct TlvIndices {
pub type_start: usize,
pub length_start: usize,
pub value_start: usize,
}
pub fn get_extension<V: Extension + Pod>(
tlv_data: &[u8],
) -> core::result::Result<&V, ProgramError> {
bytemuck::try_from_bytes::<V>(get_extension_bytes::<V>(tlv_data)?)
.map_err(|_error| ProgramError::InvalidAccountData)
}
fn get_extension_bytes<V: Extension>(tlv_data: &[u8]) -> core::result::Result<&[u8], ProgramError> {
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_extension_indices::<V>(tlv_data)?;
let length = bytemuck::try_from_bytes::<Length>(&tlv_data[length_start..value_start])
.map_err(|_error| ProgramError::InvalidAccountData)?;
let value_end = value_start.saturating_add(usize::from(*length));
if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
Ok(&tlv_data[value_start..value_end])
}
fn get_extension_indices<V: Extension>(
tlv_data: &[u8],
) -> core::result::Result<TlvIndices, ProgramError> {
let mut start_index = 0;
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start]);
if extension_type.is_ok() && extension_type.unwrap() == V::TYPE {
return Ok(tlv_indices);
}
let length = bytemuck::try_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)
.map_err(|_| ProgramError::InvalidArgument)?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
start_index = value_end_index;
}
Err(ProgramError::InvalidAccountData)
}
fn get_tlv_indices(type_start: usize) -> TlvIndices {
let length_start = type_start.saturating_add(size_of::<ExtensionType>());
let value_start = length_start.saturating_add(size_of::<Length>());
TlvIndices {
type_start,
length_start,
value_start,
}
}
pub fn get_extension_types(tlv_data: &[u8]) -> Result<Vec<IExtensionType>, ProgramError> {
let mut extension_types = vec![];
let mut start_index = 0;
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.length_start {
return Ok(extension_types);
}
let extension_type = u16::from_le_bytes(
(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])
.try_into()
.map_err(|_| ProgramError::InvalidAccountData)?,
);
if let Ok(extension_type) = IExtensionType::try_from(extension_type) {
extension_types.push(extension_type);
}
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
let length = bytemuck::try_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)
.map_err(|_| ProgramError::InvalidAccountData)?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
if value_end_index > tlv_data.len() {
return Err(ProgramError::InvalidAccountData);
}
start_index = value_end_index;
}
Ok(extension_types)
}
pub fn get_variable_len_extension<V: Extension + AnchorDeserialize>(
tlv_data: &[u8],
) -> core::result::Result<V, ProgramError> {
let data = get_extension_bytes::<V>(tlv_data)?;
#[allow(deprecated)]
try_from_slice_unchecked::<V>(data).map_err(|_error| ProgramError::InvalidAccountData)
}
#[repr(u16)]
#[derive(Debug, PartialEq)]
pub enum IExtensionType {
MintCloseAuthority = 3,
DefaultAccountState = 6,
ImmutableOwner = 7,
NonTransferable = 9,
CpiGuard = 11,
PermanentDelegate = 12,
NonTransferableAccount = 13,
TransferHook = 14,
TransferHookAccount = 15,
MetadataPointer = 18,
GroupPointer = 20,
TokenGroup = 21,
GroupMemberPointer = 22,
TokenGroupMember = 23,
}
impl IExtensionType {
fn get_type_len(&self) -> usize {
match self {
IExtensionType::MintCloseAuthority => 32,
IExtensionType::DefaultAccountState => 1,
IExtensionType::ImmutableOwner => 0,
IExtensionType::NonTransferable => 0,
IExtensionType::CpiGuard => 1,
IExtensionType::PermanentDelegate => 32,
IExtensionType::NonTransferableAccount => 0,
IExtensionType::TransferHook => 64,
IExtensionType::TransferHookAccount => 1,
IExtensionType::MetadataPointer => 64,
IExtensionType::GroupPointer => 64,
IExtensionType::TokenGroup => 72,
IExtensionType::GroupMemberPointer => 64,
IExtensionType::TokenGroupMember => 68,
}
}
fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
Ok(add_type_and_length_to_len(self.get_type_len()))
}
fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
let mut extensions = vec![];
for extension_type in extension_types {
if !extensions.contains(&extension_type) {
extensions.push(extension_type);
}
}
extensions.iter().map(|e| e.try_get_tlv_len()).sum()
}
pub fn try_calculate_account_len<S: BaseState>(
extension_types: &[Self],
) -> Result<usize, ProgramError> {
if extension_types.is_empty() {
Ok(S::LEN)
} else {
let extension_size = Self::try_get_total_tlv_len(extension_types)?;
let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
Ok(adjust_len_for_multisig(total_len))
}
}
pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
let mut account_extension_types = vec![];
for extension_type in mint_extension_types {
match extension_type {
IExtensionType::NonTransferable => {
account_extension_types.push(IExtensionType::NonTransferableAccount);
}
IExtensionType::TransferHook => {
account_extension_types.push(IExtensionType::TransferHookAccount);
}
_ => {}
}
}
account_extension_types
}
}
impl TryFrom<u16> for IExtensionType {
type Error = ProgramError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
let extension = match value {
3 => IExtensionType::MintCloseAuthority,
6 => IExtensionType::DefaultAccountState,
7 => IExtensionType::ImmutableOwner,
9 => IExtensionType::NonTransferable,
11 => IExtensionType::CpiGuard,
12 => IExtensionType::PermanentDelegate,
13 => IExtensionType::NonTransferableAccount,
14 => IExtensionType::TransferHook,
15 => IExtensionType::TransferHookAccount,
18 => IExtensionType::MetadataPointer,
20 => IExtensionType::GroupPointer,
21 => IExtensionType::TokenGroup,
22 => IExtensionType::GroupMemberPointer,
23 => IExtensionType::TokenGroupMember,
_ => return Err(ProgramError::InvalidArgument),
};
Ok(extension)
}
}
const fn add_type_and_length_to_len(value_len: usize) -> usize {
value_len
.saturating_add(size_of::<ExtensionType>())
.saturating_add(size_of::<Length>())
}
const fn adjust_len_for_multisig(account_len: usize) -> usize {
if account_len == Multisig::LEN {
account_len.saturating_add(size_of::<ExtensionType>())
} else {
account_len
}
}