1use alloc::{borrow::Cow, vec::Vec};
2use core::{
3 iter::{self, FusedIterator},
4 num::NonZero,
5};
6
7use winter_utils::{Deserializable, DeserializationError, Serializable};
8
9use super::{
10 EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, SMT_MAX_DEPTH, ValuePath,
11 Word,
12};
13use crate::hash::rpo::Rpo256;
14
15#[derive(Clone, Debug, Default, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct SparseMerklePath {
29 empty_nodes_mask: u64,
33 nodes: Vec<Word>,
35}
36
37impl SparseMerklePath {
38 pub fn from_parts(empty_nodes_mask: u64, nodes: Vec<Word>) -> Result<Self, MerkleError> {
52 let min_path_len = u64::BITS - empty_nodes_mask.leading_zeros();
60 let empty_nodes_count = empty_nodes_mask.count_ones();
61 let min_non_empty_nodes = (min_path_len - empty_nodes_count) as usize;
62
63 if nodes.len() < min_non_empty_nodes {
64 return Err(MerkleError::InvalidPathLength(min_non_empty_nodes));
65 }
66
67 let depth = Self::depth_from_parts(empty_nodes_mask, &nodes) as u8;
68 if depth > SMT_MAX_DEPTH {
69 return Err(MerkleError::DepthTooBig(depth as u64));
70 }
71
72 Ok(Self { empty_nodes_mask, nodes })
73 }
74
75 pub fn from_sized_iter<I>(iterator: I) -> Result<Self, MerkleError>
85 where
86 I: IntoIterator<IntoIter: ExactSizeIterator, Item = Word>,
87 {
88 let iterator = iterator.into_iter();
89 let tree_depth = iterator.len() as u8;
90
91 if tree_depth > SMT_MAX_DEPTH {
92 return Err(MerkleError::DepthTooBig(tree_depth as u64));
93 }
94
95 let mut empty_nodes_mask: u64 = 0;
96 let mut nodes: Vec<Word> = Default::default();
97
98 for (depth, node) in iter::zip(path_depth_iter(tree_depth), iterator) {
99 let &equivalent_empty_node = EmptySubtreeRoots::entry(tree_depth, depth.get());
100 let is_empty = node == equivalent_empty_node;
101 let node = if is_empty { None } else { Some(node) };
102
103 match node {
104 Some(node) => nodes.push(node),
105 None => empty_nodes_mask |= Self::bitmask_for_depth(depth),
106 }
107 }
108
109 Ok(SparseMerklePath { nodes, empty_nodes_mask })
110 }
111
112 pub fn depth(&self) -> u8 {
114 Self::depth_from_parts(self.empty_nodes_mask, &self.nodes) as u8
115 }
116
117 pub fn at_depth(&self, node_depth: NonZero<u8>) -> Result<Word, MerkleError> {
127 if node_depth.get() > self.depth() {
128 return Err(MerkleError::DepthTooBig(node_depth.get().into()));
129 }
130
131 let node = if let Some(nonempty_index) = self.get_nonempty_index(node_depth) {
132 self.nodes[nonempty_index]
133 } else {
134 *EmptySubtreeRoots::entry(self.depth(), node_depth.get())
135 };
136
137 Ok(node)
138 }
139
140 pub fn into_parts(self) -> (u64, Vec<Word>) {
147 (self.empty_nodes_mask, self.nodes)
148 }
149
150 pub fn iter(&self) -> impl ExactSizeIterator<Item = Word> {
156 self.into_iter()
157 }
158
159 pub fn compute_root(&self, index: u64, node_to_prove: Word) -> Result<Word, MerkleError> {
161 let mut index = NodeIndex::new(self.depth(), index)?;
162 let root = self.iter().fold(node_to_prove, |node, sibling| {
163 let children = index.build_node(node, sibling);
165 index.move_up();
166 Rpo256::merge(&children)
167 });
168
169 Ok(root)
170 }
171
172 pub fn verify(&self, index: u64, node: Word, &expected_root: &Word) -> Result<(), MerkleError> {
179 let computed_root = self.compute_root(index, node)?;
180 if computed_root != expected_root {
181 return Err(MerkleError::ConflictingRoots {
182 expected_root,
183 actual_root: computed_root,
184 });
185 }
186
187 Ok(())
188 }
189
190 pub fn authenticated_nodes(
206 &self,
207 index: u64,
208 node_to_prove: Word,
209 ) -> Result<InnerNodeIterator<'_>, MerkleError> {
210 let index = NodeIndex::new(self.depth(), index)?;
211 Ok(InnerNodeIterator { path: self, index, value: node_to_prove })
212 }
213
214 const fn bitmask_for_depth(node_depth: NonZero<u8>) -> u64 {
218 1 << (node_depth.get() - 1)
220 }
221
222 const fn is_depth_empty(&self, node_depth: NonZero<u8>) -> bool {
223 (self.empty_nodes_mask & Self::bitmask_for_depth(node_depth)) != 0
224 }
225
226 fn get_nonempty_index(&self, node_depth: NonZero<u8>) -> Option<usize> {
229 if self.is_depth_empty(node_depth) {
230 return None;
231 }
232
233 let bit_index = node_depth.get() - 1;
234 let without_shallower = self.empty_nodes_mask >> bit_index;
235 let empty_deeper = without_shallower.count_ones() as usize;
236 let normal_index = (self.depth() - node_depth.get()) as usize;
238 Some(normal_index - empty_deeper)
240 }
241
242 fn depth_from_parts(empty_nodes_mask: u64, nodes: &[Word]) -> usize {
244 nodes.len() + empty_nodes_mask.count_ones() as usize
245 }
246}
247
248impl Serializable for SparseMerklePath {
252 fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
253 target.write_u8(self.depth());
254 target.write_u64(self.empty_nodes_mask);
255 target.write_many(&self.nodes);
256 }
257}
258
259impl Deserializable for SparseMerklePath {
260 fn read_from<R: winter_utils::ByteReader>(
261 source: &mut R,
262 ) -> Result<Self, DeserializationError> {
263 let depth = source.read_u8()?;
264 if depth > SMT_MAX_DEPTH {
265 return Err(DeserializationError::InvalidValue(format!(
266 "SparseMerklePath max depth exceeded ({depth} > {SMT_MAX_DEPTH})",
267 )));
268 }
269 let empty_nodes_mask = source.read_u64()?;
270 let empty_nodes_count = empty_nodes_mask.count_ones();
271 if empty_nodes_count > depth as u32 {
272 return Err(DeserializationError::InvalidValue(format!(
273 "SparseMerklePath has more empty nodes ({empty_nodes_count}) than its full length ({depth})",
274 )));
275 }
276 let count = depth as u32 - empty_nodes_count;
277 let nodes = source.read_many::<Word>(count as usize)?;
278 Ok(Self { empty_nodes_mask, nodes })
279 }
280}
281
282impl From<SparseMerklePath> for MerklePath {
286 fn from(sparse_path: SparseMerklePath) -> Self {
287 MerklePath::from_iter(sparse_path)
288 }
289}
290
291impl TryFrom<MerklePath> for SparseMerklePath {
292 type Error = MerkleError;
293
294 fn try_from(path: MerklePath) -> Result<Self, MerkleError> {
299 SparseMerklePath::from_sized_iter(path)
300 }
301}
302
303impl From<SparseMerklePath> for Vec<Word> {
304 fn from(path: SparseMerklePath) -> Self {
305 Vec::from_iter(path)
306 }
307}
308
309pub struct SparseMerklePathIter<'p> {
315 path: Cow<'p, SparseMerklePath>,
317
318 next_depth: u8,
321}
322
323impl Iterator for SparseMerklePathIter<'_> {
324 type Item = Word;
325
326 fn next(&mut self) -> Option<Word> {
327 let this_depth = self.next_depth;
328 let this_depth = NonZero::new(this_depth)?;
330 self.next_depth = this_depth.get() - 1;
331
332 let node = self
334 .path
335 .at_depth(this_depth)
336 .expect("current depth should never exceed the path depth");
337 Some(node)
338 }
339
340 fn size_hint(&self) -> (usize, Option<usize>) {
342 let remaining = ExactSizeIterator::len(self);
343 (remaining, Some(remaining))
344 }
345}
346
347impl ExactSizeIterator for SparseMerklePathIter<'_> {
348 fn len(&self) -> usize {
349 self.next_depth as usize
350 }
351}
352
353impl FusedIterator for SparseMerklePathIter<'_> {}
354
355impl IntoIterator for SparseMerklePath {
358 type IntoIter = SparseMerklePathIter<'static>;
359 type Item = <Self::IntoIter as Iterator>::Item;
360
361 fn into_iter(self) -> SparseMerklePathIter<'static> {
362 let tree_depth = self.depth();
363 SparseMerklePathIter {
364 path: Cow::Owned(self),
365 next_depth: tree_depth,
366 }
367 }
368}
369
370impl<'p> IntoIterator for &'p SparseMerklePath {
371 type Item = <SparseMerklePathIter<'p> as Iterator>::Item;
372 type IntoIter = SparseMerklePathIter<'p>;
373
374 fn into_iter(self) -> SparseMerklePathIter<'p> {
375 let tree_depth = self.depth();
376 SparseMerklePathIter {
377 path: Cow::Borrowed(self),
378 next_depth: tree_depth,
379 }
380 }
381}
382
383pub struct InnerNodeIterator<'p> {
386 path: &'p SparseMerklePath,
387 index: NodeIndex,
388 value: Word,
389}
390
391impl Iterator for InnerNodeIterator<'_> {
392 type Item = InnerNodeInfo;
393
394 fn next(&mut self) -> Option<Self::Item> {
395 if self.index.is_root() {
396 return None;
397 }
398
399 let index_depth = NonZero::new(self.index.depth()).expect("non-root depth cannot be 0");
400 let path_node = self.path.at_depth(index_depth).unwrap();
401
402 let children = self.index.build_node(self.value, path_node);
403 self.value = Rpo256::merge(&children);
404 self.index.move_up();
405
406 Some(InnerNodeInfo {
407 value: self.value,
408 left: children[0],
409 right: children[1],
410 })
411 }
412}
413
414impl PartialEq<MerklePath> for SparseMerklePath {
417 fn eq(&self, rhs: &MerklePath) -> bool {
418 if self.depth() != rhs.depth() {
419 return false;
420 }
421
422 for (node, &rhs_node) in iter::zip(self, rhs.iter()) {
423 if node != rhs_node {
424 return false;
425 }
426 }
427
428 true
429 }
430}
431
432impl PartialEq<SparseMerklePath> for MerklePath {
433 fn eq(&self, rhs: &SparseMerklePath) -> bool {
434 rhs == self
435 }
436}
437
438#[derive(Clone, Debug, Default, PartialEq, Eq)]
442pub struct SparseValuePath {
443 pub value: Word,
445 pub path: SparseMerklePath,
448}
449
450impl SparseValuePath {
451 pub fn new(value: Word, path: SparseMerklePath) -> Self {
455 Self { value, path }
456 }
457}
458
459impl From<(SparseMerklePath, Word)> for SparseValuePath {
460 fn from((path, value): (SparseMerklePath, Word)) -> Self {
461 SparseValuePath::new(value, path)
462 }
463}
464
465impl TryFrom<ValuePath> for SparseValuePath {
466 type Error = MerkleError;
467
468 fn try_from(other: ValuePath) -> Result<Self, MerkleError> {
473 let ValuePath { value, path } = other;
474 let path = SparseMerklePath::try_from(path)?;
475 Ok(SparseValuePath { value, path })
476 }
477}
478
479impl From<SparseValuePath> for ValuePath {
480 fn from(other: SparseValuePath) -> Self {
481 let SparseValuePath { value, path } = other;
482 ValuePath { value, path: path.into() }
483 }
484}
485
486impl PartialEq<ValuePath> for SparseValuePath {
487 fn eq(&self, rhs: &ValuePath) -> bool {
488 self.value == rhs.value && self.path == rhs.path
489 }
490}
491
492impl PartialEq<SparseValuePath> for ValuePath {
493 fn eq(&self, rhs: &SparseValuePath) -> bool {
494 rhs == self
495 }
496}
497
498fn path_depth_iter(tree_depth: u8) -> impl ExactSizeIterator<Item = NonZero<u8>> {
504 let top_down_iter = (1..=tree_depth).map(|depth| {
505 unsafe { NonZero::new_unchecked(depth) }
509 });
510
511 top_down_iter.rev()
513}
514
515#[cfg(test)]
518mod tests {
519 use alloc::vec::Vec;
520 use core::num::NonZero;
521
522 use assert_matches::assert_matches;
523
524 use super::SparseMerklePath;
525 use crate::{
526 Felt, ONE, Word,
527 merkle::{
528 EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, SMT_DEPTH, Smt,
529 smt::SparseMerkleTree, sparse_path::path_depth_iter,
530 },
531 };
532
533 fn make_smt(pair_count: u64) -> Smt {
534 let entries: Vec<(Word, Word)> = (0..pair_count)
535 .map(|n| {
536 let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
537 let key = Word::new([ONE, ONE, Felt::new(n), Felt::new(leaf_index)]);
538 let value = Word::new([ONE, ONE, ONE, ONE]);
539 (key, value)
540 })
541 .collect();
542
543 Smt::with_entries(entries).unwrap()
544 }
545
546 #[test]
547 fn test_roundtrip() {
548 let tree = make_smt(8192);
549
550 for (key, _value) in tree.entries() {
551 let (control_path, _) = tree.open(key).into_parts();
552 assert_eq!(control_path.len(), tree.depth() as usize);
553
554 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
555 assert_eq!(control_path.depth(), sparse_path.depth());
556 assert_eq!(sparse_path.depth(), SMT_DEPTH);
557 let test_path = MerklePath::from_iter(sparse_path.clone().into_iter());
558
559 assert_eq!(control_path, test_path);
560 }
561 }
562
563 #[test]
569 fn test_sparse_bits() {
570 const DEPTH: u8 = 8;
571 let raw_nodes: [Word; DEPTH as usize] = [
572 ([8u8, 8, 8, 8].into()),
574 *EmptySubtreeRoots::entry(DEPTH, 7),
576 *EmptySubtreeRoots::entry(DEPTH, 6),
578 [5u8, 5, 5, 5].into(),
580 [4u8, 4, 4, 4].into(),
582 *EmptySubtreeRoots::entry(DEPTH, 3),
584 *EmptySubtreeRoots::entry(DEPTH, 2),
586 *EmptySubtreeRoots::entry(DEPTH, 1),
588 ];
590
591 let sparse_nodes: [Option<Word>; DEPTH as usize] = [
592 Some([8u8, 8, 8, 8].into()),
594 None,
596 None,
598 Some([5u8, 5, 5, 5].into()),
600 Some([4u8, 4, 4, 4].into()),
602 None,
604 None,
606 None,
608 ];
610
611 const EMPTY_BITS: u64 = 0b0110_0111;
612
613 let sparse_path = SparseMerklePath::from_sized_iter(raw_nodes).unwrap();
614
615 assert_eq!(sparse_path.empty_nodes_mask, EMPTY_BITS);
616
617 let mut nonempty_idx = 0;
619
620 for depth in (1..=8).rev() {
622 let idx = (sparse_path.depth() - depth) as usize;
623 let bit = 1 << (depth - 1);
624
625 let is_set = (sparse_path.empty_nodes_mask & bit) != 0;
627 assert_eq!(is_set, sparse_nodes.get(idx).unwrap().is_none());
628
629 if is_set {
630 let &test_node = sparse_nodes.get(idx).unwrap();
632 assert_eq!(test_node, None);
633 } else {
634 let control_node = raw_nodes.get(idx).unwrap();
636 assert_eq!(
637 sparse_path.get_nonempty_index(NonZero::new(depth).unwrap()).unwrap(),
638 nonempty_idx
639 );
640 let test_node = sparse_path.nodes.get(nonempty_idx).unwrap();
641 assert_eq!(test_node, control_node);
642
643 nonempty_idx += 1;
644 }
645 }
646 }
647
648 #[test]
649 fn from_parts() {
650 const DEPTH: u8 = 8;
651 let raw_nodes: [Word; DEPTH as usize] = [
652 ([8u8, 8, 8, 8].into()),
654 *EmptySubtreeRoots::entry(DEPTH, 7),
656 *EmptySubtreeRoots::entry(DEPTH, 6),
658 [5u8, 5, 5, 5].into(),
660 [4u8, 4, 4, 4].into(),
662 *EmptySubtreeRoots::entry(DEPTH, 3),
664 *EmptySubtreeRoots::entry(DEPTH, 2),
666 *EmptySubtreeRoots::entry(DEPTH, 1),
668 ];
670
671 let empty_nodes_mask = 0b0110_0111;
672 let nodes = vec![[8u8, 8, 8, 8].into(), [5u8, 5, 5, 5].into(), [4u8, 4, 4, 4].into()];
673 let insufficient_nodes = vec![[4u8, 4, 4, 4].into()];
674
675 let error = SparseMerklePath::from_parts(empty_nodes_mask, insufficient_nodes).unwrap_err();
676 assert_matches!(error, MerkleError::InvalidPathLength(2));
677
678 let iter_sparse_path = SparseMerklePath::from_sized_iter(raw_nodes).unwrap();
679 let sparse_path = SparseMerklePath::from_parts(empty_nodes_mask, nodes).unwrap();
680
681 assert_eq!(sparse_path, iter_sparse_path);
682 }
683
684 #[test]
685 fn from_sized_iter() {
686 let tree = make_smt(8192);
687
688 for (key, _value) in tree.entries() {
689 let index = NodeIndex::from(Smt::key_to_leaf_index(key));
690
691 let control_path = tree.get_path(key);
692 for (&control_node, proof_index) in
693 itertools::zip_eq(&*control_path, index.proof_indices())
694 {
695 let proof_node = tree.get_node_hash(proof_index);
696 assert_eq!(control_node, proof_node);
697 }
698
699 let sparse_path =
700 SparseMerklePath::from_sized_iter(control_path.clone().into_iter()).unwrap();
701 for (sparse_node, proof_idx) in
702 itertools::zip_eq(sparse_path.clone(), index.proof_indices())
703 {
704 let proof_node = tree.get_node_hash(proof_idx);
705 assert_eq!(sparse_node, proof_node);
706 }
707
708 assert_eq!(control_path.depth(), sparse_path.depth());
709 for (control, sparse) in itertools::zip_eq(control_path, sparse_path) {
710 assert_eq!(control, sparse);
711 }
712 }
713 }
714
715 #[test]
716 fn test_random_access() {
717 let tree = make_smt(8192);
718
719 for (i, (key, _value)) in tree.entries().enumerate() {
720 let control_path = tree.get_path(key);
721 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
722 assert_eq!(control_path.depth(), sparse_path.depth());
723 assert_eq!(sparse_path.depth(), SMT_DEPTH);
724
725 for depth in path_depth_iter(control_path.depth()) {
727 let control_node = control_path.at_depth(depth).unwrap();
728 let sparse_node = sparse_path.at_depth(depth).unwrap();
729 assert_eq!(control_node, sparse_node, "at depth {depth} for entry {i}");
730 }
731 }
732 }
733
734 #[test]
735 fn test_borrowing_iterator() {
736 let tree = make_smt(8192);
737
738 for (key, _value) in tree.entries() {
739 let control_path = tree.get_path(key);
740 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
741 assert_eq!(control_path.depth(), sparse_path.depth());
742 assert_eq!(sparse_path.depth(), SMT_DEPTH);
743
744 let mut count: u64 = 0;
746 for (&control_node, sparse_node) in
747 itertools::zip_eq(control_path.iter(), sparse_path.iter())
748 {
749 count += 1;
750 assert_eq!(control_node, sparse_node);
751 }
752 assert_eq!(count, control_path.depth() as u64);
753 }
754 }
755
756 #[test]
757 fn test_owning_iterator() {
758 let tree = make_smt(8192);
759
760 for (key, _value) in tree.entries() {
761 let control_path = tree.get_path(key);
762 let path_depth = control_path.depth();
763 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
764 assert_eq!(control_path.depth(), sparse_path.depth());
765 assert_eq!(sparse_path.depth(), SMT_DEPTH);
766
767 let mut count: u64 = 0;
769 for (control_node, sparse_node) in itertools::zip_eq(control_path, sparse_path) {
770 count += 1;
771 assert_eq!(control_node, sparse_node);
772 }
773 assert_eq!(count, path_depth as u64);
774 }
775 }
776
777 #[test]
778 fn test_zero_sized() {
779 let nodes: Vec<Word> = Default::default();
780
781 let sparse_path = SparseMerklePath::from_sized_iter(nodes).unwrap();
783 assert_eq!(sparse_path.depth(), 0);
784 assert_matches!(
785 sparse_path.at_depth(NonZero::new(1).unwrap()),
786 Err(MerkleError::DepthTooBig(1))
787 );
788 assert_eq!(sparse_path.iter().next(), None);
789 assert_eq!(sparse_path.into_iter().next(), None);
790 }
791
792 #[test]
793 fn test_root() {
794 let tree = make_smt(100);
795
796 for (key, _value) in tree.entries() {
797 let leaf = tree.get_leaf(key);
798 let leaf_node = leaf.hash();
799 let index: NodeIndex = Smt::key_to_leaf_index(key).into();
800 let control_path = tree.get_path(key);
801 let sparse_path = SparseMerklePath::try_from(control_path.clone()).unwrap();
802
803 let authed_nodes: Vec<_> =
804 sparse_path.authenticated_nodes(index.value(), leaf_node).unwrap().collect();
805 let authed_root = authed_nodes.last().unwrap().value;
806
807 let control_root = control_path.compute_root(index.value(), leaf_node).unwrap();
808 let sparse_root = sparse_path.compute_root(index.value(), leaf_node).unwrap();
809 assert_eq!(control_root, sparse_root);
810 assert_eq!(authed_root, control_root);
811 assert_eq!(authed_root, tree.root());
812
813 let index = index.value();
814 let control_auth_nodes = control_path.authenticated_nodes(index, leaf_node).unwrap();
815 let sparse_auth_nodes = sparse_path.authenticated_nodes(index, leaf_node).unwrap();
816 for (a, b) in control_auth_nodes.zip(sparse_auth_nodes) {
817 assert_eq!(a, b);
818 }
819 }
820 }
821}