light_client/
fee.rs

1use solana_keypair::Keypair;
2use solana_pubkey::Pubkey;
3
4use crate::rpc::{Rpc, RpcError};
5
6#[derive(Debug, Clone, PartialEq)]
7pub struct FeeConfig {
8    pub state_merkle_tree_rollover: u64,
9    pub address_queue_rollover: u64,
10    // TODO: refactor to allow multiple state and address tree configs
11    // pub state_tree_configs: Vec<StateMerkleTreeConfig>,
12    // pub address_tree_configs: Vec<AddressMerkleTreeConfig>,
13    pub network_fee: u64,
14    pub address_network_fee: u64,
15    pub solana_network_fee: i64,
16}
17
18impl Default for FeeConfig {
19    fn default() -> Self {
20        Self {
21            // rollover fee plus additional lamports for the cpi account
22            state_merkle_tree_rollover: 300,
23            address_queue_rollover: 392,
24            // TODO: refactor to allow multiple state and address tree configs
25            // state_tree_configs: vec![StateMerkleTreeConfig::default()],
26            // address_tree_configs: vec![AddressMerkleTreeConfig::default()],
27            network_fee: 5000,
28            address_network_fee: 5000,
29            solana_network_fee: 5000,
30        }
31    }
32}
33
34impl FeeConfig {
35    pub fn test_batched() -> Self {
36        Self {
37            // rollover fee plus additional lamports for the cpi account
38            state_merkle_tree_rollover: 1,
39            address_queue_rollover: 392, // not batched
40            network_fee: 5000,
41            address_network_fee: 5000,
42            solana_network_fee: 5000,
43        }
44    }
45}
46
47#[derive(Debug, Clone, PartialEq)]
48pub struct TransactionParams {
49    pub num_input_compressed_accounts: u8,
50    pub num_output_compressed_accounts: u8,
51    pub num_new_addresses: u8,
52    pub compress: i64,
53    pub fee_config: FeeConfig,
54}
55
56pub async fn assert_transaction_params(
57    rpc: &mut impl Rpc,
58    payer: &Pubkey,
59    signers: &[&Keypair],
60    pre_balance: u64,
61    params: Option<TransactionParams>,
62) -> Result<(), RpcError> {
63    if let Some(transaction_params) = params {
64        let mut deduped_signers = signers.to_vec();
65        deduped_signers.dedup();
66        let post_balance = rpc.get_account(*payer).await?.unwrap().lamports;
67
68        // a network_fee is charged if there are input compressed accounts or new addresses
69        let mut network_fee: i64 = 0;
70        if transaction_params.num_input_compressed_accounts != 0
71            || transaction_params.num_output_compressed_accounts != 0
72        {
73            network_fee += transaction_params.fee_config.network_fee as i64;
74        }
75        if transaction_params.num_new_addresses != 0 {
76            network_fee += transaction_params.fee_config.address_network_fee as i64;
77        }
78        let expected_post_balance = pre_balance as i64
79            - i64::from(transaction_params.num_new_addresses)
80                * transaction_params.fee_config.address_queue_rollover as i64
81            - i64::from(transaction_params.num_output_compressed_accounts)
82                * transaction_params.fee_config.state_merkle_tree_rollover as i64
83            - transaction_params.compress
84            - transaction_params.fee_config.solana_network_fee * deduped_signers.len() as i64
85            - network_fee;
86
87        if post_balance as i64 != expected_post_balance {
88            println!("transaction_params: {:?}", transaction_params);
89            println!("pre_balance: {}", pre_balance);
90            println!("post_balance: {}", post_balance);
91            println!("expected post_balance: {}", expected_post_balance);
92            println!(
93                "diff post_balance: {}",
94                post_balance as i64 - expected_post_balance
95            );
96            println!(
97                "rollover fee: {}",
98                transaction_params.fee_config.state_merkle_tree_rollover
99            );
100            println!(
101                "address_network_fee: {}",
102                transaction_params.fee_config.address_network_fee
103            );
104            println!("network_fee: {}", network_fee);
105            println!("num signers {}", deduped_signers.len());
106            return Err(RpcError::CustomError("Transaction fee error.".to_string()));
107        }
108    }
109    Ok(())
110}