Skip to main content

darkpool_client/
merkle_tree.rs

1//! 32-level Lean IMT with Poseidon2 hashing, mirroring the on-chain `DarkPool` commitment tree.
2//! Lean IMT: sibling=0 means "empty" and propagates without hashing (matches Noir circuit).
3
4use ethers::types::U256;
5use std::collections::HashMap;
6use tracing::debug;
7
8use crate::crypto_helpers::poseidon_hash;
9
10pub const TREE_DEPTH: usize = 32;
11
12#[derive(Debug, Clone)]
13pub struct MerklePath {
14    /// Empty siblings are 0 (Lean IMT), not pre-computed zero hashes
15    pub siblings: [U256; TREE_DEPTH],
16    pub indices: [u8; TREE_DEPTH],
17}
18
19impl MerklePath {
20    #[allow(clippy::must_use_candidate)]
21    pub fn siblings_vec(&self) -> Vec<U256> {
22        self.siblings.to_vec()
23    }
24}
25
26#[derive(Debug)]
27pub struct LocalMerkleTree {
28    leaves: Vec<U256>,
29    /// (level, index) -> hash. Level 0 = leaves, Level 31 = just below root.
30    nodes: HashMap<(u8, u64), U256>,
31}
32
33impl LocalMerkleTree {
34    #[must_use]
35    pub fn new() -> Self {
36        Self {
37            leaves: Vec::new(),
38            nodes: HashMap::new(),
39        }
40    }
41
42    #[allow(clippy::must_use_candidate)]
43    pub fn size(&self) -> u64 {
44        self.leaves.len() as u64
45    }
46
47    #[allow(clippy::must_use_candidate)]
48    pub fn root(&self) -> U256 {
49        if self.leaves.is_empty() {
50            return U256::zero();
51        }
52        self.compute_root()
53    }
54
55    pub fn insert(&mut self, commitment: U256) -> u64 {
56        let index = self.leaves.len() as u64;
57        self.leaves.push(commitment);
58        self.update_path(index);
59
60        debug!(
61            "Inserted leaf {} at index {}. New root: {:?}",
62            commitment,
63            index,
64            self.root()
65        );
66
67        index
68    }
69
70    #[must_use]
71    pub fn get_path(&self, index: u64) -> MerklePath {
72        let mut siblings = [U256::zero(); TREE_DEPTH];
73        let mut indices = [0u8; TREE_DEPTH];
74
75        let mut current_index = index;
76
77        for level in 0..TREE_DEPTH {
78            let sibling_index = if current_index.is_multiple_of(2) {
79                current_index + 1
80            } else {
81                current_index - 1
82            };
83
84            indices[level] = (current_index % 2) as u8;
85            siblings[level] = self.get_node_lean(level as u8, sibling_index);
86            current_index /= 2;
87        }
88
89        MerklePath { siblings, indices }
90    }
91
92    #[must_use]
93    pub fn verify_path(&self, leaf: U256, _index: u64, path: &MerklePath) -> bool {
94        let mut current = leaf;
95
96        for level in 0..TREE_DEPTH {
97            let sibling = path.siblings[level];
98
99            if sibling.is_zero() {
100            } else {
101                let is_right = path.indices[level] == 1;
102                current = if is_right {
103                    poseidon_hash(&[sibling, current])
104                } else {
105                    poseidon_hash(&[current, sibling])
106                };
107            }
108        }
109
110        current == self.root()
111    }
112
113    fn get_node_lean(&self, level: u8, index: u64) -> U256 {
114        if level == 0 {
115            return self
116                .leaves
117                .get(index as usize)
118                .copied()
119                .unwrap_or(U256::zero());
120        }
121
122        self.nodes
123            .get(&(level, index))
124            .copied()
125            .unwrap_or(U256::zero())
126    }
127
128    fn update_path(&mut self, leaf_index: u64) {
129        let mut current_index = leaf_index;
130
131        for level in 0..(TREE_DEPTH - 1) {
132            let parent_index = current_index / 2;
133            let left_child_index = parent_index * 2;
134            let right_child_index = left_child_index + 1;
135
136            let left = self.get_node_lean(level as u8, left_child_index);
137            let right = self.get_node_lean(level as u8, right_child_index);
138
139            let parent = if left.is_zero() && right.is_zero() {
140                U256::zero()
141            } else if right.is_zero() {
142                left
143            } else if left.is_zero() {
144                right
145            } else {
146                poseidon_hash(&[left, right])
147            };
148
149            if !parent.is_zero() {
150                self.nodes.insert(((level + 1) as u8, parent_index), parent);
151            }
152
153            current_index = parent_index;
154        }
155    }
156
157    fn compute_root(&self) -> U256 {
158        self.get_node_lean((TREE_DEPTH - 1) as u8, 0)
159    }
160
161    #[allow(clippy::must_use_candidate)]
162    pub fn leaves(&self) -> &[U256] {
163        &self.leaves
164    }
165
166    pub fn clear(&mut self) {
167        self.leaves.clear();
168        self.nodes.clear();
169    }
170
171    pub fn load_from_leaves(&mut self, leaves: &[U256]) {
172        self.clear();
173        for leaf in leaves {
174            self.insert(*leaf);
175        }
176    }
177}
178
179impl Default for LocalMerkleTree {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_empty_tree_root() {
191        let tree = LocalMerkleTree::new();
192        let root = tree.root();
193        assert!(!root.is_zero() || tree.size() == 0);
194    }
195
196    #[test]
197    fn test_insert_and_root_changes() {
198        let mut tree = LocalMerkleTree::new();
199        let root0 = tree.root();
200
201        tree.insert(U256::from(1));
202        let root1 = tree.root();
203
204        tree.insert(U256::from(2));
205        let root2 = tree.root();
206
207        assert_ne!(root0, root1);
208        assert_ne!(root1, root2);
209    }
210
211    #[test]
212    fn test_deterministic_root() {
213        let mut tree1 = LocalMerkleTree::new();
214        let mut tree2 = LocalMerkleTree::new();
215
216        tree1.insert(U256::from(100));
217        tree1.insert(U256::from(200));
218
219        tree2.insert(U256::from(100));
220        tree2.insert(U256::from(200));
221
222        assert_eq!(tree1.root(), tree2.root());
223    }
224
225    #[test]
226    fn test_get_path() {
227        let mut tree = LocalMerkleTree::new();
228
229        let leaf = U256::from(12345);
230        let index = tree.insert(leaf);
231
232        let path = tree.get_path(index);
233
234        assert_eq!(path.siblings.len(), TREE_DEPTH);
235        assert_eq!(path.indices.len(), TREE_DEPTH);
236        assert!(tree.verify_path(leaf, index, &path));
237    }
238
239    #[test]
240    fn test_verify_path_fails_for_wrong_leaf() {
241        let mut tree = LocalMerkleTree::new();
242
243        let leaf = U256::from(12345);
244        let index = tree.insert(leaf);
245
246        let path = tree.get_path(index);
247
248        let wrong_leaf = U256::from(99999);
249        assert!(!tree.verify_path(wrong_leaf, index, &path));
250    }
251
252    #[test]
253    fn test_multiple_inserts_and_paths() {
254        let mut tree = LocalMerkleTree::new();
255        let mut leaves_and_indices = Vec::new();
256
257        for i in 0..10 {
258            let leaf = U256::from(i * 1000 + 1);
259            let index = tree.insert(leaf);
260            leaves_and_indices.push((leaf, index));
261        }
262
263        for (leaf, index) in &leaves_and_indices {
264            let path = tree.get_path(*index);
265            assert!(
266                tree.verify_path(*leaf, *index, &path),
267                "Path verification failed for leaf at index {}",
268                index
269            );
270        }
271    }
272
273    #[test]
274    fn test_load_from_leaves() {
275        let leaves = vec![U256::from(1), U256::from(2), U256::from(3)];
276
277        let mut tree1 = LocalMerkleTree::new();
278        for leaf in &leaves {
279            tree1.insert(*leaf);
280        }
281
282        let mut tree2 = LocalMerkleTree::new();
283        tree2.load_from_leaves(&leaves);
284
285        assert_eq!(tree1.root(), tree2.root());
286    }
287}