#[cfg(feature = "dev-context-only-utils")]
use qualifier_attr::field_qualifiers;
use {
crate::{
account_loader::AccountLoader,
transaction_processing_callback::TransactionProcessingCallback,
},
solana_account::{AccountSharedData, ReadableAccount},
solana_pubkey::Pubkey,
solana_svm_transaction::svm_transaction::SVMTransaction,
spl_generic_token::{generic_token, is_known_spl_token_id},
};
type TxNativeBalances = Vec<u64>;
type TxTokenBalances = Vec<SvmTokenInfo>;
type BatchNativeBalances = Vec<TxNativeBalances>;
type BatchTokenBalances = Vec<TxTokenBalances>;
pub(crate) trait BalanceCollectionRoutines {
fn collect_pre_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
);
fn collect_post_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
);
}
#[derive(Debug, Default)]
#[cfg_attr(
feature = "dev-context-only-utils",
field_qualifiers(native_pre(pub), native_post(pub), token_pre(pub), token_post(pub),)
)]
pub struct BalanceCollector {
native_pre: BatchNativeBalances,
native_post: BatchNativeBalances,
token_pre: BatchTokenBalances,
token_post: BatchTokenBalances,
}
impl BalanceCollector {
pub(crate) fn new_with_transaction_count(transaction_count: usize) -> Self {
Self {
native_pre: Vec::with_capacity(transaction_count),
native_post: Vec::with_capacity(transaction_count),
token_pre: Vec::with_capacity(transaction_count),
token_post: Vec::with_capacity(transaction_count),
}
}
pub fn into_vecs(
self,
) -> (
BatchNativeBalances,
BatchNativeBalances,
BatchTokenBalances,
BatchTokenBalances,
) {
(
self.native_pre,
self.native_post,
self.token_pre,
self.token_post,
)
}
fn collect_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
) -> (TxNativeBalances, TxTokenBalances) {
let mut native_balances = Vec::with_capacity(transaction.account_keys().len());
let mut token_balances = vec![];
let has_token_program = transaction.account_keys().iter().any(is_known_spl_token_id);
for (index, key) in transaction.account_keys().iter().enumerate() {
let Some(account) = account_loader.load_account(key) else {
native_balances.push(0);
continue;
};
native_balances.push(account.lamports());
if has_token_program
&& !transaction.is_invoked(index)
&& !is_known_spl_token_id(key)
&& is_known_spl_token_id(account.owner())
{
if let Some(token_info) =
SvmTokenInfo::unpack_token_account(account_loader, &account, index)
{
token_balances.push(token_info);
}
}
}
(native_balances, token_balances)
}
pub(crate) fn lengths_match_expected(&self, expected_len: usize) -> bool {
self.native_pre.len() == expected_len
&& self.native_post.len() == expected_len
&& self.token_pre.len() == expected_len
&& self.token_post.len() == expected_len
}
}
impl BalanceCollectionRoutines for BalanceCollector {
fn collect_pre_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
) {
let (native_balances, token_balances) = self.collect_balances(account_loader, transaction);
self.native_pre.push(native_balances);
self.token_pre.push(token_balances);
}
fn collect_post_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
) {
let (native_balances, token_balances) = self.collect_balances(account_loader, transaction);
self.native_post.push(native_balances);
self.token_post.push(token_balances);
}
}
impl BalanceCollectionRoutines for Option<BalanceCollector> {
fn collect_pre_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
) {
if let Some(inner) = self {
inner.collect_pre_balances(account_loader, transaction)
}
}
fn collect_post_balances<CB: TransactionProcessingCallback>(
&mut self,
account_loader: &mut AccountLoader<CB>,
transaction: &impl SVMTransaction,
) {
if let Some(inner) = self {
inner.collect_post_balances(account_loader, transaction)
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmTokenInfo {
pub account_index: u8,
pub mint: Pubkey,
pub amount: u64,
pub owner: Pubkey,
pub program_id: Pubkey,
pub decimals: u8,
}
impl SvmTokenInfo {
fn unpack_token_account<CB: TransactionProcessingCallback>(
account_loader: &mut AccountLoader<CB>,
account: &AccountSharedData,
index: usize,
) -> Option<Self> {
let program_id = *account.owner();
let generic_token::Account {
mint,
owner,
amount,
} = generic_token::Account::unpack(account.data(), &program_id)?;
let mint_account = account_loader.load_account(&mint)?;
if *mint_account.owner() != program_id {
return None;
}
let generic_token::Mint { decimals, .. } =
generic_token::Mint::unpack(mint_account.data(), &program_id)?;
Some(Self {
account_index: index.try_into().ok()?,
mint,
amount,
owner,
program_id,
decimals,
})
}
}