light_system_program/invoke/
sum_check.rs

1use crate::{
2    errors::SystemProgramError, sdk::compressed_account::PackedCompressedAccountWithMerkleContext,
3    OutputCompressedAccountWithPackedContext,
4};
5use anchor_lang::solana_program::program_error::ProgramError;
6use anchor_lang::Result;
7use light_macros::heap_neutral;
8
9#[inline(never)]
10#[heap_neutral]
11pub fn sum_check(
12    input_compressed_accounts_with_merkle_context: &[PackedCompressedAccountWithMerkleContext],
13    output_compressed_accounts: &[OutputCompressedAccountWithPackedContext],
14    relay_fee: &Option<u64>,
15    compress_or_decompress_lamports: &Option<u64>,
16    is_compress: &bool,
17) -> Result<()> {
18    let mut sum: u64 = 0;
19    for compressed_account_with_context in input_compressed_accounts_with_merkle_context.iter() {
20        if compressed_account_with_context.read_only {
21            unimplemented!("read_only accounts are not supported. Set read_only to false.");
22        }
23        sum = sum
24            .checked_add(compressed_account_with_context.compressed_account.lamports)
25            .ok_or(ProgramError::ArithmeticOverflow)
26            .map_err(|_| SystemProgramError::ComputeInputSumFailed)?;
27    }
28
29    match compress_or_decompress_lamports {
30        Some(lamports) => {
31            if *is_compress {
32                sum = sum
33                    .checked_add(*lamports)
34                    .ok_or(ProgramError::ArithmeticOverflow)
35                    .map_err(|_| SystemProgramError::ComputeOutputSumFailed)?;
36            } else {
37                sum = sum
38                    .checked_sub(*lamports)
39                    .ok_or(ProgramError::ArithmeticOverflow)
40                    .map_err(|_| SystemProgramError::ComputeOutputSumFailed)?;
41            }
42        }
43        None => (),
44    }
45
46    for compressed_account in output_compressed_accounts.iter() {
47        sum = sum
48            .checked_sub(compressed_account.compressed_account.lamports)
49            .ok_or(ProgramError::ArithmeticOverflow)
50            .map_err(|_| SystemProgramError::ComputeOutputSumFailed)?;
51    }
52
53    if let Some(relay_fee) = relay_fee {
54        sum = sum
55            .checked_sub(*relay_fee)
56            .ok_or(ProgramError::ArithmeticOverflow)
57            .map_err(|_| SystemProgramError::ComputeRpcSumFailed)?;
58    }
59
60    if sum == 0 {
61        Ok(())
62    } else {
63        Err(SystemProgramError::SumCheckFailed.into())
64    }
65}
66
67#[cfg(test)]
68mod test {
69    use solana_sdk::{signature::Keypair, signer::Signer};
70
71    use super::*;
72    use crate::sdk::compressed_account::{CompressedAccount, PackedMerkleContext};
73
74    #[test]
75    fn test_sum_check() {
76        // SUCCEED: no relay fee, compression
77        sum_check_test(&[100, 50], &[150], None, None, false).unwrap();
78        sum_check_test(&[75, 25, 25], &[25, 25, 25, 25, 12, 13], None, None, false).unwrap();
79
80        // FAIL: no relay fee, compression
81        sum_check_test(&[100, 50], &[150 + 1], None, None, false).unwrap_err();
82        sum_check_test(&[100, 50], &[150 - 1], None, None, false).unwrap_err();
83        sum_check_test(&[100, 50], &[], None, None, false).unwrap_err();
84        sum_check_test(&[], &[100, 50], None, None, false).unwrap_err();
85        sum_check_test(&[100, 50], &[0], None, None, false).unwrap_err();
86        sum_check_test(&[0], &[100, 50], None, None, false).unwrap_err();
87
88        // SUCCEED: empty
89        sum_check_test(&[], &[], None, None, true).unwrap();
90        sum_check_test(&[], &[], None, None, false).unwrap();
91        sum_check_test(&[0], &[0], None, None, true).unwrap();
92        sum_check_test(&[0], &[0], None, None, false).unwrap();
93        // FAIL: empty
94        sum_check_test(&[], &[], Some(1), None, false).unwrap_err();
95        sum_check_test(&[], &[], None, Some(1), false).unwrap_err();
96        sum_check_test(&[], &[], None, Some(1), true).unwrap_err();
97
98        // SUCCEED: with compress
99        sum_check_test(&[100], &[123], None, Some(23), true).unwrap();
100        sum_check_test(&[], &[150], None, Some(150), true).unwrap();
101        // FAIL: compress
102        sum_check_test(&[], &[150], None, Some(150 - 1), true).unwrap_err();
103        sum_check_test(&[], &[150], None, Some(150 + 1), true).unwrap_err();
104
105        // SUCCEED: with decompress
106        sum_check_test(&[100, 50], &[100], None, Some(50), false).unwrap();
107        sum_check_test(&[100, 50], &[], None, Some(150), false).unwrap();
108        // FAIL: decompress
109        sum_check_test(&[100, 50], &[], None, Some(150 - 1), false).unwrap_err();
110        sum_check_test(&[100, 50], &[], None, Some(150 + 1), false).unwrap_err();
111
112        // SUCCEED: with relay fee
113        sum_check_test(&[100, 50], &[125], Some(25), None, false).unwrap();
114        sum_check_test(&[100, 50], &[150], Some(25), Some(25), true).unwrap();
115        sum_check_test(&[100, 50], &[100], Some(25), Some(25), false).unwrap();
116
117        // FAIL: relay fee
118        sum_check_test(&[100, 50], &[2125], Some(25 - 1), None, false).unwrap_err();
119        sum_check_test(&[100, 50], &[2125], Some(25 + 1), None, false).unwrap_err();
120    }
121
122    fn sum_check_test(
123        input_amounts: &[u64],
124        output_amounts: &[u64],
125        relay_fee: Option<u64>,
126        compress_or_decompress_lamports: Option<u64>,
127        is_compress: bool,
128    ) -> Result<()> {
129        let mut inputs = Vec::new();
130        for i in input_amounts.iter() {
131            inputs.push(PackedCompressedAccountWithMerkleContext {
132                compressed_account: CompressedAccount {
133                    owner: Keypair::new().pubkey(),
134                    lamports: *i,
135                    address: None,
136                    data: None,
137                },
138                merkle_context: PackedMerkleContext {
139                    merkle_tree_pubkey_index: 0,
140                    nullifier_queue_pubkey_index: 0,
141                    leaf_index: 0,
142                    queue_index: None,
143                },
144                root_index: 1,
145                read_only: false,
146            });
147        }
148        let mut outputs = Vec::new();
149        for amount in output_amounts.iter() {
150            outputs.push(OutputCompressedAccountWithPackedContext {
151                compressed_account: CompressedAccount {
152                    owner: Keypair::new().pubkey(),
153                    lamports: *amount,
154                    address: None,
155                    data: None,
156                },
157                merkle_tree_index: 0,
158            });
159        }
160
161        sum_check(
162            &inputs,
163            &outputs,
164            &relay_fee,
165            &compress_or_decompress_lamports,
166            &is_compress,
167        )
168    }
169}