light_merkle_tree/
lib.rs

1use std::marker::PhantomData;
2
3#[cfg(feature = "solana")]
4use anchor_lang::prelude::*;
5
6use bytemuck::{Pod, Zeroable};
7use config::MerkleTreeConfig;
8use hasher::{Hash, Hasher};
9
10pub mod config;
11pub mod constants;
12pub mod hasher;
13
14pub const DATA_LEN: usize = 32;
15pub const HASH_LEN: usize = 32;
16pub const MAX_HEIGHT: usize = 18;
17pub const MERKLE_TREE_HISTORY_SIZE: usize = 20;
18
19#[cfg(feature = "solana")]
20#[derive(AnchorSerialize, AnchorDeserialize, PartialEq, Eq, Debug, Clone, Copy)]
21pub enum HashFunction {
22    Sha256,
23    Poseidon,
24}
25
26// TODO(vadorovsky): Teach Anchor to accept `usize`, constants and const
27// generics when generating IDL.
28#[cfg_attr(feature = "solana", derive(AnchorSerialize, AnchorDeserialize))]
29#[derive(PartialEq, Eq, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct MerkleTree<H, C>
32where
33    H: Hasher,
34    C: MerkleTreeConfig,
35{
36    /// Height of the Merkle tree.
37    pub height: u64,
38    /// Subtree hashes.
39    pub filled_subtrees: [[u8; 32]; 18],
40    /// Full history of roots of the Merkle tree (the last one is the current
41    /// one).
42    pub roots: [[u8; 32]; 20],
43    /// Next index to insert a leaf.
44    pub next_index: u64,
45    /// Current index of the root.
46    pub current_root_index: u64,
47
48    /// Hash implementation used on the Merkle tree.
49    #[cfg(feature = "solana")]
50    pub hash_function: HashFunction,
51
52    hasher: PhantomData<H>,
53    config: PhantomData<C>,
54}
55
56impl<H, C> MerkleTree<H, C>
57where
58    H: Hasher,
59    C: MerkleTreeConfig,
60{
61    fn check_height(height: usize) {
62        assert!(height > 0);
63        assert!(height <= MAX_HEIGHT);
64    }
65
66    fn new_filled_subtrees(height: usize) -> [[u8; HASH_LEN]; MAX_HEIGHT] {
67        let mut filled_subtrees = [[0; HASH_LEN]; MAX_HEIGHT];
68
69        for i in 0..height {
70            filled_subtrees[i] = C::ZERO_BYTES[i];
71        }
72
73        filled_subtrees
74    }
75
76    fn new_roots(height: usize) -> [[u8; HASH_LEN]; MERKLE_TREE_HISTORY_SIZE] {
77        let mut roots = [[0; HASH_LEN]; MERKLE_TREE_HISTORY_SIZE];
78        roots[0] = C::ZERO_BYTES[height - 1];
79
80        roots
81    }
82
83    /// Create a new Merkle tree with the given height.
84    #[cfg(not(feature = "solana"))]
85    pub fn new(height: usize, #[cfg(feature = "solana")] hash_function: HashFunction) -> Self {
86        Self::check_height(height);
87
88        let filled_subtrees = Self::new_filled_subtrees(height);
89        let roots = Self::new_roots(height);
90
91        MerkleTree {
92            height: height as u64,
93            filled_subtrees,
94            roots,
95            next_index: 0,
96            current_root_index: 0,
97            #[cfg(feature = "solana")]
98            hash_function,
99            hasher: PhantomData,
100            config: PhantomData,
101        }
102    }
103
104    /// Initialize the Merkle tree with subtrees and roots based on the given
105    /// height.
106    #[cfg(feature = "solana")]
107    pub fn init(&mut self, height: usize, hash_function: HashFunction) {
108        Self::check_height(height);
109
110        self.height = height as u64;
111        self.filled_subtrees = Self::new_filled_subtrees(height);
112        self.roots = Self::new_roots(height);
113        self.hash_function = hash_function;
114    }
115
116    pub fn hash(&mut self, leaf1: [u8; DATA_LEN], leaf2: [u8; DATA_LEN]) -> Hash {
117        H::hashv(&[&leaf1, &leaf2])
118    }
119
120    pub fn insert(&mut self, leaf1: [u8; DATA_LEN], leaf2: [u8; DATA_LEN]) {
121        // Check if next index doesn't exceed the Merkle tree capacity.
122        assert_ne!(self.next_index, 2u64.pow(self.height as u32));
123
124        let mut current_index = self.next_index / 2;
125        let mut current_level_hash = self.hash(leaf1, leaf2);
126
127        for i in 1..self.height as usize {
128            let (left, right) = if current_index % 2 == 0 {
129                self.filled_subtrees[i] = current_level_hash;
130                (current_level_hash, C::ZERO_BYTES[i])
131            } else {
132                (self.filled_subtrees[i], current_level_hash)
133            };
134
135            current_index /= 2;
136            current_level_hash = self.hash(left, right);
137        }
138
139        self.current_root_index = (self.current_root_index + 1) % MERKLE_TREE_HISTORY_SIZE as u64;
140        self.roots[self.current_root_index as usize] = current_level_hash;
141        self.next_index += 2;
142    }
143
144    pub fn is_known_root(&self, root: [u8; HASH_LEN]) -> bool {
145        for i in (0..(self.current_root_index as usize + 1)).rev() {
146            if self.roots[i] == root {
147                return true;
148            }
149        }
150        return false;
151    }
152
153    pub fn last_root(&self) -> [u8; HASH_LEN] {
154        self.roots[self.current_root_index as usize]
155    }
156}
157
158/// The [`Pod`](bytemuck::Pod) trait is used under the hood by the
159/// [`zero_copy`](anchor_lang::zero_copy) attribute macro and is required for
160/// usage in zero-copy Solana accounts.
161///
162/// SAFETY: Generic parameters are used only as `PhantomData` and they don't
163/// affect the layout of the struct nor its size or padding. The only reason
164/// why we can't `#[derive(Pod)]` is because bytemuck is not aware of that and
165/// it doesn't allow to derive `Pod` for structs with generic parameters.
166/// Would be nice to fix that upstream:
167/// https://github.com/Lokathor/bytemuck/issues/191
168unsafe impl<H, C> Pod for MerkleTree<H, C>
169where
170    H: Hasher + Copy + 'static,
171    C: MerkleTreeConfig + Copy + 'static,
172{
173}
174
175/// The [`Zeroable`](bytemuck::Zeroable) trait is used under the hood by the
176/// [`zero_copy`](anchor_lang::zero_copy) attribute macro and is required for
177/// usage in zero-copy Solana accounts.
178///
179/// SAFETY: Generic parameters are used only as `PhantomData` and they don't
180/// affect the layout of the struct nor its size or padding. The only reason
181/// why we can't `#[derive(Zeroable)]` is because bytemuck is not aware of that
182/// and it doesn't allow to derive `Zeroable` for structs with generic
183/// parameters.
184/// Would be nice to fix that upstream:
185/// https://github.com/Lokathor/bytemuck/issues/191
186unsafe impl<H, C> Zeroable for MerkleTree<H, C>
187where
188    H: Hasher,
189    C: MerkleTreeConfig,
190{
191}
192
193#[cfg(feature = "solana")]
194impl<H, C> Owner for MerkleTree<H, C>
195where
196    H: Hasher,
197    C: MerkleTreeConfig,
198{
199    fn owner() -> Pubkey {
200        C::PROGRAM_ID
201    }
202}