1use blake3::Hasher;
2use rootchain_core::types::Hash;
3use std::collections::HashMap;
4
5pub const SMT_DEPTH: usize = 256;
6
7lazy_static::lazy_static! {
8 pub static ref DEFAULT_HASHES: [Hash; SMT_DEPTH + 1] = {
9 let mut hashes = [Hash::zero(); SMT_DEPTH + 1];
10 let mut current = Hash::zero();
11 hashes[0] = current;
12 for hash in hashes.iter_mut().take(SMT_DEPTH + 1).skip(1) {
13 let mut hasher = Hasher::new();
14 hasher.update(¤t.0);
15 hasher.update(¤t.0);
16 current = Hash::from_bytes(hasher.finalize().into());
17 *hash = current;
18 }
19 hashes
20 };
21}
22
23use serde::{Deserialize, Serialize};
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
26pub struct SparseMerkleTree {
27 pub nodes: HashMap<(usize, Hash), Hash>,
31}
32
33impl SparseMerkleTree {
34 pub fn new() -> Self {
35 Self {
36 nodes: HashMap::new(),
37 }
38 }
39
40 pub fn root(&self) -> Hash {
41 self.nodes
42 .get(&(SMT_DEPTH, Hash::zero()))
43 .cloned()
44 .unwrap_or(DEFAULT_HASHES[SMT_DEPTH])
45 }
46
47 pub fn update(&mut self, key: Hash, value: Hash) {
48 let mut current_hash = value;
49 let mut current_index = key;
50
51 self.nodes.insert((0, current_index), current_hash);
52
53 for level in 0..SMT_DEPTH {
54 let is_right = (current_index.0[level / 8] >> (7 - (level % 8))) & 1 == 1;
55 let sibling_index = self.get_sibling_index(current_index, level);
56
57 let sibling_hash = self
58 .nodes
59 .get(&(level, sibling_index))
60 .cloned()
61 .unwrap_or(DEFAULT_HASHES[level]);
62
63 let mut hasher = Hasher::new();
64 if is_right {
65 hasher.update(&sibling_hash.0);
66 hasher.update(¤t_hash.0);
67 } else {
68 hasher.update(¤t_hash.0);
69 hasher.update(&sibling_hash.0);
70 }
71 current_hash = Hash::from_bytes(hasher.finalize().into());
72
73 current_index = self.get_parent_index(current_index, level);
76 self.nodes.insert((level + 1, current_index), current_hash);
77 }
78 }
79
80 pub fn get_proof(&self, key: Hash) -> Vec<Hash> {
81 let mut proof = Vec::with_capacity(SMT_DEPTH);
82 let mut current_index = key;
83
84 for level in 0..SMT_DEPTH {
85 let sibling_index = self.get_sibling_index(current_index, level);
86 let sibling_hash = self
87 .nodes
88 .get(&(level, sibling_index))
89 .cloned()
90 .unwrap_or(DEFAULT_HASHES[level]);
91 proof.push(sibling_hash);
92 current_index = self.get_parent_index(current_index, level);
93 }
94 proof
95 }
96
97 fn get_sibling_index(&self, index: Hash, level: usize) -> Hash {
98 let mut sibling = index;
99 let byte_idx = level / 8;
100 let bit_idx = 7 - (level % 8);
101 sibling.0[byte_idx] ^= 1 << bit_idx;
102 sibling
103 }
104
105 fn get_parent_index(&self, index: Hash, level: usize) -> Hash {
106 let mut parent = index;
107 let byte_idx = level / 8;
108 let bit_idx = 7 - (level % 8);
109 parent.0[byte_idx] &= !(1 << bit_idx);
110 parent
111 }
112}
113
114pub fn verify_smt_proof(root: Hash, key: Hash, value: Hash, proof: &[Hash]) -> bool {
115 if proof.len() != SMT_DEPTH {
116 return false;
117 }
118
119 let mut current_hash = value;
120 let mut current_index = key;
121
122 for (level, &sibling_hash) in proof.iter().enumerate().take(SMT_DEPTH) {
123 let is_right = (current_index.0[level / 8] >> (7 - (level % 8))) & 1 == 1;
124
125 let mut hasher = Hasher::new();
126 if is_right {
127 hasher.update(&sibling_hash.0);
128 hasher.update(¤t_hash.0);
129 } else {
130 hasher.update(¤t_hash.0);
131 hasher.update(&sibling_hash.0);
132 }
133 current_hash = Hash::from_bytes(hasher.finalize().into());
134
135 let byte_idx = level / 8;
136 let bit_idx = 7 - (level % 8);
137 current_index.0[byte_idx] &= !(1 << bit_idx);
138 }
139
140 current_hash == root
141}