1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
use std::mem;

use async_trait::async_trait;
use light_concurrent_merkle_tree::{
    copy::ConcurrentMerkleTreeCopy, errors::ConcurrentMerkleTreeError, light_hasher::Poseidon,
};
use light_indexed_merkle_tree::{copy::IndexedMerkleTreeCopy, errors::IndexedMerkleTreeError};
use light_sdk::state::MerkleTreeMetadata;
use solana_sdk::pubkey::Pubkey;
use thiserror::Error;

use super::{RpcConnection, RpcError};

#[derive(Error, Debug)]
pub enum MerkleTreeExtError {
    #[error(transparent)]
    Rpc(#[from] RpcError),

    #[error(transparent)]
    ConcurrentMerkleTree(#[from] ConcurrentMerkleTreeError),

    #[error(transparent)]
    IndexedMerkleTree(#[from] IndexedMerkleTreeError),
}

/// Extension to the RPC connection which provides convenience utilities for
/// fetching Merkle trees.
#[async_trait]
pub trait MerkleTreeExt: RpcConnection {
    async fn get_state_merkle_tree(
        &mut self,
        pubkey: Pubkey,
    ) -> Result<ConcurrentMerkleTreeCopy<Poseidon, 26>, MerkleTreeExtError> {
        let account = self.get_account(pubkey).await?.unwrap();
        let tree = ConcurrentMerkleTreeCopy::from_bytes_copy(
            &account.data[8 + mem::size_of::<MerkleTreeMetadata>()..],
        )?;

        Ok(tree)
    }

    async fn get_address_merkle_tree(
        &mut self,
        pubkey: Pubkey,
    ) -> Result<IndexedMerkleTreeCopy<Poseidon, usize, 26, 16>, MerkleTreeExtError> {
        let account = self.get_account(pubkey).await?.unwrap();
        let tree = IndexedMerkleTreeCopy::from_bytes_copy(
            &account.data[8 + mem::size_of::<MerkleTreeMetadata>()..],
        )?;

        Ok(tree)
    }
}