Skip to main content

cougr_core/zk/merkle/
sparse.rs

1//! Sparse Merkle tree for key-value state.
2//!
3//! A sparse Merkle tree (SMT) represents a key-value map where
4//! keys are 256-bit hashes. Most of the tree is "empty" (default values),
5//! and only non-empty paths are stored.
6
7use crate::zk::error::ZKError;
8use alloc::collections::BTreeMap;
9use alloc::vec::Vec;
10use soroban_sdk::{BytesN, Env};
11
12use super::proof::OnChainMerkleProof;
13
14/// Fixed depth for the sparse Merkle tree (256 bits = SHA256 output size).
15/// In practice we use a smaller depth for gas efficiency.
16const SMT_DEPTH: u32 = 16; // 2^16 = 65536 slots
17
18/// Sparse Merkle tree for key-value state (runtime-only).
19///
20/// Uses a fixed depth and precomputed default hashes for empty subtrees.
21/// Only non-default nodes are stored, keeping memory usage proportional
22/// to the number of actual entries.
23pub struct SparseMerkleTree {
24    root: [u8; 32],
25    nodes: BTreeMap<(u32, u32), [u8; 32]>, // (level, index) -> hash
26    defaults: Vec<[u8; 32]>,               // precomputed default hashes per level
27}
28
29impl SparseMerkleTree {
30    /// Create a new empty sparse Merkle tree.
31    pub fn new(env: &Env) -> Self {
32        let defaults = precompute_defaults(env);
33        let root = defaults[SMT_DEPTH as usize];
34
35        Self {
36            root,
37            nodes: BTreeMap::new(),
38            defaults,
39        }
40    }
41
42    /// Returns the root hash.
43    pub fn root(&self) -> [u8; 32] {
44        self.root
45    }
46
47    /// Returns the root hash as `BytesN<32>`.
48    pub fn root_bytes(&self, env: &Env) -> BytesN<32> {
49        BytesN::from_array(env, &self.root)
50    }
51
52    /// Insert or update a key-value pair and recompute the root.
53    ///
54    /// The key determines the leaf position (lower 16 bits used as index).
55    pub fn insert(&mut self, env: &Env, key: &[u8; 32], value: &[u8; 32]) -> Result<(), ZKError> {
56        let leaf_index = key_to_index(key);
57        let leaf_hash = hash_leaf(env, value);
58
59        // Set the leaf
60        self.nodes.insert((0, leaf_index), leaf_hash);
61
62        // Recompute path from leaf to root
63        let mut idx = leaf_index;
64        for level in 0..SMT_DEPTH {
65            let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
66            let left_idx = if idx % 2 == 0 { idx } else { sibling_idx };
67            let right_idx = if idx % 2 == 0 { sibling_idx } else { idx };
68
69            let left = self.get_node(level, left_idx);
70            let right = self.get_node(level, right_idx);
71            let parent = hash_pair(env, &left, &right);
72
73            idx /= 2;
74            self.nodes.insert((level + 1, idx), parent);
75        }
76
77        self.root = self.get_node(SMT_DEPTH, 0);
78        Ok(())
79    }
80
81    /// Get a value by key, if it exists.
82    pub fn get(&self, key: &[u8; 32]) -> Option<[u8; 32]> {
83        let leaf_index = key_to_index(key);
84        self.nodes.get(&(0, leaf_index)).copied()
85    }
86
87    /// Generate an inclusion proof for a key.
88    pub fn prove(&self, env: &Env, key: &[u8; 32]) -> OnChainMerkleProof {
89        let leaf_index = key_to_index(key);
90        let leaf = self.get_node(0, leaf_index);
91
92        let mut siblings: soroban_sdk::Vec<BytesN<32>> = soroban_sdk::Vec::new(env);
93        let mut path_bits: u32 = 0;
94        let mut idx = leaf_index;
95
96        for level in 0..SMT_DEPTH {
97            let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
98            let sibling = self.get_node(level, sibling_idx);
99            siblings.push_back(BytesN::from_array(env, &sibling));
100
101            if idx % 2 != 0 {
102                path_bits |= 1 << level;
103            }
104            idx /= 2;
105        }
106
107        OnChainMerkleProof {
108            siblings,
109            path_bits,
110            leaf: BytesN::from_array(env, &leaf),
111            leaf_index,
112            depth: SMT_DEPTH,
113        }
114    }
115
116    /// Get a node hash, falling back to the default for that level.
117    fn get_node(&self, level: u32, index: u32) -> [u8; 32] {
118        self.nodes
119            .get(&(level, index))
120            .copied()
121            .unwrap_or(self.defaults[level as usize])
122    }
123}
124
125/// Map a 32-byte key to a leaf index (lower bits).
126fn key_to_index(key: &[u8; 32]) -> u32 {
127    let b0 = key[0] as u32;
128    let b1 = key[1] as u32;
129    (b0 | (b1 << 8)) % (1 << SMT_DEPTH)
130}
131
132/// Precompute default hashes for each level of the tree.
133/// Level 0 default = all zeros (empty leaf).
134/// Level n default = H(default[n-1], default[n-1]).
135fn precompute_defaults(env: &Env) -> Vec<[u8; 32]> {
136    let mut defaults = Vec::with_capacity(SMT_DEPTH as usize + 1);
137    defaults.push([0u8; 32]); // level 0: empty leaf
138
139    for _ in 0..SMT_DEPTH {
140        let prev = defaults.last().unwrap();
141        defaults.push(hash_pair(env, prev, prev));
142    }
143
144    defaults
145}
146
147/// Hash a leaf: SHA256(0x00 || data).
148fn hash_leaf(env: &Env, data: &[u8; 32]) -> [u8; 32] {
149    let mut input = [0u8; 33];
150    input[0] = 0x00;
151    input[1..].copy_from_slice(data);
152    let bytes = soroban_sdk::Bytes::from_slice(env, &input);
153    env.crypto().sha256(&bytes).to_array()
154}
155
156/// Hash two children: SHA256(0x01 || left || right).
157fn hash_pair(env: &Env, left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
158    let mut input = [0u8; 65];
159    input[0] = 0x01;
160    input[1..33].copy_from_slice(left);
161    input[33..65].copy_from_slice(right);
162    let bytes = soroban_sdk::Bytes::from_slice(env, &input);
163    env.crypto().sha256(&bytes).to_array()
164}
165
166// ─── Poseidon2-based Sparse Merkle Tree ─────────────────────────────
167
168/// Poseidon2-based sparse Merkle tree for ZK-friendly key-value state.
169///
170/// Requires the `hazmat-crypto` feature.
171#[cfg(feature = "hazmat-crypto")]
172pub struct PoseidonSparseMerkleTree {
173    root: soroban_sdk::U256,
174    nodes: BTreeMap<(u32, u32), soroban_sdk::U256>,
175    defaults: Vec<soroban_sdk::U256>,
176}
177
178#[cfg(feature = "hazmat-crypto")]
179impl PoseidonSparseMerkleTree {
180    /// Create a new empty Poseidon sparse Merkle tree.
181    pub fn new(env: &Env, params: &crate::zk::crypto::Poseidon2Params) -> Self {
182        let defaults = precompute_poseidon_defaults(env, params);
183        let root = defaults[SMT_DEPTH as usize].clone();
184
185        Self {
186            root,
187            nodes: BTreeMap::new(),
188            defaults,
189        }
190    }
191
192    /// Returns the root hash.
193    pub fn root(&self) -> soroban_sdk::U256 {
194        self.root.clone()
195    }
196
197    /// Insert or update a key-value pair and recompute the root.
198    pub fn insert(
199        &mut self,
200        env: &Env,
201        params: &crate::zk::crypto::Poseidon2Params,
202        key: &[u8; 32],
203        value: &soroban_sdk::U256,
204    ) -> Result<(), crate::zk::error::ZKError> {
205        let leaf_index = key_to_index(key);
206        let zero = soroban_sdk::U256::from_u32(env, 0);
207        let leaf_hash = crate::zk::crypto::poseidon2_hash(env, params, value, &zero);
208
209        self.nodes.insert((0, leaf_index), leaf_hash);
210
211        let mut idx = leaf_index;
212        for level in 0..SMT_DEPTH {
213            let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
214            let left_idx = if idx % 2 == 0 { idx } else { sibling_idx };
215            let right_idx = if idx % 2 == 0 { sibling_idx } else { idx };
216
217            let left = self.get_node(level, left_idx);
218            let right = self.get_node(level, right_idx);
219            let parent = crate::zk::crypto::poseidon2_hash(env, params, &left, &right);
220
221            idx /= 2;
222            self.nodes.insert((level + 1, idx), parent);
223        }
224
225        self.root = self.get_node(SMT_DEPTH, 0);
226        Ok(())
227    }
228
229    /// Get a node hash, falling back to the default for that level.
230    fn get_node(&self, level: u32, index: u32) -> soroban_sdk::U256 {
231        self.nodes
232            .get(&(level, index))
233            .cloned()
234            .unwrap_or_else(|| self.defaults[level as usize].clone())
235    }
236}
237
238/// Precompute default Poseidon hashes for each level of the sparse tree.
239#[cfg(feature = "hazmat-crypto")]
240fn precompute_poseidon_defaults(
241    env: &Env,
242    params: &crate::zk::crypto::Poseidon2Params,
243) -> Vec<soroban_sdk::U256> {
244    let mut defaults = Vec::with_capacity(SMT_DEPTH as usize + 1);
245    defaults.push(soroban_sdk::U256::from_u32(env, 0)); // level 0: empty leaf
246
247    for _ in 0..SMT_DEPTH {
248        let prev = defaults.last().unwrap();
249        defaults.push(crate::zk::crypto::poseidon2_hash(env, params, prev, prev));
250    }
251
252    defaults
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::zk::merkle::proof::verify_inclusion;
259
260    #[test]
261    fn test_empty_smt() {
262        let env = Env::default();
263        let smt = SparseMerkleTree::new(&env);
264        let root = smt.root();
265        // Root of empty tree is the precomputed default at depth 16
266        assert_ne!(root, [0u8; 32]); // it's H(H(H(...))) not raw zeros
267    }
268
269    #[test]
270    fn test_insert_and_get() {
271        let env = Env::default();
272        let mut smt = SparseMerkleTree::new(&env);
273
274        let key = [1u8; 32];
275        let value = [42u8; 32];
276
277        smt.insert(&env, &key, &value).unwrap();
278        let stored = smt.get(&key);
279        // get() returns the leaf hash, not the raw value
280        assert!(stored.is_some());
281    }
282
283    #[test]
284    fn test_insert_changes_root() {
285        let env = Env::default();
286        let mut smt = SparseMerkleTree::new(&env);
287        let initial_root = smt.root();
288
289        smt.insert(&env, &[1u8; 32], &[42u8; 32]).unwrap();
290        assert_ne!(smt.root(), initial_root);
291    }
292
293    #[test]
294    fn test_different_keys_different_roots() {
295        let env = Env::default();
296
297        let mut smt1 = SparseMerkleTree::new(&env);
298        smt1.insert(&env, &[1u8; 32], &[42u8; 32]).unwrap();
299
300        let mut smt2 = SparseMerkleTree::new(&env);
301        smt2.insert(&env, &[2u8; 32], &[42u8; 32]).unwrap();
302
303        assert_ne!(smt1.root(), smt2.root());
304    }
305
306    #[test]
307    fn test_prove_and_verify() {
308        let env = Env::default();
309        let mut smt = SparseMerkleTree::new(&env);
310
311        let key = [5u8; 32];
312        let value = [99u8; 32];
313        smt.insert(&env, &key, &value).unwrap();
314
315        let root = smt.root_bytes(&env);
316        let proof = smt.prove(&env, &key);
317
318        let result = verify_inclusion(&env, &proof, &root).unwrap();
319        assert!(result);
320    }
321
322    #[test]
323    fn test_prove_empty_key() {
324        let env = Env::default();
325        let smt = SparseMerkleTree::new(&env);
326
327        let key = [0u8; 32];
328        let root = smt.root_bytes(&env);
329        let proof = smt.prove(&env, &key);
330
331        // Proof for empty key should verify (it's a valid default path)
332        let result = verify_inclusion(&env, &proof, &root).unwrap();
333        assert!(result);
334    }
335
336    #[test]
337    fn test_multiple_inserts() {
338        let env = Env::default();
339        let mut smt = SparseMerkleTree::new(&env);
340
341        for i in 0..10u8 {
342            let mut key = [0u8; 32];
343            key[0] = i;
344            let mut value = [0u8; 32];
345            value[0] = i + 100;
346            smt.insert(&env, &key, &value).unwrap();
347        }
348
349        // Verify all 10 proofs
350        let root = smt.root_bytes(&env);
351        for i in 0..10u8 {
352            let mut key = [0u8; 32];
353            key[0] = i;
354            let proof = smt.prove(&env, &key);
355            let result = verify_inclusion(&env, &proof, &root).unwrap();
356            assert!(result, "Proof failed for key {}", i);
357        }
358    }
359
360    #[test]
361    fn test_update_existing_key() {
362        let env = Env::default();
363        let mut smt = SparseMerkleTree::new(&env);
364
365        let key = [1u8; 32];
366        smt.insert(&env, &key, &[10u8; 32]).unwrap();
367        let root1 = smt.root();
368
369        smt.insert(&env, &key, &[20u8; 32]).unwrap();
370        let root2 = smt.root();
371
372        // Updating value should change root
373        assert_ne!(root1, root2);
374
375        // New proof should verify
376        let proof = smt.prove(&env, &key);
377        let root = smt.root_bytes(&env);
378        assert!(verify_inclusion(&env, &proof, &root).unwrap());
379    }
380}