1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 vec::Vec,
4};
5
6use winter_utils::{Deserializable, Serializable};
7
8use super::{MmrDelta, MmrProof};
9use crate::{
10 Word,
11 merkle::{
12 InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks, Rpo256, mmr::forest::Forest,
13 },
14};
15
16type NodeMap = BTreeMap<InOrderIndex, Word>;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct PartialMmr {
30 pub(crate) forest: Forest,
41
42 pub(crate) peaks: Vec<Word>,
54
55 pub(crate) nodes: NodeMap,
67
68 pub(crate) track_latest: bool,
73}
74
75impl Default for PartialMmr {
76 fn default() -> Self {
78 let forest = Forest::empty();
79 let peaks = Vec::new();
80 let nodes = BTreeMap::new();
81 let track_latest = false;
82
83 Self { forest, peaks, nodes, track_latest }
84 }
85}
86
87impl PartialMmr {
88 pub fn from_peaks(peaks: MmrPeaks) -> Self {
93 let forest = peaks.forest();
94 let peaks = peaks.into();
95 let nodes = BTreeMap::new();
96 let track_latest = false;
97
98 Self { forest, peaks, nodes, track_latest }
99 }
100
101 pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
106 let forest = peaks.forest();
107 let peaks = peaks.into();
108
109 Self { forest, peaks, nodes, track_latest }
110 }
111
112 pub fn forest(&self) -> Forest {
120 self.forest
121 }
122
123 pub fn num_leaves(&self) -> usize {
125 self.forest.num_leaves()
126 }
127
128 pub fn peaks(&self) -> MmrPeaks {
130 MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
133 }
134
135 pub fn is_tracked(&self, pos: usize) -> bool {
138 let leaves = self.forest.num_leaves();
139 if pos >= leaves {
140 return false;
141 } else if pos == leaves - 1 && self.forest.has_single_leaf_tree() {
142 return self.track_latest;
145 }
146
147 let leaf_index = InOrderIndex::from_leaf_pos(pos);
148 self.is_tracked_node(&leaf_index)
149 }
150
151 pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
162 let tree_bit = self
163 .forest
164 .leaf_to_corresponding_tree(pos)
165 .ok_or(MmrError::PositionNotFound(pos))?;
166 let depth = tree_bit as usize;
167
168 let mut nodes = Vec::with_capacity(depth);
169 let mut idx = InOrderIndex::from_leaf_pos(pos);
170
171 while let Some(node) = self.nodes.get(&idx.sibling()) {
172 nodes.push(*node);
173 idx = idx.parent();
174 }
175
176 debug_assert!(nodes.is_empty() || nodes.len() == depth);
178
179 if nodes.len() != depth {
180 Ok(None)
182 } else {
183 Ok(Some(MmrProof {
184 forest: self.forest,
185 position: pos,
186 merkle_path: MerklePath::new(nodes),
187 }))
188 }
189 }
190
191 pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &Word)> {
196 self.nodes.iter()
197 }
198
199 pub fn inner_nodes<'a, I: Iterator<Item = (usize, Word)> + 'a>(
204 &'a self,
205 mut leaves: I,
206 ) -> impl Iterator<Item = InnerNodeInfo> + 'a {
207 let stack = if let Some((pos, leaf)) = leaves.next() {
208 let idx = InOrderIndex::from_leaf_pos(pos);
209 vec![(idx, leaf)]
210 } else {
211 Vec::new()
212 };
213
214 InnerNodeIterator {
215 nodes: &self.nodes,
216 leaves,
217 stack,
218 seen_nodes: BTreeSet::new(),
219 }
220 }
221
222 pub fn add(&mut self, leaf: Word, track: bool) -> Vec<(InOrderIndex, Word)> {
230 self.forest.append_leaf();
231 let merges = self.forest.smallest_tree_height_unchecked();
233 let mut new_nodes = Vec::with_capacity(merges);
234
235 let peak = if merges == 0 {
236 self.track_latest = track;
237 leaf
238 } else {
239 let mut track_right = track;
240 let mut track_left = self.track_latest;
241
242 let mut right = leaf;
243 let mut right_idx = self.forest.rightmost_in_order_index();
244
245 for _ in 0..merges {
246 let left = self.peaks.pop().expect("Missing peak");
247 let left_idx = right_idx.sibling();
248
249 if track_right {
250 let old = self.nodes.insert(left_idx, left);
251 new_nodes.push((left_idx, left));
252
253 debug_assert!(
254 old.is_none(),
255 "Idx {left_idx:?} already contained an element {old:?}",
256 );
257 };
258 if track_left {
259 let old = self.nodes.insert(right_idx, right);
260 new_nodes.push((right_idx, right));
261
262 debug_assert!(
263 old.is_none(),
264 "Idx {right_idx:?} already contained an element {old:?}",
265 );
266 };
267
268 right_idx = right_idx.parent();
273
274 right = Rpo256::merge(&[left, right]);
277
278 track_right = track_right || track_left;
282
283 track_left = self.is_tracked_node(&right_idx.sibling());
286 }
287 right
288 };
289
290 self.peaks.push(peak);
291
292 new_nodes
293 }
294
295 pub fn track(
306 &mut self,
307 leaf_pos: usize,
308 leaf: Word,
309 path: &MerklePath,
310 ) -> Result<(), MmrError> {
311 let tree = Forest::new(1 << path.depth());
314 if (tree & self.forest).is_empty() {
315 return Err(MmrError::UnknownPeak(path.depth()));
316 };
317
318 if leaf_pos + 1 == self.forest.num_leaves()
319 && path.depth() == 0
320 && self.peaks.last().is_some_and(|v| *v == leaf)
321 {
322 self.track_latest = true;
323 return Ok(());
324 }
325
326 let target_forest = self.forest ^ (self.forest & tree.all_smaller_trees_unchecked());
329 let peak_pos = target_forest.num_trees() - 1;
330
331 let path_idx = leaf_pos - (target_forest ^ tree).num_leaves();
333
334 let computed = path
337 .compute_root(path_idx as u64, leaf)
338 .map_err(MmrError::MerkleRootComputationFailed)?;
339 if self.peaks[peak_pos] != computed {
340 return Err(MmrError::PeakPathMismatch);
341 }
342
343 let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
344 for leaf in path.nodes() {
345 self.nodes.insert(idx.sibling(), *leaf);
346 idx = idx.parent();
347 }
348
349 Ok(())
350 }
351
352 pub fn untrack(&mut self, leaf_pos: usize) {
356 let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
357
358 while self.nodes.remove(&idx.sibling()).is_some() && !self.nodes.contains_key(&idx) {
364 idx = idx.parent();
365 }
366 }
367
368 pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, Word)>, MmrError> {
371 if delta.forest < self.forest {
372 return Err(MmrError::InvalidPeaks(format!(
373 "forest of mmr delta {} is less than current forest {}",
374 delta.forest, self.forest
375 )));
376 }
377
378 let mut inserted_nodes = Vec::new();
379
380 if delta.forest == self.forest {
381 if !delta.data.is_empty() {
382 return Err(MmrError::InvalidUpdate);
383 }
384
385 return Ok(inserted_nodes);
386 }
387
388 let changes = self.forest ^ delta.forest;
390 let largest = changes.largest_tree_unchecked();
393 let merges = self.forest & largest.all_smaller_trees_unchecked();
395
396 debug_assert!(
397 !self.track_latest || merges.has_single_leaf_tree(),
398 "if there is an odd element, a merge is required"
399 );
400
401 let (merge_count, new_peaks) = if !merges.is_empty() {
403 let depth = largest.smallest_tree_height_unchecked();
404 let skipped = merges.smallest_tree_height_unchecked();
406 let computed = merges.num_trees() - 1;
407 let merge_count = depth - skipped - computed;
408
409 let new_peaks = delta.forest & largest.all_smaller_trees_unchecked();
410
411 (merge_count, new_peaks)
412 } else {
413 (0, changes)
414 };
415
416 if delta.data.len() != merge_count + new_peaks.num_trees() {
418 return Err(MmrError::InvalidUpdate);
419 }
420
421 let mut update_count = 0;
423
424 if !merges.is_empty() {
425 let mut peak_idx = self.forest.root_in_order_index();
427
428 self.peaks.reverse();
430
431 let mut track = self.track_latest;
433 self.track_latest = false;
434
435 let mut peak_count = 0;
436 let mut target = merges.smallest_tree_unchecked();
437 let mut new = delta.data[0];
438 update_count += 1;
439
440 while target < largest {
441 if target != Forest::new(1) && !track {
444 track = self.is_tracked_node(&peak_idx);
445 }
446
447 let (left, right) = if !(target & merges).is_empty() {
450 let peak = self.peaks[peak_count];
451 let sibling_idx = peak_idx.sibling();
452
453 if self.is_tracked_node(&sibling_idx) {
456 self.nodes.insert(peak_idx, new);
457 inserted_nodes.push((peak_idx, new));
458 }
459 peak_count += 1;
460 (peak, new)
461 } else {
462 let update = delta.data[update_count];
463 update_count += 1;
464 (new, update)
465 };
466
467 if track {
468 let sibling_idx = peak_idx.sibling();
469 if peak_idx.is_left_child() {
470 self.nodes.insert(sibling_idx, right);
471 inserted_nodes.push((sibling_idx, right));
472 } else {
473 self.nodes.insert(sibling_idx, left);
474 inserted_nodes.push((sibling_idx, left));
475 }
476 }
477
478 peak_idx = peak_idx.parent();
479 new = Rpo256::merge(&[left, right]);
480 target = target.next_larger_tree();
481 }
482
483 debug_assert!(peak_count == merges.num_trees());
484
485 self.peaks.reverse();
487 self.peaks.truncate(self.peaks.len() - peak_count);
489 self.peaks.push(new);
491 }
492
493 self.peaks.extend_from_slice(&delta.data[update_count..]);
497 self.forest = delta.forest;
498
499 debug_assert!(self.peaks.len() == self.forest.num_trees());
500
501 Ok(inserted_nodes)
502 }
503
504 fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
510 if node_index.is_leaf() {
511 self.nodes.contains_key(&node_index.sibling())
512 } else {
513 let left_child = node_index.left_child();
514 let right_child = node_index.right_child();
515 self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
516 }
517 }
518}
519
520impl From<MmrPeaks> for PartialMmr {
524 fn from(peaks: MmrPeaks) -> Self {
525 Self::from_peaks(peaks)
526 }
527}
528
529impl From<PartialMmr> for MmrPeaks {
530 fn from(partial_mmr: PartialMmr) -> Self {
531 MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks).unwrap()
534 }
535}
536
537impl From<&MmrPeaks> for PartialMmr {
538 fn from(peaks: &MmrPeaks) -> Self {
539 Self::from_peaks(peaks.clone())
540 }
541}
542
543impl From<&PartialMmr> for MmrPeaks {
544 fn from(partial_mmr: &PartialMmr) -> Self {
545 MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks.clone()).unwrap()
548 }
549}
550
551pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, Word)>> {
556 nodes: &'a NodeMap,
557 leaves: I,
558 stack: Vec<(InOrderIndex, Word)>,
559 seen_nodes: BTreeSet<InOrderIndex>,
560}
561
562impl<I: Iterator<Item = (usize, Word)>> Iterator for InnerNodeIterator<'_, I> {
563 type Item = InnerNodeInfo;
564
565 fn next(&mut self) -> Option<Self::Item> {
566 while let Some((idx, node)) = self.stack.pop() {
567 let parent_idx = idx.parent();
568 let new_node = self.seen_nodes.insert(parent_idx);
569
570 if new_node && let Some(sibling) = self.nodes.get(&idx.sibling()) {
573 let (left, right) = if parent_idx.left_child() == idx {
574 (node, *sibling)
575 } else {
576 (*sibling, node)
577 };
578 let parent = Rpo256::merge(&[left, right]);
579 let inner_node = InnerNodeInfo { value: parent, left, right };
580
581 self.stack.push((parent_idx, parent));
582 return Some(inner_node);
583 }
584
585 if let Some((pos, leaf)) = self.leaves.next() {
587 let idx = InOrderIndex::from_leaf_pos(pos);
588 self.stack.push((idx, leaf));
589 }
590 }
591
592 None
593 }
594}
595
596impl Serializable for PartialMmr {
597 fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
598 self.forest.num_leaves().write_into(target);
599 self.peaks.write_into(target);
600 self.nodes.write_into(target);
601 target.write_bool(self.track_latest);
602 }
603}
604
605impl Deserializable for PartialMmr {
606 fn read_from<R: winter_utils::ByteReader>(
607 source: &mut R,
608 ) -> Result<Self, winter_utils::DeserializationError> {
609 let forest = Forest::new(usize::read_from(source)?);
610 let peaks = Vec::<Word>::read_from(source)?;
611 let nodes = NodeMap::read_from(source)?;
612 let track_latest = source.read_bool()?;
613
614 Ok(Self { forest, peaks, nodes, track_latest })
615 }
616}
617
618#[cfg(test)]
622mod tests {
623 use alloc::{collections::BTreeSet, vec::Vec};
624
625 use winter_utils::{Deserializable, Serializable};
626
627 use super::{MmrPeaks, PartialMmr};
628 use crate::{
629 Word,
630 merkle::{MerkleStore, Mmr, NodeIndex, int_to_node, mmr::forest::Forest},
631 };
632
633 const LEAVES: [Word; 7] = [
634 int_to_node(0),
635 int_to_node(1),
636 int_to_node(2),
637 int_to_node(3),
638 int_to_node(4),
639 int_to_node(5),
640 int_to_node(6),
641 ];
642
643 #[test]
644 fn test_partial_mmr_apply_delta() {
645 let mut mmr = Mmr::default();
647 (0..10).for_each(|i| mmr.add(int_to_node(i)));
648 let mut partial_mmr: PartialMmr = mmr.peaks().into();
649
650 {
652 let node = mmr.get(1).unwrap();
653 let proof = mmr.open(1).unwrap();
654 partial_mmr.track(1, node, &proof.merkle_path).unwrap();
655 }
656
657 {
658 let node = mmr.get(8).unwrap();
659 let proof = mmr.open(8).unwrap();
660 partial_mmr.track(8, node, &proof.merkle_path).unwrap();
661 }
662
663 (10..12).for_each(|i| mmr.add(int_to_node(i)));
665 validate_apply_delta(&mmr, &mut partial_mmr);
666
667 mmr.add(int_to_node(12));
669 validate_apply_delta(&mmr, &mut partial_mmr);
670 {
671 let node = mmr.get(12).unwrap();
672 let proof = mmr.open(12).unwrap();
673 partial_mmr.track(12, node, &proof.merkle_path).unwrap();
674 assert!(partial_mmr.track_latest);
675 }
676
677 (13..16).for_each(|i| mmr.add(int_to_node(i)));
681 validate_apply_delta(&mmr, &mut partial_mmr);
682 }
683
684 fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
685 let tracked_leaves = partial
686 .nodes
687 .iter()
688 .filter(|&(index, _)| index.is_leaf())
689 .map(|(index, _)| index.sibling())
690 .collect::<Vec<_>>();
691 let nodes_before = partial.nodes.clone();
692
693 let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
695 let nodes_delta = partial.apply(delta).unwrap();
696
697 assert_eq!(mmr.peaks(), partial.peaks());
699
700 let mut expected_nodes = nodes_before;
701 for (key, value) in nodes_delta {
702 assert!(expected_nodes.insert(key, value).is_none());
704 }
705
706 assert_eq!(expected_nodes, partial.nodes);
708
709 for index in tracked_leaves {
711 let pos = index.inner() / 2;
712 let proof1 = partial.open(pos).unwrap().unwrap();
713 let proof2 = mmr.open(pos).unwrap();
714 assert_eq!(proof1, proof2);
715 }
716 }
717
718 #[test]
719 fn test_partial_mmr_inner_nodes_iterator() {
720 let mmr: Mmr = LEAVES.into();
722 let first_peak = mmr.peaks().peaks()[0];
723
724 let node1 = mmr.get(1).unwrap();
728 let proof1 = mmr.open(1).unwrap();
729
730 let mut partial_mmr: PartialMmr = mmr.peaks().into();
732 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
733
734 assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
736
737 let mut store: MerkleStore = MerkleStore::new();
739 store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
740
741 let index1 = NodeIndex::new(2, 1).unwrap();
742 let path1 = store.get_path(first_peak, index1).unwrap().path;
743
744 assert_eq!(path1, proof1.merkle_path);
745
746 let mut partial_mmr: PartialMmr = mmr.peaks().into();
750
751 let node0 = mmr.get(0).unwrap();
752 let proof0 = mmr.open(0).unwrap();
753
754 let node2 = mmr.get(2).unwrap();
755 let proof2 = mmr.open(2).unwrap();
756
757 partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
758 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
759 partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
760
761 let leaves = [(0, node0), (1, node1), (2, node2)];
763 let mut nodes = BTreeSet::new();
764 for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
765 assert!(nodes.insert(node.value));
766 }
767
768 store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
770
771 let index0 = NodeIndex::new(2, 0).unwrap();
772 let index1 = NodeIndex::new(2, 1).unwrap();
773 let index2 = NodeIndex::new(2, 2).unwrap();
774
775 let path0 = store.get_path(first_peak, index0).unwrap().path;
776 let path1 = store.get_path(first_peak, index1).unwrap().path;
777 let path2 = store.get_path(first_peak, index2).unwrap().path;
778
779 assert_eq!(path0, proof0.merkle_path);
780 assert_eq!(path1, proof1.merkle_path);
781 assert_eq!(path2, proof2.merkle_path);
782
783 let mut partial_mmr: PartialMmr = mmr.peaks().into();
787
788 let node5 = mmr.get(5).unwrap();
789 let proof5 = mmr.open(5).unwrap();
790
791 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
792 partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
793
794 let mut store: MerkleStore = MerkleStore::new();
796 store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
797
798 let index1 = NodeIndex::new(2, 1).unwrap();
799 let index5 = NodeIndex::new(1, 1).unwrap();
800
801 let second_peak = mmr.peaks().peaks()[1];
802
803 let path1 = store.get_path(first_peak, index1).unwrap().path;
804 let path5 = store.get_path(second_peak, index5).unwrap().path;
805
806 assert_eq!(path1, proof1.merkle_path);
807 assert_eq!(path5, proof5.merkle_path);
808 }
809
810 #[test]
811 fn test_partial_mmr_add_without_track() {
812 let mut mmr = Mmr::default();
813 let empty_peaks = MmrPeaks::new(Forest::empty(), vec![]).unwrap();
814 let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
815
816 for el in (0..256).map(int_to_node) {
817 mmr.add(el);
818 partial_mmr.add(el, false);
819
820 assert_eq!(mmr.peaks(), partial_mmr.peaks());
821 assert_eq!(mmr.forest(), partial_mmr.forest());
822 }
823 }
824
825 #[test]
826 fn test_partial_mmr_add_with_track() {
827 let mut mmr = Mmr::default();
828 let empty_peaks = MmrPeaks::new(Forest::empty(), vec![]).unwrap();
829 let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
830
831 for i in 0..256 {
832 let el = int_to_node(i as u64);
833 mmr.add(el);
834 partial_mmr.add(el, true);
835
836 assert_eq!(mmr.peaks(), partial_mmr.peaks());
837 assert_eq!(mmr.forest(), partial_mmr.forest());
838
839 for pos in 0..i {
840 let mmr_proof = mmr.open(pos).unwrap();
841 let partialmmr_proof = partial_mmr.open(pos).unwrap().unwrap();
842 assert_eq!(mmr_proof, partialmmr_proof);
843 }
844 }
845 }
846
847 #[test]
848 fn test_partial_mmr_add_existing_track() {
849 let mut mmr = Mmr::from((0..7).map(int_to_node));
850
851 let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
853 let path_to_5 = mmr.open(5).unwrap().merkle_path;
854 let leaf_at_5 = mmr.get(5).unwrap();
855 partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
856
857 let leaf_at_7 = int_to_node(7);
859 mmr.add(leaf_at_7);
860 partial_mmr.add(leaf_at_7, false);
861
862 assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
864 }
865
866 #[test]
867 fn test_partial_mmr_serialization() {
868 let mmr = Mmr::from((0..7).map(int_to_node));
869 let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
870
871 let bytes = partial_mmr.to_bytes();
872 let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
873
874 assert_eq!(partial_mmr, decoded);
875 }
876
877 #[test]
878 fn test_partial_mmr_untrack() {
879 let mmr: Mmr = LEAVES.into();
881
882 let node1 = mmr.get(1).unwrap();
884 let proof1 = mmr.open(1).unwrap();
885
886 let node2 = mmr.get(2).unwrap();
888 let proof2 = mmr.open(2).unwrap();
889
890 let mut partial_mmr: PartialMmr = mmr.peaks().into();
892 partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
893 partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
894
895 partial_mmr.untrack(1);
897 partial_mmr.untrack(2);
898
899 assert!(!partial_mmr.is_tracked(1));
901 assert!(!partial_mmr.is_tracked(2));
902 assert_eq!(partial_mmr.nodes().count(), 0);
903 }
904}