Skip to main content

rootchain_crypto/
smt.rs

1use blake3::Hasher;
2use rootchain_core::types::Hash;
3use std::collections::HashMap;
4
5pub const SMT_DEPTH: usize = 256;
6
7lazy_static::lazy_static! {
8    pub static ref DEFAULT_HASHES: [Hash; SMT_DEPTH + 1] = {
9        let mut hashes = [Hash::zero(); SMT_DEPTH + 1];
10        let mut current = Hash::zero();
11        hashes[0] = current;
12        for hash in hashes.iter_mut().take(SMT_DEPTH + 1).skip(1) {
13            let mut hasher = Hasher::new();
14            hasher.update(&current.0);
15            hasher.update(&current.0);
16            current = Hash::from_bytes(hasher.finalize().into());
17            *hash = current;
18        }
19        hashes
20    };
21}
22
23use serde::{Deserialize, Serialize};
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
26pub struct SparseMerkleTree {
27    /// nodes: HashMap<(level, index_hash), value_hash>
28    /// For a 256-level tree, index_hash is the path truncated or full.
29    /// Simplified: We store only non-default nodes.
30    pub nodes: HashMap<(usize, Hash), Hash>,
31}
32
33impl SparseMerkleTree {
34    pub fn new() -> Self {
35        Self {
36            nodes: HashMap::new(),
37        }
38    }
39
40    pub fn root(&self) -> Hash {
41        self.nodes
42            .get(&(SMT_DEPTH, Hash::zero()))
43            .cloned()
44            .unwrap_or(DEFAULT_HASHES[SMT_DEPTH])
45    }
46
47    pub fn update(&mut self, key: Hash, value: Hash) {
48        let mut current_hash = value;
49        let mut current_index = key;
50
51        self.nodes.insert((0, current_index), current_hash);
52
53        for level in 0..SMT_DEPTH {
54            let is_right = (current_index.0[level / 8] >> (7 - (level % 8))) & 1 == 1;
55            let sibling_index = self.get_sibling_index(current_index, level);
56
57            let sibling_hash = self
58                .nodes
59                .get(&(level, sibling_index))
60                .cloned()
61                .unwrap_or(DEFAULT_HASHES[level]);
62
63            let mut hasher = Hasher::new();
64            if is_right {
65                hasher.update(&sibling_hash.0);
66                hasher.update(&current_hash.0);
67            } else {
68                hasher.update(&current_hash.0);
69                hasher.update(&sibling_hash.0);
70            }
71            current_hash = Hash::from_bytes(hasher.finalize().into());
72
73            // Move up: parent index is current index with the bit at 'level' cleared (if simplified)
74            // In a true SMT, parent index is just the path prefix.
75            current_index = self.get_parent_index(current_index, level);
76            self.nodes.insert((level + 1, current_index), current_hash);
77        }
78    }
79
80    pub fn get_proof(&self, key: Hash) -> Vec<Hash> {
81        let mut proof = Vec::with_capacity(SMT_DEPTH);
82        let mut current_index = key;
83
84        for level in 0..SMT_DEPTH {
85            let sibling_index = self.get_sibling_index(current_index, level);
86            let sibling_hash = self
87                .nodes
88                .get(&(level, sibling_index))
89                .cloned()
90                .unwrap_or(DEFAULT_HASHES[level]);
91            proof.push(sibling_hash);
92            current_index = self.get_parent_index(current_index, level);
93        }
94        proof
95    }
96
97    fn get_sibling_index(&self, index: Hash, level: usize) -> Hash {
98        let mut sibling = index;
99        let byte_idx = level / 8;
100        let bit_idx = 7 - (level % 8);
101        sibling.0[byte_idx] ^= 1 << bit_idx;
102        sibling
103    }
104
105    fn get_parent_index(&self, index: Hash, level: usize) -> Hash {
106        let mut parent = index;
107        let byte_idx = level / 8;
108        let bit_idx = 7 - (level % 8);
109        parent.0[byte_idx] &= !(1 << bit_idx);
110        parent
111    }
112}
113
114pub fn verify_smt_proof(root: Hash, key: Hash, value: Hash, proof: &[Hash]) -> bool {
115    if proof.len() != SMT_DEPTH {
116        return false;
117    }
118
119    let mut current_hash = value;
120    let mut current_index = key;
121
122    for (level, &sibling_hash) in proof.iter().enumerate().take(SMT_DEPTH) {
123        let is_right = (current_index.0[level / 8] >> (7 - (level % 8))) & 1 == 1;
124
125        let mut hasher = Hasher::new();
126        if is_right {
127            hasher.update(&sibling_hash.0);
128            hasher.update(&current_hash.0);
129        } else {
130            hasher.update(&current_hash.0);
131            hasher.update(&sibling_hash.0);
132        }
133        current_hash = Hash::from_bytes(hasher.finalize().into());
134
135        let byte_idx = level / 8;
136        let bit_idx = 7 - (level % 8);
137        current_index.0[byte_idx] &= !(1 << bit_idx);
138    }
139
140    current_hash == root
141}