1use crate::zk::error::ZKError;
8use alloc::collections::BTreeMap;
9use alloc::vec::Vec;
10use soroban_sdk::{BytesN, Env};
11
12use super::proof::OnChainMerkleProof;
13
14const SMT_DEPTH: u32 = 16; pub struct SparseMerkleTree {
24 root: [u8; 32],
25 nodes: BTreeMap<(u32, u32), [u8; 32]>, defaults: Vec<[u8; 32]>, }
28
29impl SparseMerkleTree {
30 pub fn new(env: &Env) -> Self {
32 let defaults = precompute_defaults(env);
33 let root = defaults[SMT_DEPTH as usize];
34
35 Self {
36 root,
37 nodes: BTreeMap::new(),
38 defaults,
39 }
40 }
41
42 pub fn root(&self) -> [u8; 32] {
44 self.root
45 }
46
47 pub fn root_bytes(&self, env: &Env) -> BytesN<32> {
49 BytesN::from_array(env, &self.root)
50 }
51
52 pub fn insert(&mut self, env: &Env, key: &[u8; 32], value: &[u8; 32]) -> Result<(), ZKError> {
56 let leaf_index = key_to_index(key);
57 let leaf_hash = hash_leaf(env, value);
58
59 self.nodes.insert((0, leaf_index), leaf_hash);
61
62 let mut idx = leaf_index;
64 for level in 0..SMT_DEPTH {
65 let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
66 let left_idx = if idx % 2 == 0 { idx } else { sibling_idx };
67 let right_idx = if idx % 2 == 0 { sibling_idx } else { idx };
68
69 let left = self.get_node(level, left_idx);
70 let right = self.get_node(level, right_idx);
71 let parent = hash_pair(env, &left, &right);
72
73 idx /= 2;
74 self.nodes.insert((level + 1, idx), parent);
75 }
76
77 self.root = self.get_node(SMT_DEPTH, 0);
78 Ok(())
79 }
80
81 pub fn get(&self, key: &[u8; 32]) -> Option<[u8; 32]> {
83 let leaf_index = key_to_index(key);
84 self.nodes.get(&(0, leaf_index)).copied()
85 }
86
87 pub fn prove(&self, env: &Env, key: &[u8; 32]) -> OnChainMerkleProof {
89 let leaf_index = key_to_index(key);
90 let leaf = self.get_node(0, leaf_index);
91
92 let mut siblings: soroban_sdk::Vec<BytesN<32>> = soroban_sdk::Vec::new(env);
93 let mut path_bits: u32 = 0;
94 let mut idx = leaf_index;
95
96 for level in 0..SMT_DEPTH {
97 let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
98 let sibling = self.get_node(level, sibling_idx);
99 siblings.push_back(BytesN::from_array(env, &sibling));
100
101 if idx % 2 != 0 {
102 path_bits |= 1 << level;
103 }
104 idx /= 2;
105 }
106
107 OnChainMerkleProof {
108 siblings,
109 path_bits,
110 leaf: BytesN::from_array(env, &leaf),
111 leaf_index,
112 depth: SMT_DEPTH,
113 }
114 }
115
116 fn get_node(&self, level: u32, index: u32) -> [u8; 32] {
118 self.nodes
119 .get(&(level, index))
120 .copied()
121 .unwrap_or(self.defaults[level as usize])
122 }
123}
124
125fn key_to_index(key: &[u8; 32]) -> u32 {
127 let b0 = key[0] as u32;
128 let b1 = key[1] as u32;
129 (b0 | (b1 << 8)) % (1 << SMT_DEPTH)
130}
131
132fn precompute_defaults(env: &Env) -> Vec<[u8; 32]> {
136 let mut defaults = Vec::with_capacity(SMT_DEPTH as usize + 1);
137 defaults.push([0u8; 32]); for _ in 0..SMT_DEPTH {
140 let prev = defaults.last().unwrap();
141 defaults.push(hash_pair(env, prev, prev));
142 }
143
144 defaults
145}
146
147fn hash_leaf(env: &Env, data: &[u8; 32]) -> [u8; 32] {
149 let mut input = [0u8; 33];
150 input[0] = 0x00;
151 input[1..].copy_from_slice(data);
152 let bytes = soroban_sdk::Bytes::from_slice(env, &input);
153 env.crypto().sha256(&bytes).to_array()
154}
155
156fn hash_pair(env: &Env, left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
158 let mut input = [0u8; 65];
159 input[0] = 0x01;
160 input[1..33].copy_from_slice(left);
161 input[33..65].copy_from_slice(right);
162 let bytes = soroban_sdk::Bytes::from_slice(env, &input);
163 env.crypto().sha256(&bytes).to_array()
164}
165
166#[cfg(feature = "hazmat-crypto")]
172pub struct PoseidonSparseMerkleTree {
173 root: soroban_sdk::U256,
174 nodes: BTreeMap<(u32, u32), soroban_sdk::U256>,
175 defaults: Vec<soroban_sdk::U256>,
176}
177
178#[cfg(feature = "hazmat-crypto")]
179impl PoseidonSparseMerkleTree {
180 pub fn new(env: &Env, params: &crate::zk::crypto::Poseidon2Params) -> Self {
182 let defaults = precompute_poseidon_defaults(env, params);
183 let root = defaults[SMT_DEPTH as usize].clone();
184
185 Self {
186 root,
187 nodes: BTreeMap::new(),
188 defaults,
189 }
190 }
191
192 pub fn root(&self) -> soroban_sdk::U256 {
194 self.root.clone()
195 }
196
197 pub fn insert(
199 &mut self,
200 env: &Env,
201 params: &crate::zk::crypto::Poseidon2Params,
202 key: &[u8; 32],
203 value: &soroban_sdk::U256,
204 ) -> Result<(), crate::zk::error::ZKError> {
205 let leaf_index = key_to_index(key);
206 let zero = soroban_sdk::U256::from_u32(env, 0);
207 let leaf_hash = crate::zk::crypto::poseidon2_hash(env, params, value, &zero);
208
209 self.nodes.insert((0, leaf_index), leaf_hash);
210
211 let mut idx = leaf_index;
212 for level in 0..SMT_DEPTH {
213 let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
214 let left_idx = if idx % 2 == 0 { idx } else { sibling_idx };
215 let right_idx = if idx % 2 == 0 { sibling_idx } else { idx };
216
217 let left = self.get_node(level, left_idx);
218 let right = self.get_node(level, right_idx);
219 let parent = crate::zk::crypto::poseidon2_hash(env, params, &left, &right);
220
221 idx /= 2;
222 self.nodes.insert((level + 1, idx), parent);
223 }
224
225 self.root = self.get_node(SMT_DEPTH, 0);
226 Ok(())
227 }
228
229 fn get_node(&self, level: u32, index: u32) -> soroban_sdk::U256 {
231 self.nodes
232 .get(&(level, index))
233 .cloned()
234 .unwrap_or_else(|| self.defaults[level as usize].clone())
235 }
236}
237
238#[cfg(feature = "hazmat-crypto")]
240fn precompute_poseidon_defaults(
241 env: &Env,
242 params: &crate::zk::crypto::Poseidon2Params,
243) -> Vec<soroban_sdk::U256> {
244 let mut defaults = Vec::with_capacity(SMT_DEPTH as usize + 1);
245 defaults.push(soroban_sdk::U256::from_u32(env, 0)); for _ in 0..SMT_DEPTH {
248 let prev = defaults.last().unwrap();
249 defaults.push(crate::zk::crypto::poseidon2_hash(env, params, prev, prev));
250 }
251
252 defaults
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::zk::merkle::proof::verify_inclusion;
259
260 #[test]
261 fn test_empty_smt() {
262 let env = Env::default();
263 let smt = SparseMerkleTree::new(&env);
264 let root = smt.root();
265 assert_ne!(root, [0u8; 32]); }
268
269 #[test]
270 fn test_insert_and_get() {
271 let env = Env::default();
272 let mut smt = SparseMerkleTree::new(&env);
273
274 let key = [1u8; 32];
275 let value = [42u8; 32];
276
277 smt.insert(&env, &key, &value).unwrap();
278 let stored = smt.get(&key);
279 assert!(stored.is_some());
281 }
282
283 #[test]
284 fn test_insert_changes_root() {
285 let env = Env::default();
286 let mut smt = SparseMerkleTree::new(&env);
287 let initial_root = smt.root();
288
289 smt.insert(&env, &[1u8; 32], &[42u8; 32]).unwrap();
290 assert_ne!(smt.root(), initial_root);
291 }
292
293 #[test]
294 fn test_different_keys_different_roots() {
295 let env = Env::default();
296
297 let mut smt1 = SparseMerkleTree::new(&env);
298 smt1.insert(&env, &[1u8; 32], &[42u8; 32]).unwrap();
299
300 let mut smt2 = SparseMerkleTree::new(&env);
301 smt2.insert(&env, &[2u8; 32], &[42u8; 32]).unwrap();
302
303 assert_ne!(smt1.root(), smt2.root());
304 }
305
306 #[test]
307 fn test_prove_and_verify() {
308 let env = Env::default();
309 let mut smt = SparseMerkleTree::new(&env);
310
311 let key = [5u8; 32];
312 let value = [99u8; 32];
313 smt.insert(&env, &key, &value).unwrap();
314
315 let root = smt.root_bytes(&env);
316 let proof = smt.prove(&env, &key);
317
318 let result = verify_inclusion(&env, &proof, &root).unwrap();
319 assert!(result);
320 }
321
322 #[test]
323 fn test_prove_empty_key() {
324 let env = Env::default();
325 let smt = SparseMerkleTree::new(&env);
326
327 let key = [0u8; 32];
328 let root = smt.root_bytes(&env);
329 let proof = smt.prove(&env, &key);
330
331 let result = verify_inclusion(&env, &proof, &root).unwrap();
333 assert!(result);
334 }
335
336 #[test]
337 fn test_multiple_inserts() {
338 let env = Env::default();
339 let mut smt = SparseMerkleTree::new(&env);
340
341 for i in 0..10u8 {
342 let mut key = [0u8; 32];
343 key[0] = i;
344 let mut value = [0u8; 32];
345 value[0] = i + 100;
346 smt.insert(&env, &key, &value).unwrap();
347 }
348
349 let root = smt.root_bytes(&env);
351 for i in 0..10u8 {
352 let mut key = [0u8; 32];
353 key[0] = i;
354 let proof = smt.prove(&env, &key);
355 let result = verify_inclusion(&env, &proof, &root).unwrap();
356 assert!(result, "Proof failed for key {}", i);
357 }
358 }
359
360 #[test]
361 fn test_update_existing_key() {
362 let env = Env::default();
363 let mut smt = SparseMerkleTree::new(&env);
364
365 let key = [1u8; 32];
366 smt.insert(&env, &key, &[10u8; 32]).unwrap();
367 let root1 = smt.root();
368
369 smt.insert(&env, &key, &[20u8; 32]).unwrap();
370 let root2 = smt.root();
371
372 assert_ne!(root1, root2);
374
375 let proof = smt.prove(&env, &key);
377 let root = smt.root_bytes(&env);
378 assert!(verify_inclusion(&env, &proof, &root).unwrap());
379 }
380}