use std::fmt::{Debug, Formatter};
use async_trait::async_trait;
use borsh::BorshDeserialize;
use solana_banks_client::BanksClientError;
use solana_program_test::ProgramTestContext;
use solana_sdk::{
account::{Account, AccountSharedData},
clock::Slot,
commitment_config::CommitmentConfig,
epoch_info::EpochInfo,
hash::Hash,
instruction::{Instruction, InstructionError},
pubkey::Pubkey,
signature::{Keypair, Signature, Signer},
system_instruction,
transaction::{Transaction, TransactionError},
};
use crate::transaction_params::TransactionParams;
use super::{merkle_tree::MerkleTreeExt, RpcConnection, RpcError};
pub struct ProgramTestRpcConnection {
pub context: ProgramTestContext,
}
impl Debug for ProgramTestRpcConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ProgramTestRpcConnection")
}
}
#[async_trait]
impl RpcConnection for ProgramTestRpcConnection {
fn new<U: ToString>(_url: U, _commitment_config: Option<CommitmentConfig>) -> Self
where
Self: Sized,
{
unimplemented!()
}
fn get_payer(&self) -> &Keypair {
&self.context.payer
}
fn get_url(&self) -> String {
unimplemented!("get_url doesn't make sense for ProgramTestRpcConnection")
}
async fn health(&self) -> Result<(), RpcError> {
unimplemented!()
}
async fn get_block_time(&self, _slot: u64) -> Result<i64, RpcError> {
unimplemented!()
}
async fn get_epoch_info(&self) -> Result<EpochInfo, RpcError> {
unimplemented!()
}
async fn get_program_accounts(
&self,
_program_id: &Pubkey,
) -> Result<Vec<(Pubkey, Account)>, RpcError> {
unimplemented!("get_program_accounts")
}
async fn process_transaction(
&mut self,
transaction: Transaction,
) -> Result<Signature, RpcError> {
let sig = *transaction.signatures.first().unwrap();
let result = self
.context
.banks_client
.process_transaction_with_metadata(transaction)
.await
.map_err(RpcError::from)?;
result.result.map_err(RpcError::TransactionError)?;
Ok(sig)
}
async fn process_transaction_with_context(
&mut self,
transaction: Transaction,
) -> Result<(Signature, Slot), RpcError> {
let sig = *transaction.signatures.first().unwrap();
let result = self
.context
.banks_client
.process_transaction_with_metadata(transaction)
.await
.map_err(RpcError::from)?;
result.result.map_err(RpcError::TransactionError)?;
let slot = self.context.banks_client.get_root_slot().await?;
Ok((sig, slot))
}
async fn create_and_send_transaction_with_event<T>(
&mut self,
instruction: &[Instruction],
payer: &Pubkey,
signers: &[&Keypair],
transaction_params: Option<TransactionParams>,
) -> Result<Option<(T, Signature, Slot)>, RpcError>
where
T: BorshDeserialize + Send + Debug,
{
let pre_balance = self
.context
.banks_client
.get_account(*payer)
.await?
.unwrap()
.lamports;
let transaction = Transaction::new_signed_with_payer(
instruction,
Some(payer),
signers,
self.context.get_new_latest_blockhash().await?,
);
let signature = transaction.signatures[0];
let simulation_result = self
.context
.banks_client
.simulate_transaction(transaction.clone())
.await?;
if let Some(Err(e)) = simulation_result.result {
let error = match e {
TransactionError::InstructionError(_, _) => RpcError::TransactionError(e),
_ => RpcError::from(BanksClientError::TransactionError(e)),
};
return Err(error);
}
let event = simulation_result
.simulation_details
.and_then(|details| details.inner_instructions)
.and_then(|instructions| {
instructions.iter().flatten().find_map(|inner_instruction| {
T::try_from_slice(inner_instruction.instruction.data.as_slice()).ok()
})
});
if let Some(Ok(())) = simulation_result.result {
let result = self
.context
.banks_client
.process_transaction(transaction)
.await;
if let Err(e) = result {
let error = RpcError::from(e);
return Err(error);
}
}
if let Some(transaction_params) = transaction_params {
let mut deduped_signers = signers.to_vec();
deduped_signers.dedup();
let post_balance = self.get_account(*payer).await?.unwrap().lamports;
let mut network_fee: i64 = 0;
if transaction_params.num_input_compressed_accounts != 0 {
network_fee += transaction_params.fee_config.network_fee as i64;
}
if transaction_params.num_new_addresses != 0 {
network_fee += transaction_params.fee_config.address_network_fee as i64;
}
let expected_post_balance = pre_balance as i64
- i64::from(transaction_params.num_new_addresses)
* transaction_params.fee_config.address_queue_rollover as i64
- i64::from(transaction_params.num_output_compressed_accounts)
* transaction_params.fee_config.state_merkle_tree_rollover as i64
- transaction_params.compress
- transaction_params.fee_config.solana_network_fee * deduped_signers.len() as i64
- network_fee;
if post_balance as i64 != expected_post_balance {
println!("transaction_params: {:?}", transaction_params);
println!("pre_balance: {}", pre_balance);
println!("post_balance: {}", post_balance);
println!("expected post_balance: {}", expected_post_balance);
println!(
"diff post_balance: {}",
post_balance as i64 - expected_post_balance
);
println!(
"rollover fee: {}",
transaction_params.fee_config.state_merkle_tree_rollover
);
println!(
"address_network_fee: {}",
transaction_params.fee_config.address_network_fee
);
println!("network_fee: {}", network_fee);
println!("num signers {}", deduped_signers.len());
return Err(RpcError::from(BanksClientError::TransactionError(
TransactionError::InstructionError(0, InstructionError::Custom(11111)),
)));
}
}
let slot = self.context.banks_client.get_root_slot().await?;
let result = event.map(|event| (event, signature, slot));
Ok(result)
}
async fn confirm_transaction(&self, _transaction: Signature) -> Result<bool, RpcError> {
Ok(true)
}
async fn get_account(&mut self, address: Pubkey) -> Result<Option<Account>, RpcError> {
self.context
.banks_client
.get_account(address)
.await
.map_err(RpcError::from)
}
fn set_account(&mut self, address: &Pubkey, account: &AccountSharedData) {
self.context.set_account(address, account);
}
async fn get_minimum_balance_for_rent_exemption(
&mut self,
data_len: usize,
) -> Result<u64, RpcError> {
let rent = self
.context
.banks_client
.get_rent()
.await
.map_err(RpcError::from);
Ok(rent?.minimum_balance(data_len))
}
async fn airdrop_lamports(
&mut self,
to: &Pubkey,
lamports: u64,
) -> Result<Signature, RpcError> {
let transfer_instruction =
system_instruction::transfer(&self.context.payer.pubkey(), to, lamports);
let latest_blockhash = self.get_latest_blockhash().await.unwrap();
let transaction = Transaction::new_signed_with_payer(
&[transfer_instruction],
Some(&self.get_payer().pubkey()),
&vec![&self.get_payer()],
latest_blockhash,
);
let sig = *transaction.signatures.first().unwrap();
self.context
.banks_client
.process_transaction(transaction)
.await?;
Ok(sig)
}
async fn get_balance(&mut self, pubkey: &Pubkey) -> Result<u64, RpcError> {
self.context
.banks_client
.get_balance(*pubkey)
.await
.map_err(RpcError::from)
}
async fn get_latest_blockhash(&mut self) -> Result<Hash, RpcError> {
self.context
.get_new_latest_blockhash()
.await
.map_err(|e| RpcError::from(BanksClientError::from(e)))
}
async fn get_slot(&mut self) -> Result<u64, RpcError> {
self.context
.banks_client
.get_root_slot()
.await
.map_err(RpcError::from)
}
async fn warp_to_slot(&mut self, slot: Slot) -> Result<(), RpcError> {
self.context
.warp_to_slot(slot)
.map_err(|_| RpcError::InvalidWarpSlot)
}
async fn send_transaction(&self, _transaction: &Transaction) -> Result<Signature, RpcError> {
unimplemented!("send transaction is unimplemented for ProgramTestRpcConnection")
}
}
impl MerkleTreeExt for ProgramTestRpcConnection {}