1use alloc::{borrow::Cow, vec::Vec};
2use core::{
3 iter::{self, FusedIterator},
4 num::NonZero,
5};
6
7use winter_utils::{Deserializable, DeserializationError, Serializable};
8
9use super::{
10 EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, SMT_MAX_DEPTH, ValuePath,
11 Word,
12};
13use crate::hash::rpo::Rpo256;
14
15#[derive(Clone, Debug, Default, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct SparseMerklePath {
29 empty_nodes_mask: u64,
33 nodes: Vec<Word>,
35}
36
37impl SparseMerklePath {
38 pub fn from_sized_iter<I>(iterator: I) -> Result<Self, MerkleError>
48 where
49 I: IntoIterator<IntoIter: ExactSizeIterator, Item = Word>,
50 {
51 let iterator = iterator.into_iter();
52 let tree_depth = iterator.len() as u8;
53
54 if tree_depth > SMT_MAX_DEPTH {
55 return Err(MerkleError::DepthTooBig(tree_depth as u64));
56 }
57
58 let mut empty_nodes_mask: u64 = 0;
59 let mut nodes: Vec<Word> = Default::default();
60
61 for (depth, node) in iter::zip(path_depth_iter(tree_depth), iterator) {
62 let &equivalent_empty_node = EmptySubtreeRoots::entry(tree_depth, depth.get());
63 let is_empty = node == equivalent_empty_node;
64 let node = if is_empty { None } else { Some(node) };
65
66 match node {
67 Some(node) => nodes.push(node),
68 None => empty_nodes_mask |= Self::bitmask_for_depth(depth),
69 }
70 }
71
72 Ok(SparseMerklePath { nodes, empty_nodes_mask })
73 }
74
75 pub fn depth(&self) -> u8 {
77 (self.nodes.len() + self.empty_nodes_mask.count_ones() as usize) as u8
78 }
79
80 pub fn at_depth(&self, node_depth: NonZero<u8>) -> Result<Word, MerkleError> {
90 if node_depth.get() > self.depth() {
91 return Err(MerkleError::DepthTooBig(node_depth.get().into()));
92 }
93
94 let node = if let Some(nonempty_index) = self.get_nonempty_index(node_depth) {
95 self.nodes[nonempty_index]
96 } else {
97 *EmptySubtreeRoots::entry(self.depth(), node_depth.get())
98 };
99
100 Ok(node)
101 }
102
103 pub fn iter(&self) -> impl ExactSizeIterator<Item = Word> {
109 self.into_iter()
110 }
111
112 pub fn compute_root(&self, index: u64, node_to_prove: Word) -> Result<Word, MerkleError> {
114 let mut index = NodeIndex::new(self.depth(), index)?;
115 let root = self.iter().fold(node_to_prove, |node, sibling| {
116 let children = index.build_node(node, sibling);
118 index.move_up();
119 Rpo256::merge(&children)
120 });
121
122 Ok(root)
123 }
124
125 pub fn verify(&self, index: u64, node: Word, &expected_root: &Word) -> Result<(), MerkleError> {
132 let computed_root = self.compute_root(index, node)?;
133 if computed_root != expected_root {
134 return Err(MerkleError::ConflictingRoots {
135 expected_root,
136 actual_root: computed_root,
137 });
138 }
139
140 Ok(())
141 }
142
143 pub fn authenticated_nodes(
159 &self,
160 index: u64,
161 node_to_prove: Word,
162 ) -> Result<InnerNodeIterator<'_>, MerkleError> {
163 let index = NodeIndex::new(self.depth(), index)?;
164 Ok(InnerNodeIterator { path: self, index, value: node_to_prove })
165 }
166
167 const fn bitmask_for_depth(node_depth: NonZero<u8>) -> u64 {
171 1 << (node_depth.get() - 1)
173 }
174
175 const fn is_depth_empty(&self, node_depth: NonZero<u8>) -> bool {
176 (self.empty_nodes_mask & Self::bitmask_for_depth(node_depth)) != 0
177 }
178
179 fn get_nonempty_index(&self, node_depth: NonZero<u8>) -> Option<usize> {
182 if self.is_depth_empty(node_depth) {
183 return None;
184 }
185
186 let bit_index = node_depth.get() - 1;
187 let without_shallower = self.empty_nodes_mask >> bit_index;
188 let empty_deeper = without_shallower.count_ones() as usize;
189 let normal_index = (self.depth() - node_depth.get()) as usize;
191 Some(normal_index - empty_deeper)
193 }
194}
195
196impl Serializable for SparseMerklePath {
200 fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
201 target.write_u8(self.depth());
202 target.write_u64(self.empty_nodes_mask);
203 target.write_many(&self.nodes);
204 }
205}
206
207impl Deserializable for SparseMerklePath {
208 fn read_from<R: winter_utils::ByteReader>(
209 source: &mut R,
210 ) -> Result<Self, DeserializationError> {
211 let depth = source.read_u8()?;
212 if depth > SMT_MAX_DEPTH {
213 return Err(DeserializationError::InvalidValue(format!(
214 "SparseMerklePath max depth exceeded ({depth} > {SMT_MAX_DEPTH})",
215 )));
216 }
217 let empty_nodes_mask = source.read_u64()?;
218 let empty_nodes_count = empty_nodes_mask.count_ones();
219 if empty_nodes_count > depth as u32 {
220 return Err(DeserializationError::InvalidValue(format!(
221 "SparseMerklePath has more empty nodes ({empty_nodes_count}) than its full length ({depth})",
222 )));
223 }
224 let count = depth as u32 - empty_nodes_count;
225 let nodes = source.read_many::<Word>(count as usize)?;
226 Ok(Self { empty_nodes_mask, nodes })
227 }
228}
229
230impl From<SparseMerklePath> for MerklePath {
234 fn from(sparse_path: SparseMerklePath) -> Self {
235 MerklePath::from_iter(sparse_path)
236 }
237}
238
239impl TryFrom<MerklePath> for SparseMerklePath {
240 type Error = MerkleError;
241
242 fn try_from(path: MerklePath) -> Result<Self, MerkleError> {
247 SparseMerklePath::from_sized_iter(path)
248 }
249}
250
251impl From<SparseMerklePath> for Vec<Word> {
252 fn from(path: SparseMerklePath) -> Self {
253 Vec::from_iter(path)
254 }
255}
256
257pub struct SparseMerklePathIter<'p> {
263 path: Cow<'p, SparseMerklePath>,
265
266 next_depth: u8,
269}
270
271impl Iterator for SparseMerklePathIter<'_> {
272 type Item = Word;
273
274 fn next(&mut self) -> Option<Word> {
275 let this_depth = self.next_depth;
276 let this_depth = NonZero::new(this_depth)?;
278 self.next_depth = this_depth.get() - 1;
279
280 let node = self
282 .path
283 .at_depth(this_depth)
284 .expect("current depth should never exceed the path depth");
285 Some(node)
286 }
287
288 fn size_hint(&self) -> (usize, Option<usize>) {
290 let remaining = ExactSizeIterator::len(self);
291 (remaining, Some(remaining))
292 }
293}
294
295impl ExactSizeIterator for SparseMerklePathIter<'_> {
296 fn len(&self) -> usize {
297 self.next_depth as usize
298 }
299}
300
301impl FusedIterator for SparseMerklePathIter<'_> {}
302
303impl IntoIterator for SparseMerklePath {
306 type IntoIter = SparseMerklePathIter<'static>;
307 type Item = <Self::IntoIter as Iterator>::Item;
308
309 fn into_iter(self) -> SparseMerklePathIter<'static> {
310 let tree_depth = self.depth();
311 SparseMerklePathIter {
312 path: Cow::Owned(self),
313 next_depth: tree_depth,
314 }
315 }
316}
317
318impl<'p> IntoIterator for &'p SparseMerklePath {
319 type Item = <SparseMerklePathIter<'p> as Iterator>::Item;
320 type IntoIter = SparseMerklePathIter<'p>;
321
322 fn into_iter(self) -> SparseMerklePathIter<'p> {
323 let tree_depth = self.depth();
324 SparseMerklePathIter {
325 path: Cow::Borrowed(self),
326 next_depth: tree_depth,
327 }
328 }
329}
330
331pub struct InnerNodeIterator<'p> {
334 path: &'p SparseMerklePath,
335 index: NodeIndex,
336 value: Word,
337}
338
339impl Iterator for InnerNodeIterator<'_> {
340 type Item = InnerNodeInfo;
341
342 fn next(&mut self) -> Option<Self::Item> {
343 if self.index.is_root() {
344 return None;
345 }
346
347 let index_depth = NonZero::new(self.index.depth()).expect("non-root depth cannot be 0");
348 let path_node = self.path.at_depth(index_depth).unwrap();
349
350 let children = self.index.build_node(self.value, path_node);
351 self.value = Rpo256::merge(&children);
352 self.index.move_up();
353
354 Some(InnerNodeInfo {
355 value: self.value,
356 left: children[0],
357 right: children[1],
358 })
359 }
360}
361
362impl PartialEq<MerklePath> for SparseMerklePath {
365 fn eq(&self, rhs: &MerklePath) -> bool {
366 if self.depth() != rhs.depth() {
367 return false;
368 }
369
370 for (node, &rhs_node) in iter::zip(self, rhs.iter()) {
371 if node != rhs_node {
372 return false;
373 }
374 }
375
376 true
377 }
378}
379
380impl PartialEq<SparseMerklePath> for MerklePath {
381 fn eq(&self, rhs: &SparseMerklePath) -> bool {
382 rhs == self
383 }
384}
385
386#[derive(Clone, Debug, Default, PartialEq, Eq)]
390pub struct SparseValuePath {
391 pub value: Word,
393 pub path: SparseMerklePath,
396}
397
398impl SparseValuePath {
399 pub fn new(value: Word, path: SparseMerklePath) -> Self {
403 Self { value, path }
404 }
405}
406
407impl From<(SparseMerklePath, Word)> for SparseValuePath {
408 fn from((path, value): (SparseMerklePath, Word)) -> Self {
409 SparseValuePath::new(value, path)
410 }
411}
412
413impl TryFrom<ValuePath> for SparseValuePath {
414 type Error = MerkleError;
415
416 fn try_from(other: ValuePath) -> Result<Self, MerkleError> {
421 let ValuePath { value, path } = other;
422 let path = SparseMerklePath::try_from(path)?;
423 Ok(SparseValuePath { value, path })
424 }
425}
426
427impl From<SparseValuePath> for ValuePath {
428 fn from(other: SparseValuePath) -> Self {
429 let SparseValuePath { value, path } = other;
430 ValuePath { value, path: path.into() }
431 }
432}
433
434impl PartialEq<ValuePath> for SparseValuePath {
435 fn eq(&self, rhs: &ValuePath) -> bool {
436 self.value == rhs.value && self.path == rhs.path
437 }
438}
439
440impl PartialEq<SparseValuePath> for ValuePath {
441 fn eq(&self, rhs: &SparseValuePath) -> bool {
442 rhs == self
443 }
444}
445
446fn path_depth_iter(tree_depth: u8) -> impl ExactSizeIterator<Item = NonZero<u8>> {
452 let top_down_iter = (1..=tree_depth).map(|depth| {
453 unsafe { NonZero::new_unchecked(depth) }
457 });
458
459 top_down_iter.rev()
461}
462
463#[cfg(test)]
466mod tests {
467 use alloc::vec::Vec;
468 use core::num::NonZero;
469
470 use assert_matches::assert_matches;
471
472 use super::SparseMerklePath;
473 use crate::{
474 Felt, ONE, Word,
475 merkle::{
476 EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, SMT_DEPTH, Smt,
477 smt::SparseMerkleTree, sparse_path::path_depth_iter,
478 },
479 };
480
481 fn make_smt(pair_count: u64) -> Smt {
482 let entries: Vec<(Word, Word)> = (0..pair_count)
483 .map(|n| {
484 let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
485 let key = Word::new([ONE, ONE, Felt::new(n), Felt::new(leaf_index)]);
486 let value = Word::new([ONE, ONE, ONE, ONE]);
487 (key, value)
488 })
489 .collect();
490
491 Smt::with_entries(entries).unwrap()
492 }
493
494 #[test]
495 fn test_roundtrip() {
496 let tree = make_smt(8192);
497
498 for (key, _value) in tree.entries() {
499 let (control_path, _) = tree.open(key).into_parts();
500 assert_eq!(control_path.len(), tree.depth() as usize);
501
502 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
503 assert_eq!(control_path.depth(), sparse_path.depth());
504 assert_eq!(sparse_path.depth(), SMT_DEPTH);
505 let test_path = MerklePath::from_iter(sparse_path.clone().into_iter());
506
507 assert_eq!(control_path, test_path);
508 }
509 }
510
511 #[test]
517 fn test_sparse_bits() {
518 const DEPTH: u8 = 8;
519 let raw_nodes: [Word; DEPTH as usize] = [
520 ([8u8, 8, 8, 8].into()),
522 *EmptySubtreeRoots::entry(DEPTH, 7),
524 *EmptySubtreeRoots::entry(DEPTH, 6),
526 [5u8, 5, 5, 5].into(),
528 [4u8, 4, 4, 4].into(),
530 *EmptySubtreeRoots::entry(DEPTH, 3),
532 *EmptySubtreeRoots::entry(DEPTH, 2),
534 *EmptySubtreeRoots::entry(DEPTH, 1),
536 ];
538
539 let sparse_nodes: [Option<Word>; DEPTH as usize] = [
540 Some([8u8, 8, 8, 8].into()),
542 None,
544 None,
546 Some([5u8, 5, 5, 5].into()),
548 Some([4u8, 4, 4, 4].into()),
550 None,
552 None,
554 None,
556 ];
558
559 const EMPTY_BITS: u64 = 0b0110_0111;
560
561 let sparse_path = SparseMerklePath::from_sized_iter(raw_nodes).unwrap();
562
563 assert_eq!(sparse_path.empty_nodes_mask, EMPTY_BITS);
564
565 let mut nonempty_idx = 0;
567
568 for depth in (1..=8).rev() {
570 let idx = (sparse_path.depth() - depth) as usize;
571 let bit = 1 << (depth - 1);
572
573 let is_set = (sparse_path.empty_nodes_mask & bit) != 0;
575 assert_eq!(is_set, sparse_nodes.get(idx).unwrap().is_none());
576
577 if is_set {
578 let &test_node = sparse_nodes.get(idx).unwrap();
580 assert_eq!(test_node, None);
581 } else {
582 let control_node = raw_nodes.get(idx).unwrap();
584 assert_eq!(
585 sparse_path.get_nonempty_index(NonZero::new(depth).unwrap()).unwrap(),
586 nonempty_idx
587 );
588 let test_node = sparse_path.nodes.get(nonempty_idx).unwrap();
589 assert_eq!(test_node, control_node);
590
591 nonempty_idx += 1;
592 }
593 }
594 }
595
596 #[test]
597 fn from_sized_iter() {
598 let tree = make_smt(8192);
599
600 for (key, _value) in tree.entries() {
601 let index = NodeIndex::from(Smt::key_to_leaf_index(key));
602
603 let control_path = tree.get_path(key);
604 for (&control_node, proof_index) in
605 itertools::zip_eq(&*control_path, index.proof_indices())
606 {
607 let proof_node = tree.get_node_hash(proof_index);
608 assert_eq!(control_node, proof_node);
609 }
610
611 let sparse_path =
612 SparseMerklePath::from_sized_iter(control_path.clone().into_iter()).unwrap();
613 for (sparse_node, proof_idx) in
614 itertools::zip_eq(sparse_path.clone(), index.proof_indices())
615 {
616 let proof_node = tree.get_node_hash(proof_idx);
617 assert_eq!(sparse_node, proof_node);
618 }
619
620 assert_eq!(control_path.depth(), sparse_path.depth());
621 for (control, sparse) in itertools::zip_eq(control_path, sparse_path) {
622 assert_eq!(control, sparse);
623 }
624 }
625 }
626
627 #[test]
628 fn test_random_access() {
629 let tree = make_smt(8192);
630
631 for (i, (key, _value)) in tree.entries().enumerate() {
632 let control_path = tree.get_path(key);
633 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
634 assert_eq!(control_path.depth(), sparse_path.depth());
635 assert_eq!(sparse_path.depth(), SMT_DEPTH);
636
637 for depth in path_depth_iter(control_path.depth()) {
639 let control_node = control_path.at_depth(depth).unwrap();
640 let sparse_node = sparse_path.at_depth(depth).unwrap();
641 assert_eq!(control_node, sparse_node, "at depth {depth} for entry {i}");
642 }
643 }
644 }
645
646 #[test]
647 fn test_borrowing_iterator() {
648 let tree = make_smt(8192);
649
650 for (key, _value) in tree.entries() {
651 let control_path = tree.get_path(key);
652 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
653 assert_eq!(control_path.depth(), sparse_path.depth());
654 assert_eq!(sparse_path.depth(), SMT_DEPTH);
655
656 let mut count: u64 = 0;
658 for (&control_node, sparse_node) in
659 itertools::zip_eq(control_path.iter(), sparse_path.iter())
660 {
661 count += 1;
662 assert_eq!(control_node, sparse_node);
663 }
664 assert_eq!(count, control_path.depth() as u64);
665 }
666 }
667
668 #[test]
669 fn test_owning_iterator() {
670 let tree = make_smt(8192);
671
672 for (key, _value) in tree.entries() {
673 let control_path = tree.get_path(key);
674 let path_depth = control_path.depth();
675 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
676 assert_eq!(control_path.depth(), sparse_path.depth());
677 assert_eq!(sparse_path.depth(), SMT_DEPTH);
678
679 let mut count: u64 = 0;
681 for (control_node, sparse_node) in itertools::zip_eq(control_path, sparse_path) {
682 count += 1;
683 assert_eq!(control_node, sparse_node);
684 }
685 assert_eq!(count, path_depth as u64);
686 }
687 }
688
689 #[test]
690 fn test_zero_sized() {
691 let nodes: Vec<Word> = Default::default();
692
693 let sparse_path = SparseMerklePath::from_sized_iter(nodes).unwrap();
695 assert_eq!(sparse_path.depth(), 0);
696 assert_matches!(
697 sparse_path.at_depth(NonZero::new(1).unwrap()),
698 Err(MerkleError::DepthTooBig(1))
699 );
700 assert_eq!(sparse_path.iter().next(), None);
701 assert_eq!(sparse_path.into_iter().next(), None);
702 }
703
704 #[test]
705 fn test_root() {
706 let tree = make_smt(100);
707
708 for (key, _value) in tree.entries() {
709 let leaf = tree.get_leaf(key);
710 let leaf_node = leaf.hash();
711 let index: NodeIndex = Smt::key_to_leaf_index(key).into();
712 let control_path = tree.get_path(key);
713 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
714
715 let authed_nodes: Vec<_> =
716 sparse_path.authenticated_nodes(index.value(), leaf_node).unwrap().collect();
717 let authed_root = authed_nodes.last().unwrap().value;
718
719 let control_root = control_path.compute_root(index.value(), leaf_node).unwrap();
720 let sparse_root = sparse_path.compute_root(index.value(), leaf_node).unwrap();
721 assert_eq!(control_root, sparse_root);
722 assert_eq!(authed_root, control_root);
723 assert_eq!(authed_root, tree.root());
724
725 let index = index.value();
726 let control_auth_nodes = control_path.authenticated_nodes(index, leaf_node).unwrap();
727 let sparse_auth_nodes = sparse_path.authenticated_nodes(index, leaf_node).unwrap();
728 for (a, b) in control_auth_nodes.zip(sparse_auth_nodes) {
729 assert_eq!(a, b);
730 }
731 }
732 }
733}