1use ark_bn254::Fr;
9use core::array;
10use light_poseidon::{Poseidon, PoseidonBytesHasher};
11
12use crate::store::NodeStore;
13
14pub mod store;
15
16#[inline]
17fn poseidon_hash(inputs: &[&[u8]]) -> [u8; 32] {
18 let mut p = Poseidon::<Fr>::new_circom(inputs.len()).expect("poseidon init");
19 p.hash_bytes_be(inputs).expect("poseidon hash")
20}
21
22#[inline]
23fn leaf_key(k: [u8; 32], v: [u8; 32]) -> [u8; 32] {
24 let mut one = [0u8; 32];
25 one[31] = 1;
26 poseidon_hash(&[&k, &v, &one])
27}
28
29#[inline]
30fn mid_key(l: [u8; 32], r: [u8; 32]) -> [u8; 32] {
31 poseidon_hash(&[&l, &r])
32}
33
34#[derive(Debug, thiserror::Error)]
36pub enum Error<E> {
37 #[error("The key is already present")]
38 AlreadyPresent,
39 #[error("Key wasn't found")]
40 KeyNotFound,
41 #[error("Store error: {0}")]
42 Store(E),
43}
44
45#[derive(Clone, Copy, Debug)]
47pub enum Node {
48 Middle { l: [u8; 32], r: [u8; 32] },
49 Leaf { k: [u8; 32], v: [u8; 32] },
50}
51
52impl Node {
53 pub fn encode(&self) -> [u8; 65] {
58 let mut out = [0u8; 65];
59 match self {
60 Node::Middle { l: left, r: right } => {
61 out[0] = 0;
62 out[1..33].copy_from_slice(left);
63 out[33..65].copy_from_slice(right);
64 }
65 Node::Leaf { k: index, v: value } => {
66 out[0] = 1;
67 out[1..33].copy_from_slice(index);
68 out[33..65].copy_from_slice(value);
69 }
70 }
71 out
72 }
73
74 pub fn decode(bs: &[u8]) -> Option<Self> {
79 if bs.len() != 65 {
80 return None;
81 }
82 let mut a = [0u8; 32];
83 let mut b = [0u8; 32];
84 a.copy_from_slice(&bs[1..33]);
85 b.copy_from_slice(&bs[33..65]);
86 Some(match bs[0] {
87 0 => Node::Middle { l: a, r: b },
88 1 => Node::Leaf { k: a, v: b },
89 _ => return None,
90 })
91 }
92
93 fn key(&self) -> [u8; 32] {
94 match *self {
95 Node::Leaf { k, v } => leaf_key(k, v),
96 Node::Middle { l, r } => mid_key(l, r),
97 }
98 }
99}
100
101#[inline]
102fn get_path<const D: usize>(key: &[u8; 32]) -> [bool; D] {
103 array::from_fn(|i| {
104 let byte = i / 8;
105 let bit = i % 8;
106 (key[31 - byte] & (1 << bit)) != 0
107 })
108}
109
110#[derive(Clone, Debug, PartialEq, Eq)]
112pub struct CircomProof<const D: usize> {
113 pub siblings: [[u8; 32]; D],
114 pub is_old0: bool,
115 pub old_key: [u8; 32],
116 pub old_value: [u8; 32],
117 pub membership: bool,
118}
119
120impl<const D: usize> CircomProof<D> {
121 pub fn get_leaf(&self) -> Option<&[u8; 32]> {
123 if self.membership {
124 Some(&self.old_value)
125 } else {
126 None
127 }
128 }
129}
130
131pub struct SparseMerkleTree<const D: usize, S: NodeStore> {
132 store: S,
133}
134
135impl<const D: usize, S: NodeStore> SparseMerkleTree<D, S> {
136 pub fn new(store: S) -> Result<Self, S::Error> {
138 Ok(Self { store })
139 }
140
141 pub fn root(&self) -> Result<[u8; 32], S::Error> {
143 self.store.get_root()
144 }
145
146 fn put(&mut self, node: &Node) -> Result<[u8; 32], S::Error> {
147 let k = node.key();
148 self.store.put(k, node.encode())?;
149 Ok(k)
150 }
151
152 fn set_root(&mut self, root: [u8; 32]) -> Result<(), S::Error> {
153 self.store.set_root(root)
154 }
155
156 fn add_leaf(
157 &mut self,
158 new_leaf: Node,
159 cur_key: [u8; 32],
160 lvl: usize,
161 path_new: &[bool],
162 ) -> Result<[u8; 32], Error<S::Error>> {
163 let n = self.store.get(cur_key).expect("node exists");
164 match n {
165 None => Ok(self.put(&new_leaf).map_err(Error::Store)?),
166 Some(Node::Leaf { k: old_k, v: old_v }) => {
167 if let Node::Leaf { k: new_k, .. } = new_leaf {
168 if new_k == old_k {
169 return Err(Error::AlreadyPresent);
170 }
171 } else {
172 unreachable!();
173 }
174 let path_old = get_path::<D>(&old_k);
175 self.push_leaf(
176 new_leaf,
177 Node::Leaf { k: old_k, v: old_v },
178 lvl,
179 path_new,
180 &path_old,
181 )
182 .map_err(Error::Store)
183 }
184 Some(Node::Middle { l, r }) => {
185 if path_new[lvl] {
186 let next = self.add_leaf(new_leaf, r, lvl + 1, path_new)?;
187 Ok(self
188 .put(&Node::Middle { l, r: next })
189 .map_err(Error::Store)?)
190 } else {
191 let next = self.add_leaf(new_leaf, l, lvl + 1, path_new)?;
192 Ok(self
193 .put(&Node::Middle { l: next, r })
194 .map_err(Error::Store)?)
195 }
196 }
197 }
198 }
199
200 fn push_leaf(
201 &mut self,
202 new_leaf: Node,
203 old_leaf: Node,
204 lvl: usize,
205 path_new: &[bool],
206 path_old: &[bool],
207 ) -> Result<[u8; 32], S::Error> {
208 if path_new[lvl] == path_old[lvl] {
209 let next_key = self.push_leaf(new_leaf, old_leaf, lvl + 1, path_new, path_old)?;
210 let mid = if path_new[lvl] {
211 Node::Middle {
212 l: [0; 32],
213 r: next_key,
214 }
215 } else {
216 Node::Middle {
217 l: next_key,
218 r: [0; 32],
219 }
220 };
221 return self.put(&mid);
222 }
223
224 let Node::Leaf { k: old_k, v: old_v } = old_leaf else {
225 unreachable!()
226 };
227
228 let new_leaf_key = self.put(&new_leaf)?;
229 let old_leaf_key = leaf_key(old_k, old_v);
230
231 let mid = if path_new[lvl] {
232 Node::Middle {
233 l: old_leaf_key,
234 r: new_leaf_key,
235 }
236 } else {
237 Node::Middle {
238 l: new_leaf_key,
239 r: old_leaf_key,
240 }
241 };
242 self.put(&mid)
243 }
244
245 pub fn add(&mut self, key: [u8; 32], val: [u8; 32]) -> Result<(), Error<S::Error>> {
247 let kh = key;
248 let vh = val;
249 let new_leaf = Node::Leaf { k: kh, v: vh };
250
251 let path_new = get_path::<D>(&kh);
252 let new_root = self.add_leaf(new_leaf, self.root().map_err(Error::Store)?, 0, &path_new)?;
253 self.set_root(new_root).map_err(Error::Store)?;
254 Ok(())
255 }
256
257 pub fn update(&mut self, key: [u8; 32], val: [u8; 32]) -> Result<[u8; 32], Error<S::Error>> {
259 let kh = key;
260 let vh = val;
261 let mut cur = self.root().map_err(Error::Store)?;
262 let mut siblings = heapless::Vec::<[u8; 32], D>::new();
263 let path = get_path::<D>(&kh);
264 let old_v;
265
266 for go_right in path.iter().copied() {
267 match self.store.get(cur).expect("node exists") {
268 None => return Err(Error::KeyNotFound),
269 Some(Node::Leaf { k, v }) => {
270 if k != kh {
271 return Err(Error::KeyNotFound);
272 }
273 old_v = Some(v);
274
275 let mut node = Node::Leaf { k: kh, v: vh };
276 let mut node_h = self.put(&node).map_err(Error::Store)?;
277
278 for (lvl, sib) in siblings.into_iter().enumerate().rev() {
279 let bit = path[lvl];
280 node = if bit {
281 Node::Middle { l: sib, r: node_h }
282 } else {
283 Node::Middle { l: node_h, r: sib }
284 };
285 node_h = self.put(&node).map_err(Error::Store)?;
286 }
287 self.set_root(node_h).map_err(Error::Store)?;
288 return Ok(old_v.unwrap());
289 }
290 Some(Node::Middle { l, r }) => {
291 if go_right {
292 siblings.push(l).unwrap();
293 cur = r;
294 } else {
295 siblings.push(r).unwrap();
296 cur = l;
297 }
298 }
299 }
300 }
301 Err(Error::KeyNotFound)
302 }
303
304 pub fn get_proof(&self, key: [u8; 32]) -> Result<CircomProof<D>, S::Error> {
306 let k = key;
307 let mut siblings = [[0; 32]; D];
308 let mut sibling_i = 0;
309 let mut cur = self.root()?;
310
311 for (i, go_right) in get_path::<D>(&k).into_iter().enumerate() {
312 match self.store.get(cur).expect("node exists") {
313 None => {
314 return Ok(CircomProof {
315 old_key: [0; 32],
316 old_value: [0; 32],
317 is_old0: true,
318 siblings,
319 membership: false,
320 });
321 }
322 Some(Node::Leaf {
323 k: leaf_k,
324 v: leaf_v,
325 }) => {
326 return Ok(CircomProof {
327 old_key: leaf_k,
328 old_value: leaf_v,
329 is_old0: leaf_k == [0; 32],
330 siblings,
331 membership: leaf_k == k,
332 });
333 }
334 Some(Node::Middle { l, r }) => {
335 if go_right {
336 siblings[sibling_i] = l;
337 cur = r;
338 } else {
339 siblings[sibling_i] = r;
340 cur = l;
341 }
342 sibling_i += 1;
343 }
344 }
345 if i == D - 1 {
346 return Ok(CircomProof {
347 old_key: [0; 32],
348 old_value: [0; 32],
349 is_old0: true,
350 siblings,
351 membership: false,
352 });
353 }
354 }
355 unreachable!();
356 }
357
358 pub fn add_or_update(&mut self, key: [u8; 32], val: [u8; 32]) -> Result<(), Error<S::Error>> {
360 match self.add(key, val) {
361 Err(Error::AlreadyPresent) => self.update(key, val).map(|_| ()),
362 x => x,
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::store::MemStore;
371
372 const DEPTH: usize = 64;
373
374 #[test]
375 fn test_smt() {
376 let mut t = SparseMerkleTree::<DEPTH, _>::new(MemStore::new()).unwrap();
377 assert_eq!(t.root().unwrap(), [0; 32]);
378
379 let k1 = [
380 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 43, 127, 78,
381 51, 93, 159, 92, 71,
382 ];
383 let v1 = [
384 16, 232, 248, 117, 61, 208, 169, 22, 163, 170, 44, 57, 210, 21, 42, 219, 91, 147, 79,
385 94, 181, 31, 210, 205, 159, 82, 222, 81, 110, 255, 37, 198,
386 ];
387 let p1 = t.get_proof(k1).unwrap();
388 assert!(p1.get_leaf().is_none());
389 t.add_or_update(k1, v1).unwrap();
390 assert_eq!(t.get_proof(k1).unwrap().get_leaf(), Some(&v1));
391 assert!(!p1.membership);
392 assert!(p1.is_old0);
393 assert_eq!(p1.old_key, [0; 32]);
394 assert_eq!(p1.old_value, [0; 32]);
395 assert_eq!(p1.siblings.len(), DEPTH);
396 assert!(p1.siblings.iter().all(|&b| b == [0; 32]));
397
398 let root1 = t.root().unwrap();
399 let root1_js = [
400 37, 18, 9, 85, 224, 252, 133, 154, 45, 120, 67, 166, 143, 180, 254, 196, 219, 139, 9,
401 229, 191, 47, 36, 89, 138, 111, 104, 170, 242, 127, 191, 38,
402 ];
403 assert_eq!(root1, root1_js);
404
405 let k2 = [
406 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 211, 160, 91,
407 130, 253, 193, 133, 52,
408 ];
409 let v2 = [
410 2, 135, 56, 32, 251, 187, 59, 31, 232, 236, 204, 116, 101, 171, 47, 15, 159, 138, 139,
411 231, 61, 78, 108, 10, 70, 133, 200, 198, 187, 100, 85, 178,
412 ];
413 let p2 = t.get_proof(k2).unwrap();
414 assert!(p2.get_leaf().is_none());
415 t.add_or_update(k2, v2).unwrap();
416 assert_eq!(t.get_proof(k2).unwrap().get_leaf(), Some(&v2));
417 assert!(!p2.membership);
418 assert!(!p2.is_old0);
419 assert_eq!(p2.old_key, k1);
420 assert_eq!(p2.old_value, v1);
421 assert_eq!(p2.siblings.len(), DEPTH);
422 assert!(p2.siblings.iter().all(|&b| b == [0; 32]));
423
424 let k3 = [
425 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74, 181, 123,
426 89, 155, 208, 255, 114,
427 ];
428 let v3 = [
429 16, 46, 63, 228, 134, 35, 92, 132, 114, 153, 57, 23, 154, 224, 217, 112, 131, 208, 134,
430 232, 218, 170, 173, 245, 178, 128, 151, 223, 2, 64, 114, 19,
431 ];
432 let p3 = t.get_proof(k3).unwrap();
433 assert!(p3.get_leaf().is_none());
434 t.add_or_update(k3, v3).unwrap();
435 assert_eq!(t.get_proof(k3).unwrap().get_leaf(), Some(&v3));
436 assert!(!p3.membership);
437 assert!(!p3.is_old0);
438 assert_eq!(p3.old_key, k2);
439 assert_eq!(p3.old_value, v2);
440 assert_eq!(p3.siblings.len(), DEPTH);
441 assert_eq!(p3.siblings[0], root1_js);
442 assert!(p3.siblings[1..].iter().all(|&b| b == [0; 32]));
443
444 let v4 = [
445 34, 105, 95, 86, 39, 160, 123, 45, 219, 68, 91, 94, 55, 161, 223, 203, 206, 164, 203,
446 253, 33, 59, 150, 111, 108, 74, 20, 17, 62, 214, 104, 58,
447 ];
448 let p4 = t.get_proof(k3).unwrap();
449 t.add_or_update(k3, v4).unwrap();
450 assert_eq!(t.get_proof(k3).unwrap().get_leaf(), Some(&v4));
451 assert!(p4.membership);
452 assert!(!p4.is_old0);
453 assert_eq!(p4.old_key, k3);
454 assert_eq!(p4.old_value, v3);
455 assert_eq!(p4.siblings.len(), DEPTH);
456 assert_eq!(p4.siblings[0], root1_js);
457 assert_eq!(
458 p4.siblings[1],
459 [
460 39, 2, 121, 120, 126, 69, 90, 96, 220, 95, 224, 252, 255, 197, 106, 214, 4, 22,
461 155, 164, 67, 176, 180, 82, 34, 37, 226, 17, 201, 250, 187, 58
462 ],
463 );
464 assert!(p4.siblings[2..].iter().all(|&b| b == [0; 32]));
465
466 assert!(t.get_proof([0; 32]).unwrap().get_leaf().is_none());
467 t.add([0; 32], [0; 32]).unwrap();
468 assert_eq!(t.get_proof([0; 32]).unwrap().get_leaf(), Some(&[0; 32]));
469
470 assert!(t.get_proof([1; 32]).unwrap().get_leaf().is_none());
471 t.add([1; 32], [1; 32]).unwrap();
472 assert_eq!(t.get_proof([1; 32]).unwrap().get_leaf(), Some(&[1; 32]));
473 }
474}