smt_circom/
lib.rs

1//! Sparse Merkle Tree, compatible with circom proofs.
2//!
3//! This crate implements a Poseidon-based Sparse Merkle Tree suitable for
4//! generating witnesses/proofs for circom circuits. Nodes are persisted via a
5//! pluggable [`NodeStore`] and keyed/hashed as 32-byte big-endian field
6//! elements over BN254.
7
8use ark_bn254::Fr;
9use core::array;
10use light_poseidon::{Poseidon, PoseidonBytesHasher};
11
12use crate::store::NodeStore;
13
14pub mod store;
15
16#[inline]
17fn poseidon_hash(inputs: &[&[u8]]) -> [u8; 32] {
18    let mut p = Poseidon::<Fr>::new_circom(inputs.len()).expect("poseidon init");
19    p.hash_bytes_be(inputs).expect("poseidon hash")
20}
21
22#[inline]
23fn leaf_key(k: [u8; 32], v: [u8; 32]) -> [u8; 32] {
24    let mut one = [0u8; 32];
25    one[31] = 1;
26    poseidon_hash(&[&k, &v, &one])
27}
28
29#[inline]
30fn mid_key(l: [u8; 32], r: [u8; 32]) -> [u8; 32] {
31    poseidon_hash(&[&l, &r])
32}
33
34/// Errors returned by SMT operations.
35#[derive(Debug, thiserror::Error)]
36pub enum Error<E> {
37    #[error("The key is already present")]
38    AlreadyPresent,
39    #[error("Key wasn't found")]
40    KeyNotFound,
41    #[error("Store error: {0}")]
42    Store(E),
43}
44
45/// A Sparse Merkle Tree node.
46#[derive(Clone, Copy, Debug)]
47pub enum Node {
48    Middle { l: [u8; 32], r: [u8; 32] },
49    Leaf { k: [u8; 32], v: [u8; 32] },
50}
51
52impl Node {
53    /// Encode a node to a compact 65-byte form:
54    ///
55    /// - byte 0: `0` for `Middle`, `1` for `Leaf`
56    /// - bytes 1..33, 33..65: two 32-byte fields (`l|r` or `k|v`)
57    pub fn encode(&self) -> [u8; 65] {
58        let mut out = [0u8; 65];
59        match self {
60            Node::Middle { l: left, r: right } => {
61                out[0] = 0;
62                out[1..33].copy_from_slice(left);
63                out[33..65].copy_from_slice(right);
64            }
65            Node::Leaf { k: index, v: value } => {
66                out[0] = 1;
67                out[1..33].copy_from_slice(index);
68                out[33..65].copy_from_slice(value);
69            }
70        }
71        out
72    }
73
74    /// Decode a node from its 65-byte encoding.
75    ///
76    /// Returns `None` if the buffer is not exactly 65 bytes or the tag is
77    /// invalid.
78    pub fn decode(bs: &[u8]) -> Option<Self> {
79        if bs.len() != 65 {
80            return None;
81        }
82        let mut a = [0u8; 32];
83        let mut b = [0u8; 32];
84        a.copy_from_slice(&bs[1..33]);
85        b.copy_from_slice(&bs[33..65]);
86        Some(match bs[0] {
87            0 => Node::Middle { l: a, r: b },
88            1 => Node::Leaf { k: a, v: b },
89            _ => return None,
90        })
91    }
92
93    fn key(&self) -> [u8; 32] {
94        match *self {
95            Node::Leaf { k, v } => leaf_key(k, v),
96            Node::Middle { l, r } => mid_key(l, r),
97        }
98    }
99}
100
101#[inline]
102fn get_path<const D: usize>(key: &[u8; 32]) -> [bool; D] {
103    array::from_fn(|i| {
104        let byte = i / 8;
105        let bit = i % 8;
106        (key[31 - byte] & (1 << bit)) != 0
107    })
108}
109
110/// Proof object tailored for circom circuits.
111#[derive(Clone, Debug, PartialEq, Eq)]
112pub struct CircomProof<const D: usize> {
113    pub siblings: [[u8; 32]; D],
114    pub is_old0: bool,
115    pub old_key: [u8; 32],
116    pub old_value: [u8; 32],
117    pub membership: bool,
118}
119
120impl<const D: usize> CircomProof<D> {
121    /// Retrieve the leaf value if the proof key is present.
122    pub fn get_leaf(&self) -> Option<&[u8; 32]> {
123        if self.membership {
124            Some(&self.old_value)
125        } else {
126            None
127        }
128    }
129}
130
131pub struct SparseMerkleTree<const D: usize, S: NodeStore> {
132    store: S,
133}
134
135impl<const D: usize, S: NodeStore> SparseMerkleTree<D, S> {
136    /// Construct a new tree from a store.
137    pub fn new(store: S) -> Result<Self, S::Error> {
138        Ok(Self { store })
139    }
140
141    /// Get the current root.
142    pub fn root(&self) -> Result<[u8; 32], S::Error> {
143        self.store.get_root()
144    }
145
146    fn put(&mut self, node: &Node) -> Result<[u8; 32], S::Error> {
147        let k = node.key();
148        self.store.put(k, node.encode())?;
149        Ok(k)
150    }
151
152    fn set_root(&mut self, root: [u8; 32]) -> Result<(), S::Error> {
153        self.store.set_root(root)
154    }
155
156    fn add_leaf(
157        &mut self,
158        new_leaf: Node,
159        cur_key: [u8; 32],
160        lvl: usize,
161        path_new: &[bool],
162    ) -> Result<[u8; 32], Error<S::Error>> {
163        let n = self.store.get(cur_key).expect("node exists");
164        match n {
165            None => Ok(self.put(&new_leaf).map_err(Error::Store)?),
166            Some(Node::Leaf { k: old_k, v: old_v }) => {
167                if let Node::Leaf { k: new_k, .. } = new_leaf {
168                    if new_k == old_k {
169                        return Err(Error::AlreadyPresent);
170                    }
171                } else {
172                    unreachable!();
173                }
174                let path_old = get_path::<D>(&old_k);
175                self.push_leaf(
176                    new_leaf,
177                    Node::Leaf { k: old_k, v: old_v },
178                    lvl,
179                    path_new,
180                    &path_old,
181                )
182                .map_err(Error::Store)
183            }
184            Some(Node::Middle { l, r }) => {
185                if path_new[lvl] {
186                    let next = self.add_leaf(new_leaf, r, lvl + 1, path_new)?;
187                    Ok(self
188                        .put(&Node::Middle { l, r: next })
189                        .map_err(Error::Store)?)
190                } else {
191                    let next = self.add_leaf(new_leaf, l, lvl + 1, path_new)?;
192                    Ok(self
193                        .put(&Node::Middle { l: next, r })
194                        .map_err(Error::Store)?)
195                }
196            }
197        }
198    }
199
200    fn push_leaf(
201        &mut self,
202        new_leaf: Node,
203        old_leaf: Node,
204        lvl: usize,
205        path_new: &[bool],
206        path_old: &[bool],
207    ) -> Result<[u8; 32], S::Error> {
208        if path_new[lvl] == path_old[lvl] {
209            let next_key = self.push_leaf(new_leaf, old_leaf, lvl + 1, path_new, path_old)?;
210            let mid = if path_new[lvl] {
211                Node::Middle {
212                    l: [0; 32],
213                    r: next_key,
214                }
215            } else {
216                Node::Middle {
217                    l: next_key,
218                    r: [0; 32],
219                }
220            };
221            return self.put(&mid);
222        }
223
224        let Node::Leaf { k: old_k, v: old_v } = old_leaf else {
225            unreachable!()
226        };
227
228        let new_leaf_key = self.put(&new_leaf)?;
229        let old_leaf_key = leaf_key(old_k, old_v);
230
231        let mid = if path_new[lvl] {
232            Node::Middle {
233                l: old_leaf_key,
234                r: new_leaf_key,
235            }
236        } else {
237            Node::Middle {
238                l: new_leaf_key,
239                r: old_leaf_key,
240            }
241        };
242        self.put(&mid)
243    }
244
245    /// Insert a new leaf.
246    pub fn add(&mut self, key: [u8; 32], val: [u8; 32]) -> Result<(), Error<S::Error>> {
247        let kh = key;
248        let vh = val;
249        let new_leaf = Node::Leaf { k: kh, v: vh };
250
251        let path_new = get_path::<D>(&kh);
252        let new_root = self.add_leaf(new_leaf, self.root().map_err(Error::Store)?, 0, &path_new)?;
253        self.set_root(new_root).map_err(Error::Store)?;
254        Ok(())
255    }
256
257    /// Update an existing leaf's value, returning the previous value.
258    pub fn update(&mut self, key: [u8; 32], val: [u8; 32]) -> Result<[u8; 32], Error<S::Error>> {
259        let kh = key;
260        let vh = val;
261        let mut cur = self.root().map_err(Error::Store)?;
262        let mut siblings = heapless::Vec::<[u8; 32], D>::new();
263        let path = get_path::<D>(&kh);
264        let old_v;
265
266        for go_right in path.iter().copied() {
267            match self.store.get(cur).expect("node exists") {
268                None => return Err(Error::KeyNotFound),
269                Some(Node::Leaf { k, v }) => {
270                    if k != kh {
271                        return Err(Error::KeyNotFound);
272                    }
273                    old_v = Some(v);
274
275                    let mut node = Node::Leaf { k: kh, v: vh };
276                    let mut node_h = self.put(&node).map_err(Error::Store)?;
277
278                    for (lvl, sib) in siblings.into_iter().enumerate().rev() {
279                        let bit = path[lvl];
280                        node = if bit {
281                            Node::Middle { l: sib, r: node_h }
282                        } else {
283                            Node::Middle { l: node_h, r: sib }
284                        };
285                        node_h = self.put(&node).map_err(Error::Store)?;
286                    }
287                    self.set_root(node_h).map_err(Error::Store)?;
288                    return Ok(old_v.unwrap());
289                }
290                Some(Node::Middle { l, r }) => {
291                    if go_right {
292                        siblings.push(l).unwrap();
293                        cur = r;
294                    } else {
295                        siblings.push(r).unwrap();
296                        cur = l;
297                    }
298                }
299            }
300        }
301        Err(Error::KeyNotFound)
302    }
303
304    /// Build a circom-compatible proof for `key` inclusion or non-inclusion.
305    pub fn get_proof(&self, key: [u8; 32]) -> Result<CircomProof<D>, S::Error> {
306        let k = key;
307        let mut siblings = [[0; 32]; D];
308        let mut sibling_i = 0;
309        let mut cur = self.root()?;
310
311        for (i, go_right) in get_path::<D>(&k).into_iter().enumerate() {
312            match self.store.get(cur).expect("node exists") {
313                None => {
314                    return Ok(CircomProof {
315                        old_key: [0; 32],
316                        old_value: [0; 32],
317                        is_old0: true,
318                        siblings,
319                        membership: false,
320                    });
321                }
322                Some(Node::Leaf {
323                    k: leaf_k,
324                    v: leaf_v,
325                }) => {
326                    return Ok(CircomProof {
327                        old_key: leaf_k,
328                        old_value: leaf_v,
329                        is_old0: leaf_k == [0; 32],
330                        siblings,
331                        membership: leaf_k == k,
332                    });
333                }
334                Some(Node::Middle { l, r }) => {
335                    if go_right {
336                        siblings[sibling_i] = l;
337                        cur = r;
338                    } else {
339                        siblings[sibling_i] = r;
340                        cur = l;
341                    }
342                    sibling_i += 1;
343                }
344            }
345            if i == D - 1 {
346                return Ok(CircomProof {
347                    old_key: [0; 32],
348                    old_value: [0; 32],
349                    is_old0: true,
350                    siblings,
351                    membership: false,
352                });
353            }
354        }
355        unreachable!();
356    }
357
358    /// Insert or update (insert if missing, otherwise replace the value).
359    pub fn add_or_update(&mut self, key: [u8; 32], val: [u8; 32]) -> Result<(), Error<S::Error>> {
360        match self.add(key, val) {
361            Err(Error::AlreadyPresent) => self.update(key, val).map(|_| ()),
362            x => x,
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::store::MemStore;
371
372    const DEPTH: usize = 64;
373
374    #[test]
375    fn test_smt() {
376        let mut t = SparseMerkleTree::<DEPTH, _>::new(MemStore::new()).unwrap();
377        assert_eq!(t.root().unwrap(), [0; 32]);
378
379        let k1 = [
380            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 43, 127, 78,
381            51, 93, 159, 92, 71,
382        ];
383        let v1 = [
384            16, 232, 248, 117, 61, 208, 169, 22, 163, 170, 44, 57, 210, 21, 42, 219, 91, 147, 79,
385            94, 181, 31, 210, 205, 159, 82, 222, 81, 110, 255, 37, 198,
386        ];
387        let p1 = t.get_proof(k1).unwrap();
388        assert!(p1.get_leaf().is_none());
389        t.add_or_update(k1, v1).unwrap();
390        assert_eq!(t.get_proof(k1).unwrap().get_leaf(), Some(&v1));
391        assert!(!p1.membership);
392        assert!(p1.is_old0);
393        assert_eq!(p1.old_key, [0; 32]);
394        assert_eq!(p1.old_value, [0; 32]);
395        assert_eq!(p1.siblings.len(), DEPTH);
396        assert!(p1.siblings.iter().all(|&b| b == [0; 32]));
397
398        let root1 = t.root().unwrap();
399        let root1_js = [
400            37, 18, 9, 85, 224, 252, 133, 154, 45, 120, 67, 166, 143, 180, 254, 196, 219, 139, 9,
401            229, 191, 47, 36, 89, 138, 111, 104, 170, 242, 127, 191, 38,
402        ];
403        assert_eq!(root1, root1_js);
404
405        let k2 = [
406            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 211, 160, 91,
407            130, 253, 193, 133, 52,
408        ];
409        let v2 = [
410            2, 135, 56, 32, 251, 187, 59, 31, 232, 236, 204, 116, 101, 171, 47, 15, 159, 138, 139,
411            231, 61, 78, 108, 10, 70, 133, 200, 198, 187, 100, 85, 178,
412        ];
413        let p2 = t.get_proof(k2).unwrap();
414        assert!(p2.get_leaf().is_none());
415        t.add_or_update(k2, v2).unwrap();
416        assert_eq!(t.get_proof(k2).unwrap().get_leaf(), Some(&v2));
417        assert!(!p2.membership);
418        assert!(!p2.is_old0);
419        assert_eq!(p2.old_key, k1);
420        assert_eq!(p2.old_value, v1);
421        assert_eq!(p2.siblings.len(), DEPTH);
422        assert!(p2.siblings.iter().all(|&b| b == [0; 32]));
423
424        let k3 = [
425            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74, 181, 123,
426            89, 155, 208, 255, 114,
427        ];
428        let v3 = [
429            16, 46, 63, 228, 134, 35, 92, 132, 114, 153, 57, 23, 154, 224, 217, 112, 131, 208, 134,
430            232, 218, 170, 173, 245, 178, 128, 151, 223, 2, 64, 114, 19,
431        ];
432        let p3 = t.get_proof(k3).unwrap();
433        assert!(p3.get_leaf().is_none());
434        t.add_or_update(k3, v3).unwrap();
435        assert_eq!(t.get_proof(k3).unwrap().get_leaf(), Some(&v3));
436        assert!(!p3.membership);
437        assert!(!p3.is_old0);
438        assert_eq!(p3.old_key, k2);
439        assert_eq!(p3.old_value, v2);
440        assert_eq!(p3.siblings.len(), DEPTH);
441        assert_eq!(p3.siblings[0], root1_js);
442        assert!(p3.siblings[1..].iter().all(|&b| b == [0; 32]));
443
444        let v4 = [
445            34, 105, 95, 86, 39, 160, 123, 45, 219, 68, 91, 94, 55, 161, 223, 203, 206, 164, 203,
446            253, 33, 59, 150, 111, 108, 74, 20, 17, 62, 214, 104, 58,
447        ];
448        let p4 = t.get_proof(k3).unwrap();
449        t.add_or_update(k3, v4).unwrap();
450        assert_eq!(t.get_proof(k3).unwrap().get_leaf(), Some(&v4));
451        assert!(p4.membership);
452        assert!(!p4.is_old0);
453        assert_eq!(p4.old_key, k3);
454        assert_eq!(p4.old_value, v3);
455        assert_eq!(p4.siblings.len(), DEPTH);
456        assert_eq!(p4.siblings[0], root1_js);
457        assert_eq!(
458            p4.siblings[1],
459            [
460                39, 2, 121, 120, 126, 69, 90, 96, 220, 95, 224, 252, 255, 197, 106, 214, 4, 22,
461                155, 164, 67, 176, 180, 82, 34, 37, 226, 17, 201, 250, 187, 58
462            ],
463        );
464        assert!(p4.siblings[2..].iter().all(|&b| b == [0; 32]));
465
466        assert!(t.get_proof([0; 32]).unwrap().get_leaf().is_none());
467        t.add([0; 32], [0; 32]).unwrap();
468        assert_eq!(t.get_proof([0; 32]).unwrap().get_leaf(), Some(&[0; 32]));
469
470        assert!(t.get_proof([1; 32]).unwrap().get_leaf().is_none());
471        t.add([1; 32], [1; 32]).unwrap();
472        assert_eq!(t.get_proof([1; 32]).unwrap().get_leaf(), Some(&[1; 32]));
473    }
474}