1use std::fmt::{Debug, Formatter};
2
3use async_trait::async_trait;
4use borsh::BorshDeserialize;
5use solana_banks_client::BanksClientError;
6use solana_program_test::ProgramTestContext;
7use solana_sdk::{
8 account::{Account, AccountSharedData},
9 clock::Slot,
10 commitment_config::CommitmentConfig,
11 epoch_info::EpochInfo,
12 hash::Hash,
13 instruction::{Instruction, InstructionError},
14 pubkey::Pubkey,
15 signature::{Keypair, Signature, Signer},
16 system_instruction,
17 transaction::{Transaction, TransactionError},
18};
19
20use crate::transaction_params::TransactionParams;
21
22use super::{merkle_tree::MerkleTreeExt, RpcConnection, RpcError};
23
24pub struct ProgramTestRpcConnection {
25 pub context: ProgramTestContext,
26}
27
28impl Debug for ProgramTestRpcConnection {
29 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30 write!(f, "ProgramTestRpcConnection")
31 }
32}
33
34#[async_trait]
35impl RpcConnection for ProgramTestRpcConnection {
36 fn new<U: ToString>(_url: U, _commitment_config: Option<CommitmentConfig>) -> Self
37 where
38 Self: Sized,
39 {
40 unimplemented!()
41 }
42
43 fn get_payer(&self) -> &Keypair {
44 &self.context.payer
45 }
46
47 fn get_url(&self) -> String {
48 unimplemented!("get_url doesn't make sense for ProgramTestRpcConnection")
49 }
50
51 async fn health(&self) -> Result<(), RpcError> {
52 unimplemented!()
53 }
54
55 async fn get_block_time(&self, _slot: u64) -> Result<i64, RpcError> {
56 unimplemented!()
57 }
58
59 async fn get_epoch_info(&self) -> Result<EpochInfo, RpcError> {
60 unimplemented!()
61 }
62
63 async fn get_program_accounts(
64 &self,
65 _program_id: &Pubkey,
66 ) -> Result<Vec<(Pubkey, Account)>, RpcError> {
67 unimplemented!("get_program_accounts")
68 }
69
70 async fn process_transaction(
71 &mut self,
72 transaction: Transaction,
73 ) -> Result<Signature, RpcError> {
74 let sig = *transaction.signatures.first().unwrap();
75 let result = self
76 .context
77 .banks_client
78 .process_transaction_with_metadata(transaction)
79 .await
80 .map_err(RpcError::from)?;
81 result.result.map_err(RpcError::TransactionError)?;
82 Ok(sig)
83 }
84
85 async fn process_transaction_with_context(
86 &mut self,
87 transaction: Transaction,
88 ) -> Result<(Signature, Slot), RpcError> {
89 let sig = *transaction.signatures.first().unwrap();
90 let result = self
91 .context
92 .banks_client
93 .process_transaction_with_metadata(transaction)
94 .await
95 .map_err(RpcError::from)?;
96 result.result.map_err(RpcError::TransactionError)?;
97 let slot = self.context.banks_client.get_root_slot().await?;
98 Ok((sig, slot))
99 }
100
101 async fn create_and_send_transaction_with_event<T>(
102 &mut self,
103 instruction: &[Instruction],
104 payer: &Pubkey,
105 signers: &[&Keypair],
106 transaction_params: Option<TransactionParams>,
107 ) -> Result<Option<(T, Signature, Slot)>, RpcError>
108 where
109 T: BorshDeserialize + Send + Debug,
110 {
111 let pre_balance = self
112 .context
113 .banks_client
114 .get_account(*payer)
115 .await?
116 .unwrap()
117 .lamports;
118
119 let transaction = Transaction::new_signed_with_payer(
120 instruction,
121 Some(payer),
122 signers,
123 self.context.get_new_latest_blockhash().await?,
124 );
125
126 let signature = transaction.signatures[0];
127 let simulation_result = self
131 .context
132 .banks_client
133 .simulate_transaction(transaction.clone())
134 .await?;
135 if let Some(Err(e)) = simulation_result.result {
137 let error = match e {
138 TransactionError::InstructionError(_, _) => RpcError::TransactionError(e),
139 _ => RpcError::from(BanksClientError::TransactionError(e)),
140 };
141 return Err(error);
142 }
143
144 let event = simulation_result
146 .simulation_details
147 .and_then(|details| details.inner_instructions)
148 .and_then(|instructions| {
149 instructions.iter().flatten().find_map(|inner_instruction| {
150 T::try_from_slice(inner_instruction.instruction.data.as_slice()).ok()
151 })
152 });
153 if let Some(Ok(())) = simulation_result.result {
155 let result = self
156 .context
157 .banks_client
158 .process_transaction(transaction)
159 .await;
160 if let Err(e) = result {
161 let error = RpcError::from(e);
162 return Err(error);
163 }
164 }
165
166 if let Some(transaction_params) = transaction_params {
168 let mut deduped_signers = signers.to_vec();
169 deduped_signers.dedup();
170 let post_balance = self.get_account(*payer).await?.unwrap().lamports;
171
172 let mut network_fee: i64 = 0;
174 if transaction_params.num_input_compressed_accounts != 0 {
175 network_fee += transaction_params.fee_config.network_fee as i64;
176 }
177 if transaction_params.num_new_addresses != 0 {
178 network_fee += transaction_params.fee_config.address_network_fee as i64;
179 }
180 let expected_post_balance = pre_balance as i64
181 - i64::from(transaction_params.num_new_addresses)
182 * transaction_params.fee_config.address_queue_rollover as i64
183 - i64::from(transaction_params.num_output_compressed_accounts)
184 * transaction_params.fee_config.state_merkle_tree_rollover as i64
185 - transaction_params.compress
186 - transaction_params.fee_config.solana_network_fee * deduped_signers.len() as i64
187 - network_fee;
188
189 if post_balance as i64 != expected_post_balance {
190 println!("transaction_params: {:?}", transaction_params);
191 println!("pre_balance: {}", pre_balance);
192 println!("post_balance: {}", post_balance);
193 println!("expected post_balance: {}", expected_post_balance);
194 println!(
195 "diff post_balance: {}",
196 post_balance as i64 - expected_post_balance
197 );
198 println!(
199 "rollover fee: {}",
200 transaction_params.fee_config.state_merkle_tree_rollover
201 );
202 println!(
203 "address_network_fee: {}",
204 transaction_params.fee_config.address_network_fee
205 );
206 println!("network_fee: {}", network_fee);
207 println!("num signers {}", deduped_signers.len());
208 return Err(RpcError::from(BanksClientError::TransactionError(
209 TransactionError::InstructionError(0, InstructionError::Custom(11111)),
210 )));
211 }
212 }
213
214 let slot = self.context.banks_client.get_root_slot().await?;
215 let result = event.map(|event| (event, signature, slot));
216 Ok(result)
217 }
218
219 async fn confirm_transaction(&self, _transaction: Signature) -> Result<bool, RpcError> {
220 Ok(true)
221 }
222
223 async fn get_account(&mut self, address: Pubkey) -> Result<Option<Account>, RpcError> {
224 self.context
225 .banks_client
226 .get_account(address)
227 .await
228 .map_err(RpcError::from)
229 }
230
231 fn set_account(&mut self, address: &Pubkey, account: &AccountSharedData) {
232 self.context.set_account(address, account);
233 }
234
235 async fn get_minimum_balance_for_rent_exemption(
236 &mut self,
237 data_len: usize,
238 ) -> Result<u64, RpcError> {
239 let rent = self
240 .context
241 .banks_client
242 .get_rent()
243 .await
244 .map_err(RpcError::from);
245
246 Ok(rent?.minimum_balance(data_len))
247 }
248
249 async fn airdrop_lamports(
250 &mut self,
251 to: &Pubkey,
252 lamports: u64,
253 ) -> Result<Signature, RpcError> {
254 let transfer_instruction =
256 system_instruction::transfer(&self.context.payer.pubkey(), to, lamports);
257 let latest_blockhash = self.get_latest_blockhash().await.unwrap();
258 let transaction = Transaction::new_signed_with_payer(
260 &[transfer_instruction],
261 Some(&self.get_payer().pubkey()),
262 &vec![&self.get_payer()],
263 latest_blockhash,
264 );
265 let sig = *transaction.signatures.first().unwrap();
266
267 self.context
269 .banks_client
270 .process_transaction(transaction)
271 .await?;
272
273 Ok(sig)
274 }
275
276 async fn get_balance(&mut self, pubkey: &Pubkey) -> Result<u64, RpcError> {
277 self.context
278 .banks_client
279 .get_balance(*pubkey)
280 .await
281 .map_err(RpcError::from)
282 }
283
284 async fn get_latest_blockhash(&mut self) -> Result<Hash, RpcError> {
285 self.context
286 .get_new_latest_blockhash()
287 .await
288 .map_err(|e| RpcError::from(BanksClientError::from(e)))
289 }
290
291 async fn get_slot(&mut self) -> Result<u64, RpcError> {
292 self.context
293 .banks_client
294 .get_root_slot()
295 .await
296 .map_err(RpcError::from)
297 }
298
299 async fn warp_to_slot(&mut self, slot: Slot) -> Result<(), RpcError> {
300 self.context
301 .warp_to_slot(slot)
302 .map_err(|_| RpcError::InvalidWarpSlot)
303 }
304
305 async fn send_transaction(&self, _transaction: &Transaction) -> Result<Signature, RpcError> {
306 unimplemented!("send transaction is unimplemented for ProgramTestRpcConnection")
307 }
308}
309
310impl MerkleTreeExt for ProgramTestRpcConnection {}