1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 vec::Vec,
4};
5
6use winter_utils::{Deserializable, Serializable};
7
8use super::{MmrDelta, MmrProof};
9use crate::{
10 Word,
11 merkle::{
12 InnerNodeInfo, MerklePath, Rpo256,
13 mmr::{InOrderIndex, MmrError, MmrPeaks, forest::Forest},
14 },
15};
16
17type NodeMap = BTreeMap<InOrderIndex, Word>;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct PartialMmr {
31 pub(crate) forest: Forest,
42
43 pub(crate) peaks: Vec<Word>,
55
56 pub(crate) nodes: NodeMap,
68
69 pub(crate) track_latest: bool,
74}
75
76impl Default for PartialMmr {
77 fn default() -> Self {
79 let forest = Forest::empty();
80 let peaks = Vec::new();
81 let nodes = BTreeMap::new();
82 let track_latest = false;
83
84 Self { forest, peaks, nodes, track_latest }
85 }
86}
87
88impl PartialMmr {
89 pub fn from_peaks(peaks: MmrPeaks) -> Self {
94 let forest = peaks.forest();
95 let peaks = peaks.into();
96 let nodes = BTreeMap::new();
97 let track_latest = false;
98
99 Self { forest, peaks, nodes, track_latest }
100 }
101
102 pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
107 let forest = peaks.forest();
108 let peaks = peaks.into();
109
110 Self { forest, peaks, nodes, track_latest }
111 }
112
113 pub fn forest(&self) -> Forest {
121 self.forest
122 }
123
124 pub fn num_leaves(&self) -> usize {
126 self.forest.num_leaves()
127 }
128
129 pub fn peaks(&self) -> MmrPeaks {
131 MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
134 }
135
136 pub fn is_tracked(&self, pos: usize) -> bool {
139 let leaves = self.forest.num_leaves();
140 if pos >= leaves {
141 return false;
142 } else if pos == leaves - 1 && self.forest.has_single_leaf_tree() {
143 return self.track_latest;
146 }
147
148 let leaf_index = InOrderIndex::from_leaf_pos(pos);
149 self.is_tracked_node(&leaf_index)
150 }
151
152 pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
163 let tree_bit = self
164 .forest
165 .leaf_to_corresponding_tree(pos)
166 .ok_or(MmrError::PositionNotFound(pos))?;
167 let depth = tree_bit as usize;
168
169 let mut nodes = Vec::with_capacity(depth);
170 let mut idx = InOrderIndex::from_leaf_pos(pos);
171
172 while let Some(node) = self.nodes.get(&idx.sibling()) {
173 nodes.push(*node);
174 idx = idx.parent();
175 }
176
177 debug_assert!(nodes.is_empty() || nodes.len() == depth);
179
180 if nodes.len() != depth {
181 Ok(None)
183 } else {
184 Ok(Some(MmrProof {
185 forest: self.forest,
186 position: pos,
187 merkle_path: MerklePath::new(nodes),
188 }))
189 }
190 }
191
192 pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &Word)> {
197 self.nodes.iter()
198 }
199
200 pub fn inner_nodes<'a, I: Iterator<Item = (usize, Word)> + 'a>(
205 &'a self,
206 mut leaves: I,
207 ) -> impl Iterator<Item = InnerNodeInfo> + 'a {
208 let stack = if let Some((pos, leaf)) = leaves.next() {
209 let idx = InOrderIndex::from_leaf_pos(pos);
210 vec![(idx, leaf)]
211 } else {
212 Vec::new()
213 };
214
215 InnerNodeIterator {
216 nodes: &self.nodes,
217 leaves,
218 stack,
219 seen_nodes: BTreeSet::new(),
220 }
221 }
222
223 pub fn add(&mut self, leaf: Word, track: bool) -> Vec<(InOrderIndex, Word)> {
231 self.forest.append_leaf();
232 let merges = self.forest.smallest_tree_height_unchecked();
234 let mut new_nodes = Vec::with_capacity(merges);
235
236 let peak = if merges == 0 {
237 self.track_latest = track;
238 leaf
239 } else {
240 let mut track_right = track;
241 let mut track_left = self.track_latest;
242
243 let mut right = leaf;
244 let mut right_idx = self.forest.rightmost_in_order_index();
245
246 for _ in 0..merges {
247 let left = self.peaks.pop().expect("Missing peak");
248 let left_idx = right_idx.sibling();
249
250 if track_right {
251 let old = self.nodes.insert(left_idx, left);
252 new_nodes.push((left_idx, left));
253
254 debug_assert!(
255 old.is_none(),
256 "Idx {left_idx:?} already contained an element {old:?}",
257 );
258 };
259 if track_left {
260 let old = self.nodes.insert(right_idx, right);
261 new_nodes.push((right_idx, right));
262
263 debug_assert!(
264 old.is_none(),
265 "Idx {right_idx:?} already contained an element {old:?}",
266 );
267 };
268
269 right_idx = right_idx.parent();
274
275 right = Rpo256::merge(&[left, right]);
278
279 track_right = track_right || track_left;
283
284 track_left = self.is_tracked_node(&right_idx.sibling());
287 }
288 right
289 };
290
291 self.peaks.push(peak);
292
293 new_nodes
294 }
295
296 pub fn track(
307 &mut self,
308 leaf_pos: usize,
309 leaf: Word,
310 path: &MerklePath,
311 ) -> Result<(), MmrError> {
312 let tree = Forest::new(1 << path.depth());
315 if (tree & self.forest).is_empty() {
316 return Err(MmrError::UnknownPeak(path.depth()));
317 };
318
319 if leaf_pos + 1 == self.forest.num_leaves()
320 && path.depth() == 0
321 && self.peaks.last().is_some_and(|v| *v == leaf)
322 {
323 self.track_latest = true;
324 return Ok(());
325 }
326
327 let target_forest = self.forest ^ (self.forest & tree.all_smaller_trees_unchecked());
330 let peak_pos = target_forest.num_trees() - 1;
331
332 let path_idx = leaf_pos - (target_forest ^ tree).num_leaves();
334
335 let computed = path
338 .compute_root(path_idx as u64, leaf)
339 .map_err(MmrError::MerkleRootComputationFailed)?;
340 if self.peaks[peak_pos] != computed {
341 return Err(MmrError::PeakPathMismatch);
342 }
343
344 let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
345 for leaf in path.nodes() {
346 self.nodes.insert(idx.sibling(), *leaf);
347 idx = idx.parent();
348 }
349
350 Ok(())
351 }
352
353 pub fn untrack(&mut self, leaf_pos: usize) {
357 let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
358
359 while self.nodes.remove(&idx.sibling()).is_some() && !self.nodes.contains_key(&idx) {
365 idx = idx.parent();
366 }
367 }
368
369 pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, Word)>, MmrError> {
372 if delta.forest < self.forest {
373 return Err(MmrError::InvalidPeaks(format!(
374 "forest of mmr delta {} is less than current forest {}",
375 delta.forest, self.forest
376 )));
377 }
378
379 let mut inserted_nodes = Vec::new();
380
381 if delta.forest == self.forest {
382 if !delta.data.is_empty() {
383 return Err(MmrError::InvalidUpdate);
384 }
385
386 return Ok(inserted_nodes);
387 }
388
389 let changes = self.forest ^ delta.forest;
391 let largest = changes.largest_tree_unchecked();
394 let merges = self.forest & largest.all_smaller_trees_unchecked();
396
397 debug_assert!(
398 !self.track_latest || merges.has_single_leaf_tree(),
399 "if there is an odd element, a merge is required"
400 );
401
402 let (merge_count, new_peaks) = if !merges.is_empty() {
404 let depth = largest.smallest_tree_height_unchecked();
405 let skipped = merges.smallest_tree_height_unchecked();
407 let computed = merges.num_trees() - 1;
408 let merge_count = depth - skipped - computed;
409
410 let new_peaks = delta.forest & largest.all_smaller_trees_unchecked();
411
412 (merge_count, new_peaks)
413 } else {
414 (0, changes)
415 };
416
417 if delta.data.len() != merge_count + new_peaks.num_trees() {
419 return Err(MmrError::InvalidUpdate);
420 }
421
422 let mut update_count = 0;
424
425 if !merges.is_empty() {
426 let mut peak_idx = self.forest.root_in_order_index();
428
429 self.peaks.reverse();
431
432 let mut track = self.track_latest;
434 self.track_latest = false;
435
436 let mut peak_count = 0;
437 let mut target = merges.smallest_tree_unchecked();
438 let mut new = delta.data[0];
439 update_count += 1;
440
441 while target < largest {
442 if target != Forest::new(1) && !track {
445 track = self.is_tracked_node(&peak_idx);
446 }
447
448 let (left, right) = if !(target & merges).is_empty() {
451 let peak = self.peaks[peak_count];
452 let sibling_idx = peak_idx.sibling();
453
454 if self.is_tracked_node(&sibling_idx) {
457 self.nodes.insert(peak_idx, new);
458 inserted_nodes.push((peak_idx, new));
459 }
460 peak_count += 1;
461 (peak, new)
462 } else {
463 let update = delta.data[update_count];
464 update_count += 1;
465 (new, update)
466 };
467
468 if track {
469 let sibling_idx = peak_idx.sibling();
470 if peak_idx.is_left_child() {
471 self.nodes.insert(sibling_idx, right);
472 inserted_nodes.push((sibling_idx, right));
473 } else {
474 self.nodes.insert(sibling_idx, left);
475 inserted_nodes.push((sibling_idx, left));
476 }
477 }
478
479 peak_idx = peak_idx.parent();
480 new = Rpo256::merge(&[left, right]);
481 target = target.next_larger_tree();
482 }
483
484 debug_assert!(peak_count == merges.num_trees());
485
486 self.peaks.reverse();
488 self.peaks.truncate(self.peaks.len() - peak_count);
490 self.peaks.push(new);
492 }
493
494 self.peaks.extend_from_slice(&delta.data[update_count..]);
498 self.forest = delta.forest;
499
500 debug_assert!(self.peaks.len() == self.forest.num_trees());
501
502 Ok(inserted_nodes)
503 }
504
505 fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
511 if node_index.is_leaf() {
512 self.nodes.contains_key(&node_index.sibling())
513 } else {
514 let left_child = node_index.left_child();
515 let right_child = node_index.right_child();
516 self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
517 }
518 }
519}
520
521impl From<MmrPeaks> for PartialMmr {
525 fn from(peaks: MmrPeaks) -> Self {
526 Self::from_peaks(peaks)
527 }
528}
529
530impl From<PartialMmr> for MmrPeaks {
531 fn from(partial_mmr: PartialMmr) -> Self {
532 MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks).unwrap()
535 }
536}
537
538impl From<&MmrPeaks> for PartialMmr {
539 fn from(peaks: &MmrPeaks) -> Self {
540 Self::from_peaks(peaks.clone())
541 }
542}
543
544impl From<&PartialMmr> for MmrPeaks {
545 fn from(partial_mmr: &PartialMmr) -> Self {
546 MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks.clone()).unwrap()
549 }
550}
551
552pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, Word)>> {
557 nodes: &'a NodeMap,
558 leaves: I,
559 stack: Vec<(InOrderIndex, Word)>,
560 seen_nodes: BTreeSet<InOrderIndex>,
561}
562
563impl<I: Iterator<Item = (usize, Word)>> Iterator for InnerNodeIterator<'_, I> {
564 type Item = InnerNodeInfo;
565
566 fn next(&mut self) -> Option<Self::Item> {
567 while let Some((idx, node)) = self.stack.pop() {
568 let parent_idx = idx.parent();
569 let new_node = self.seen_nodes.insert(parent_idx);
570
571 if new_node && let Some(sibling) = self.nodes.get(&idx.sibling()) {
574 let (left, right) = if parent_idx.left_child() == idx {
575 (node, *sibling)
576 } else {
577 (*sibling, node)
578 };
579 let parent = Rpo256::merge(&[left, right]);
580 let inner_node = InnerNodeInfo { value: parent, left, right };
581
582 self.stack.push((parent_idx, parent));
583 return Some(inner_node);
584 }
585
586 if let Some((pos, leaf)) = self.leaves.next() {
588 let idx = InOrderIndex::from_leaf_pos(pos);
589 self.stack.push((idx, leaf));
590 }
591 }
592
593 None
594 }
595}
596
597impl Serializable for PartialMmr {
598 fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
599 self.forest.num_leaves().write_into(target);
600 self.peaks.write_into(target);
601 self.nodes.write_into(target);
602 target.write_bool(self.track_latest);
603 }
604}
605
606impl Deserializable for PartialMmr {
607 fn read_from<R: winter_utils::ByteReader>(
608 source: &mut R,
609 ) -> Result<Self, winter_utils::DeserializationError> {
610 let forest = Forest::new(usize::read_from(source)?);
611 let peaks = Vec::<Word>::read_from(source)?;
612 let nodes = NodeMap::read_from(source)?;
613 let track_latest = source.read_bool()?;
614
615 Ok(Self { forest, peaks, nodes, track_latest })
616 }
617}
618
619#[cfg(test)]
623mod tests {
624 use alloc::{collections::BTreeSet, vec::Vec};
625
626 use winter_utils::{Deserializable, Serializable};
627
628 use super::{MmrPeaks, PartialMmr};
629 use crate::{
630 Word,
631 merkle::{
632 NodeIndex, int_to_node,
633 mmr::{Mmr, forest::Forest},
634 store::MerkleStore,
635 },
636 };
637
638 const LEAVES: [Word; 7] = [
639 int_to_node(0),
640 int_to_node(1),
641 int_to_node(2),
642 int_to_node(3),
643 int_to_node(4),
644 int_to_node(5),
645 int_to_node(6),
646 ];
647
648 #[test]
649 fn test_partial_mmr_apply_delta() {
650 let mut mmr = Mmr::default();
652 (0..10).for_each(|i| mmr.add(int_to_node(i)));
653 let mut partial_mmr: PartialMmr = mmr.peaks().into();
654
655 {
657 let node = mmr.get(1).unwrap();
658 let proof = mmr.open(1).unwrap();
659 partial_mmr.track(1, node, &proof.merkle_path).unwrap();
660 }
661
662 {
663 let node = mmr.get(8).unwrap();
664 let proof = mmr.open(8).unwrap();
665 partial_mmr.track(8, node, &proof.merkle_path).unwrap();
666 }
667
668 (10..12).for_each(|i| mmr.add(int_to_node(i)));
670 validate_apply_delta(&mmr, &mut partial_mmr);
671
672 mmr.add(int_to_node(12));
674 validate_apply_delta(&mmr, &mut partial_mmr);
675 {
676 let node = mmr.get(12).unwrap();
677 let proof = mmr.open(12).unwrap();
678 partial_mmr.track(12, node, &proof.merkle_path).unwrap();
679 assert!(partial_mmr.track_latest);
680 }
681
682 (13..16).for_each(|i| mmr.add(int_to_node(i)));
686 validate_apply_delta(&mmr, &mut partial_mmr);
687 }
688
689 fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
690 let tracked_leaves = partial
691 .nodes
692 .iter()
693 .filter(|&(index, _)| index.is_leaf())
694 .map(|(index, _)| index.sibling())
695 .collect::<Vec<_>>();
696 let nodes_before = partial.nodes.clone();
697
698 let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
700 let nodes_delta = partial.apply(delta).unwrap();
701
702 assert_eq!(mmr.peaks(), partial.peaks());
704
705 let mut expected_nodes = nodes_before;
706 for (key, value) in nodes_delta {
707 assert!(expected_nodes.insert(key, value).is_none());
709 }
710
711 assert_eq!(expected_nodes, partial.nodes);
713
714 for index in tracked_leaves {
716 let pos = index.inner() / 2;
717 let proof1 = partial.open(pos).unwrap().unwrap();
718 let proof2 = mmr.open(pos).unwrap();
719 assert_eq!(proof1, proof2);
720 }
721 }
722
723 #[test]
724 fn test_partial_mmr_inner_nodes_iterator() {
725 let mmr: Mmr = LEAVES.into();
727 let first_peak = mmr.peaks().peaks()[0];
728
729 let node1 = mmr.get(1).unwrap();
733 let proof1 = mmr.open(1).unwrap();
734
735 let mut partial_mmr: PartialMmr = mmr.peaks().into();
737 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
738
739 assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
741
742 let mut store: MerkleStore = MerkleStore::new();
744 store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
745
746 let index1 = NodeIndex::new(2, 1).unwrap();
747 let path1 = store.get_path(first_peak, index1).unwrap().path;
748
749 assert_eq!(path1, proof1.merkle_path);
750
751 let mut partial_mmr: PartialMmr = mmr.peaks().into();
755
756 let node0 = mmr.get(0).unwrap();
757 let proof0 = mmr.open(0).unwrap();
758
759 let node2 = mmr.get(2).unwrap();
760 let proof2 = mmr.open(2).unwrap();
761
762 partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
763 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
764 partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
765
766 let leaves = [(0, node0), (1, node1), (2, node2)];
768 let mut nodes = BTreeSet::new();
769 for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
770 assert!(nodes.insert(node.value));
771 }
772
773 store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
775
776 let index0 = NodeIndex::new(2, 0).unwrap();
777 let index1 = NodeIndex::new(2, 1).unwrap();
778 let index2 = NodeIndex::new(2, 2).unwrap();
779
780 let path0 = store.get_path(first_peak, index0).unwrap().path;
781 let path1 = store.get_path(first_peak, index1).unwrap().path;
782 let path2 = store.get_path(first_peak, index2).unwrap().path;
783
784 assert_eq!(path0, proof0.merkle_path);
785 assert_eq!(path1, proof1.merkle_path);
786 assert_eq!(path2, proof2.merkle_path);
787
788 let mut partial_mmr: PartialMmr = mmr.peaks().into();
792
793 let node5 = mmr.get(5).unwrap();
794 let proof5 = mmr.open(5).unwrap();
795
796 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
797 partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
798
799 let mut store: MerkleStore = MerkleStore::new();
801 store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
802
803 let index1 = NodeIndex::new(2, 1).unwrap();
804 let index5 = NodeIndex::new(1, 1).unwrap();
805
806 let second_peak = mmr.peaks().peaks()[1];
807
808 let path1 = store.get_path(first_peak, index1).unwrap().path;
809 let path5 = store.get_path(second_peak, index5).unwrap().path;
810
811 assert_eq!(path1, proof1.merkle_path);
812 assert_eq!(path5, proof5.merkle_path);
813 }
814
815 #[test]
816 fn test_partial_mmr_add_without_track() {
817 let mut mmr = Mmr::default();
818 let empty_peaks = MmrPeaks::new(Forest::empty(), vec![]).unwrap();
819 let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
820
821 for el in (0..256).map(int_to_node) {
822 mmr.add(el);
823 partial_mmr.add(el, false);
824
825 assert_eq!(mmr.peaks(), partial_mmr.peaks());
826 assert_eq!(mmr.forest(), partial_mmr.forest());
827 }
828 }
829
830 #[test]
831 fn test_partial_mmr_add_with_track() {
832 let mut mmr = Mmr::default();
833 let empty_peaks = MmrPeaks::new(Forest::empty(), vec![]).unwrap();
834 let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
835
836 for i in 0..256 {
837 let el = int_to_node(i as u64);
838 mmr.add(el);
839 partial_mmr.add(el, true);
840
841 assert_eq!(mmr.peaks(), partial_mmr.peaks());
842 assert_eq!(mmr.forest(), partial_mmr.forest());
843
844 for pos in 0..i {
845 let mmr_proof = mmr.open(pos).unwrap();
846 let partialmmr_proof = partial_mmr.open(pos).unwrap().unwrap();
847 assert_eq!(mmr_proof, partialmmr_proof);
848 }
849 }
850 }
851
852 #[test]
853 fn test_partial_mmr_add_existing_track() {
854 let mut mmr = Mmr::from((0..7).map(int_to_node));
855
856 let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
858 let path_to_5 = mmr.open(5).unwrap().merkle_path;
859 let leaf_at_5 = mmr.get(5).unwrap();
860 partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
861
862 let leaf_at_7 = int_to_node(7);
864 mmr.add(leaf_at_7);
865 partial_mmr.add(leaf_at_7, false);
866
867 assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
869 }
870
871 #[test]
872 fn test_partial_mmr_serialization() {
873 let mmr = Mmr::from((0..7).map(int_to_node));
874 let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
875
876 let bytes = partial_mmr.to_bytes();
877 let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
878
879 assert_eq!(partial_mmr, decoded);
880 }
881
882 #[test]
883 fn test_partial_mmr_untrack() {
884 let mmr: Mmr = LEAVES.into();
886
887 let node1 = mmr.get(1).unwrap();
889 let proof1 = mmr.open(1).unwrap();
890
891 let node2 = mmr.get(2).unwrap();
893 let proof2 = mmr.open(2).unwrap();
894
895 let mut partial_mmr: PartialMmr = mmr.peaks().into();
897 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
898 partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
899
900 partial_mmr.untrack(1);
902 partial_mmr.untrack(2);
903
904 assert!(!partial_mmr.is_tracked(1));
906 assert!(!partial_mmr.is_tracked(2));
907 assert_eq!(partial_mmr.nodes().count(), 0);
908 }
909}