use {
super::{
error::TokenError,
extension::{
default_account_state::DefaultAccountState,
immutable_owner::ImmutableOwner,
interest_bearing_mint::InterestBearingConfig,
memo_transfer::MemoTransfer,
mint_close_authority::MintCloseAuthority,
non_transferable::NonTransferable,
transfer_fee::{TransferFeeAmount, TransferFeeConfig},
},
pod::*,
state::{Account, Mint, Multisig},
},
bytemuck::{Pod, Zeroable},
num_enum::{IntoPrimitive, TryFromPrimitive},
solana_sdk::{
program_error::ProgramError,
program_pack::{IsInitialized, Pack},
},
std::{
convert::{TryFrom, TryInto},
mem::size_of,
},
};
pub mod default_account_state;
pub mod immutable_owner;
pub mod interest_bearing_mint;
pub mod memo_transfer;
pub mod mint_close_authority;
pub mod non_transferable;
pub mod reallocate;
pub mod transfer_fee;
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
#[repr(transparent)]
pub struct Length(PodU16);
impl From<Length> for usize {
fn from(n: Length) -> Self {
Self::from(u16::from(n.0))
}
}
impl TryFrom<usize> for Length {
type Error = ProgramError;
fn try_from(n: usize) -> Result<Self, Self::Error> {
u16::try_from(n)
.map(|v| Self(PodU16::from(v)))
.map_err(|_| ProgramError::AccountDataTooSmall)
}
}
fn get_tlv_indices(type_start: usize) -> TlvIndices {
let length_start = type_start.saturating_add(size_of::<ExtensionType>());
let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
TlvIndices {
type_start,
length_start,
value_start,
}
}
#[derive(Debug)]
struct TlvIndices {
pub type_start: usize,
pub length_start: usize,
pub value_start: usize,
}
fn get_extension_indices<V: Extension>(
tlv_data: &[u8],
init: bool,
) -> Result<TlvIndices, ProgramError> {
let mut start_index = 0;
let v_account_type = V::TYPE.get_account_type();
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
let account_type = extension_type.get_account_type();
if extension_type == ExtensionType::Uninitialized {
if init {
return Ok(tlv_indices);
} else {
start_index = tlv_indices.length_start;
}
} else if extension_type == V::TYPE {
return Ok(tlv_indices);
} else if v_account_type != account_type {
return Err(TokenError::ExtensionTypeMismatch.into());
} else {
let length = pod_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
start_index = value_end_index;
}
}
Err(ProgramError::InvalidAccountData)
}
fn get_extension_types(tlv_data: &[u8]) -> Result<Vec<ExtensionType>, ProgramError> {
let mut extension_types = vec![];
let mut start_index = 0;
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.value_start {
return Ok(extension_types);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
if extension_type == ExtensionType::Uninitialized {
return Ok(extension_types);
} else {
extension_types.push(extension_type);
let length = pod_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
start_index = value_end_index;
}
}
Ok(extension_types)
}
fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
if tlv_data.is_empty() {
Ok(None)
} else {
let tlv_indices = get_tlv_indices(0);
if tlv_data.len() <= tlv_indices.length_start {
return Ok(None);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
if extension_type == ExtensionType::Uninitialized {
Ok(None)
} else {
Ok(Some(extension_type))
}
}
}
fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
if input.len() == Multisig::LEN || input.len() < minimum_len {
Err(ProgramError::InvalidAccountData)
} else {
Ok(())
}
}
fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
if account_type != S::ACCOUNT_TYPE {
Err(ProgramError::InvalidAccountData)
} else {
Ok(())
}
}
const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
fn type_and_tlv_indices<S: BaseState>(
rest_input: &[u8],
) -> Result<Option<(usize, usize)>, ProgramError> {
if rest_input.is_empty() {
Ok(None)
} else {
let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::LEN);
let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
if rest_input.len() <= tlv_start_index {
return Err(ProgramError::InvalidAccountData);
}
if rest_input[..account_type_index] != vec![0; account_type_index] {
Err(ProgramError::InvalidAccountData)
} else {
Ok(Some((account_type_index, tlv_start_index)))
}
}
}
fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
const ACCOUNT_INITIALIZED_INDEX: usize = 108;
if input.len() != BASE_ACCOUNT_LENGTH {
return Err(ProgramError::InvalidAccountData);
}
Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
}
fn get_extension<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&V, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_extension_indices::<V>(tlv_data, false)?;
let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
pod_from_bytes::<V>(&tlv_data[value_start..value_end])
}
#[derive(Debug, PartialEq)]
pub struct StateWithExtensionsOwned<S: BaseState> {
pub base: S,
tlv_data: Vec<u8>,
}
impl<S: BaseState> StateWithExtensionsOwned<S> {
pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(&input, S::LEN)?;
let mut rest = input.split_off(S::LEN);
let base = S::unpack(&input)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
let tlv_data = rest.split_off(tlv_start_index);
Ok(Self { base, tlv_data })
} else {
Ok(Self {
base,
tlv_data: vec![],
})
}
}
pub fn get_extension<V: Extension>(&self) -> Result<&V, ProgramError> {
get_extension::<S, V>(&self.tlv_data)
}
pub fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
get_extension_types(&self.tlv_data)
}
}
#[derive(Debug, PartialEq)]
pub struct StateWithExtensions<'data, S: BaseState> {
pub base: S,
tlv_data: &'data [u8],
}
impl<'data, S: BaseState> StateWithExtensions<'data, S> {
pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at(S::LEN);
let base = S::unpack(base_data)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
Ok(Self {
base,
tlv_data: &rest[tlv_start_index..],
})
} else {
Ok(Self {
base,
tlv_data: &[],
})
}
}
pub fn get_extension<V: Extension>(&self) -> Result<&V, ProgramError> {
get_extension::<S, V>(self.tlv_data)
}
pub fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
get_extension_types(self.tlv_data)
}
}
#[derive(Debug, PartialEq)]
pub struct StateWithExtensionsMut<'data, S: BaseState> {
pub base: S,
base_data: &'data mut [u8],
account_type: &'data mut [u8],
tlv_data: &'data mut [u8],
}
impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at_mut(S::LEN);
let base = S::unpack(base_data)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
Ok(Self {
base,
base_data,
account_type: &mut account_type[account_type_index..tlv_start_index],
tlv_data,
})
} else {
Ok(Self {
base,
base_data,
account_type: &mut [],
tlv_data: &mut [],
})
}
}
pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at_mut(S::LEN);
let base = S::unpack_unchecked(base_data)?;
if base.is_initialized() {
return Err(TokenError::AlreadyInUse.into());
}
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
if account_type != AccountType::Uninitialized {
return Err(ProgramError::InvalidAccountData);
}
let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
let state = Self {
base,
base_data,
account_type: &mut account_type[account_type_index..tlv_start_index],
tlv_data,
};
if let Some(extension_type) = state.get_first_extension_type()? {
let account_type = extension_type.get_account_type();
if account_type != S::ACCOUNT_TYPE {
return Err(TokenError::ExtensionBaseMismatch.into());
}
}
Ok(state)
} else {
Ok(Self {
base,
base_data,
account_type: &mut [],
tlv_data: &mut [],
})
}
}
pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start,
length_start,
value_start,
} = get_extension_indices::<V>(self.tlv_data, false)?;
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
return Err(ProgramError::InvalidAccountData);
}
let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])
}
pub fn get_extension<V: Extension>(&self) -> Result<&V, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start,
length_start,
value_start,
} = get_extension_indices::<V>(self.tlv_data, false)?;
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
return Err(ProgramError::InvalidAccountData);
}
let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
pod_from_bytes::<V>(&self.tlv_data[value_start..value_end])
}
pub fn pack_base(&mut self) {
S::pack_into_slice(&self.base, self.base_data);
}
pub fn init_extension<V: Extension>(
&mut self,
overwrite: bool,
) -> Result<&mut V, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start,
length_start,
value_start,
} = get_extension_indices::<V>(self.tlv_data, true)?;
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
return Err(ProgramError::InvalidAccountData);
}
let extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
if extension_type == ExtensionType::Uninitialized || overwrite {
let extension_type_array: [u8; 2] = V::TYPE.into();
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
extension_type_ref.copy_from_slice(&extension_type_array);
let length_ref =
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
let length = pod_get_packed_len::<V>();
*length_ref = Length::try_from(length).unwrap();
let value_end = value_start.saturating_add(length);
let extension_ref =
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
*extension_ref = V::default();
Ok(extension_ref)
} else {
Err(TokenError::ExtensionAlreadyInitialized.into())
}
}
pub fn init_account_extension_from_type(
&mut self,
extension_type: ExtensionType,
) -> Result<(), ProgramError> {
if extension_type.get_account_type() != AccountType::Account {
return Ok(());
}
match extension_type {
ExtensionType::TransferFeeAmount => {
self.init_extension::<TransferFeeAmount>(true).map(|_| ())
}
ExtensionType::ConfidentialTransferAccount => Ok(()),
#[cfg(test)]
ExtensionType::AccountPaddingTest => {
self.init_extension::<AccountPaddingTest>(true).map(|_| ())
}
_ => unreachable!(),
}
}
pub fn init_account_type(&mut self) -> Result<(), ProgramError> {
if !self.account_type.is_empty() {
if let Some(extension_type) = self.get_first_extension_type()? {
let account_type = extension_type.get_account_type();
if account_type != S::ACCOUNT_TYPE {
return Err(TokenError::ExtensionBaseMismatch.into());
}
}
self.account_type[0] = S::ACCOUNT_TYPE.into();
}
Ok(())
}
pub fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
get_extension_types(self.tlv_data)
}
fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
get_first_extension_type(self.tlv_data)
}
}
pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at_mut(S::LEN);
if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
return Err(ProgramError::InvalidAccountData);
}
if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let mut account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
if account_type == AccountType::Uninitialized {
rest[account_type_index] = S::ACCOUNT_TYPE.into();
account_type = S::ACCOUNT_TYPE;
}
check_account_type::<S>(account_type)?;
Ok(())
} else {
Err(ProgramError::InvalidAccountData)
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
pub enum AccountType {
Uninitialized,
Mint,
Account,
}
impl Default for AccountType {
fn default() -> Self {
Self::Uninitialized
}
}
#[repr(u16)]
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
pub enum ExtensionType {
Uninitialized,
TransferFeeConfig,
TransferFeeAmount,
MintCloseAuthority,
ConfidentialTransferMint,
ConfidentialTransferAccount,
DefaultAccountState,
ImmutableOwner,
MemoTransfer,
NonTransferable,
InterestBearingConfig,
#[cfg(test)]
AccountPaddingTest = u16::MAX - 1,
#[cfg(test)]
MintPaddingTest = u16::MAX,
}
impl TryFrom<&[u8]> for ExtensionType {
type Error = ProgramError;
fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
Self::try_from(u16::from_le_bytes(
a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
))
.map_err(|_| ProgramError::InvalidAccountData)
}
}
impl From<ExtensionType> for [u8; 2] {
fn from(a: ExtensionType) -> Self {
u16::from(a).to_le_bytes()
}
}
impl ExtensionType {
pub fn get_type_len(&self) -> usize {
match self {
ExtensionType::Uninitialized => 0,
ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
ExtensionType::ConfidentialTransferMint => {
0
}
ExtensionType::ConfidentialTransferAccount => {
0
}
ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
#[cfg(test)]
ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
#[cfg(test)]
ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
}
}
fn get_tlv_len(&self) -> usize {
self.get_type_len()
.saturating_add(size_of::<ExtensionType>())
.saturating_add(pod_get_packed_len::<Length>())
}
fn get_total_tlv_len(extension_types: &[Self]) -> usize {
let mut extensions = vec![];
for extension_type in extension_types {
if !extensions.contains(&extension_type) {
extensions.push(extension_type);
}
}
let tlv_len: usize = extensions.iter().map(|e| e.get_tlv_len()).sum();
if tlv_len
== Multisig::LEN
.saturating_sub(BASE_ACCOUNT_LENGTH)
.saturating_sub(size_of::<AccountType>())
{
tlv_len.saturating_add(size_of::<ExtensionType>())
} else {
tlv_len
}
}
pub fn get_account_len<S: BaseState>(extension_types: &[Self]) -> usize {
if extension_types.is_empty() {
S::LEN
} else {
let extension_size = Self::get_total_tlv_len(extension_types);
extension_size
.saturating_add(BASE_ACCOUNT_LENGTH)
.saturating_add(size_of::<AccountType>())
}
}
pub fn get_account_type(&self) -> AccountType {
match self {
ExtensionType::Uninitialized => AccountType::Uninitialized,
ExtensionType::TransferFeeConfig
| ExtensionType::MintCloseAuthority
| ExtensionType::ConfidentialTransferMint
| ExtensionType::DefaultAccountState
| ExtensionType::NonTransferable
| ExtensionType::InterestBearingConfig => AccountType::Mint,
ExtensionType::ImmutableOwner
| ExtensionType::TransferFeeAmount
| ExtensionType::ConfidentialTransferAccount
| ExtensionType::MemoTransfer => AccountType::Account,
#[cfg(test)]
ExtensionType::AccountPaddingTest => AccountType::Account,
#[cfg(test)]
ExtensionType::MintPaddingTest => AccountType::Mint,
}
}
pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
let mut account_extension_types = vec![];
for extension_type in mint_extension_types {
#[allow(clippy::single_match)]
match extension_type {
ExtensionType::TransferFeeConfig => {
account_extension_types.push(ExtensionType::TransferFeeAmount);
}
#[cfg(test)]
ExtensionType::MintPaddingTest => {
account_extension_types.push(ExtensionType::AccountPaddingTest);
}
_ => {}
}
}
account_extension_types
}
}
pub trait BaseState: Pack + IsInitialized {
const ACCOUNT_TYPE: AccountType;
}
impl BaseState for Account {
const ACCOUNT_TYPE: AccountType = AccountType::Account;
}
impl BaseState for Mint {
const ACCOUNT_TYPE: AccountType = AccountType::Mint;
}
pub trait Extension: Pod + Default {
const TYPE: ExtensionType;
}