1use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
2
3use super::{LeafIndex, SMT_DEPTH};
4use crate::{
5 EMPTY_WORD, Word,
6 merkle::{
7 InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath,
8 smt::{InnerNode, InnerNodes, Leaves, Smt, SmtLeaf, SmtProof, SparseMerkleTree},
9 },
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
29#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
30pub struct PartialSmt(Smt);
31
32impl PartialSmt {
33 pub fn new(root: Word) -> Self {
40 let mut partial_smt = Self(Smt::default());
41
42 partial_smt.0.set_root(root);
43
44 partial_smt
45 }
46
47 pub fn from_proofs<I>(proofs: I) -> Result<Self, MerkleError>
57 where
58 I: IntoIterator<Item = SmtProof>,
59 {
60 let mut proofs = proofs.into_iter();
61
62 let Some(first_proof) = proofs.next() else {
63 return Ok(Self::default());
64 };
65
66 let mut partial_smt = Self::default();
70 let (path, leaf) = first_proof.into_parts();
71 let path_root = partial_smt.add_path_unchecked(leaf, path);
72 partial_smt.0.set_root(path_root);
73
74 for proof in proofs {
75 partial_smt.add_proof(proof)?;
76 }
77
78 Ok(partial_smt)
79 }
80
81 pub fn root(&self) -> Word {
86 self.0.root()
87 }
88
89 pub fn open(&self, key: &Word) -> Result<SmtProof, MerkleError> {
97 if !self.is_leaf_tracked(key) {
98 return Err(MerkleError::UntrackedKey(*key));
99 }
100
101 Ok(self.0.open(key))
102 }
103
104 pub fn get_leaf(&self, key: &Word) -> Result<SmtLeaf, MerkleError> {
111 if !self.is_leaf_tracked(key) {
112 return Err(MerkleError::UntrackedKey(*key));
113 }
114
115 Ok(self.0.get_leaf(key))
116 }
117
118 pub fn get_value(&self, key: &Word) -> Result<Word, MerkleError> {
125 if !self.is_leaf_tracked(key) {
126 return Err(MerkleError::UntrackedKey(*key));
127 }
128
129 Ok(self.0.get_value(key))
130 }
131
132 pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
151 if !self.is_leaf_tracked(&key) {
152 return Err(MerkleError::UntrackedKey(key));
153 }
154
155 let previous_value = self.0.insert(key, value)?;
156
157 if value == EMPTY_WORD {
161 let leaf_index = Smt::key_to_leaf_index(&key);
162 self.0.leaves.insert(leaf_index.value(), SmtLeaf::Empty(leaf_index));
163 }
164
165 Ok(previous_value)
166 }
167
168 pub fn add_proof(&mut self, proof: SmtProof) -> Result<(), MerkleError> {
173 let (path, leaf) = proof.into_parts();
174 self.add_path(leaf, path)
175 }
176
177 pub fn add_path(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Result<(), MerkleError> {
188 let path_root = self.add_path_unchecked(leaf, path);
189
190 if self.root() != path_root {
193 return Err(MerkleError::ConflictingRoots {
194 expected_root: self.root(),
195 actual_root: path_root,
196 });
197 }
198
199 Ok(())
200 }
201
202 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
204 self.0.inner_nodes()
205 }
206
207 pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
210 self.0.inner_node_indices()
211 }
212
213 pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
216 self.0.leaves().filter_map(
218 |(leaf_idx, leaf)| {
219 if leaf.is_empty() { None } else { Some((leaf_idx, leaf)) }
220 },
221 )
222 }
223
224 pub fn tracked_leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
228 self.0.leaves()
229 }
230
231 pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
234 self.0.entries()
235 }
236
237 pub fn num_leaves(&self) -> usize {
242 self.0.num_leaves()
243 }
244
245 pub fn num_entries(&self) -> usize {
250 self.0.num_entries()
251 }
252
253 pub fn tracks_leaves(&self) -> bool {
259 !self.0.leaves.is_empty()
260 }
261
262 fn add_path_unchecked(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Word {
272 let mut current_index = leaf.index().index;
273
274 let mut node_hash_at_current_index = leaf.hash();
275
276 let prev_entries = self
285 .0
286 .leaves
287 .get(¤t_index.value())
288 .map(|leaf| leaf.num_entries())
289 .unwrap_or(0);
290 let current_entries = leaf.num_entries();
291 self.0.leaves.insert(current_index.value(), leaf);
292
293 self.0.num_entries = self.0.num_entries + current_entries - prev_entries;
295
296 for sibling_hash in path {
297 let is_sibling_right = current_index.sibling().is_value_odd();
299
300 current_index.move_up();
302
303 let new_parent_node = if is_sibling_right {
306 InnerNode {
307 left: node_hash_at_current_index,
308 right: sibling_hash,
309 }
310 } else {
311 InnerNode {
312 left: sibling_hash,
313 right: node_hash_at_current_index,
314 }
315 };
316
317 self.0.insert_inner_node(current_index, new_parent_node);
318
319 node_hash_at_current_index = self.0.get_inner_node(current_index).hash();
320 }
321
322 node_hash_at_current_index
323 }
324
325 fn is_leaf_tracked(&self, key: &Word) -> bool {
330 self.0.leaves.contains_key(&Smt::key_to_leaf_index(key).value())
331 }
332}
333
334impl Default for PartialSmt {
335 fn default() -> Self {
339 Self::new(Smt::EMPTY_ROOT)
340 }
341}
342
343impl From<Smt> for PartialSmt {
347 fn from(smt: Smt) -> Self {
348 PartialSmt(smt)
349 }
350}
351
352impl Serializable for PartialSmt {
356 fn write_into<W: ByteWriter>(&self, target: &mut W) {
357 target.write(self.root());
358 target.write_usize(self.0.leaves.len());
359 for (i, leaf) in &self.0.leaves {
360 target.write_u64(*i);
361 target.write(leaf);
362 }
363 target.write_usize(self.0.inner_nodes.len());
364 for (idx, node) in &self.0.inner_nodes {
365 target.write(idx);
366 target.write(node);
367 }
368 }
369}
370
371impl Deserializable for PartialSmt {
372 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
373 let root: Word = source.read()?;
374
375 let mut leaves = Leaves::default();
376 for _ in 0..source.read_usize()? {
377 let pos: u64 = source.read()?;
378 let leaf: SmtLeaf = source.read()?;
379 leaves.insert(pos, leaf);
380 }
381
382 let mut nodes = InnerNodes::default();
383 for _ in 0..source.read_usize()? {
384 let idx: NodeIndex = source.read()?;
385 let node: InnerNode = source.read()?;
386 nodes.insert(idx, node);
387 }
388
389 let smt = if leaves.is_empty() {
393 let inner_node_root =
394 nodes.get(&NodeIndex::root()).map(InnerNode::hash).unwrap_or(Smt::EMPTY_ROOT);
395 let mut smt = Smt::from_raw_parts(nodes, leaves, inner_node_root);
396 smt.set_root(root);
397 smt
398 } else {
399 Smt::from_raw_parts(nodes, leaves, root)
401 };
402
403 Ok(PartialSmt(smt))
404 }
405}
406
407#[cfg(test)]
411mod tests {
412
413 use alloc::collections::{BTreeMap, BTreeSet};
414
415 use assert_matches::assert_matches;
416 use rand_utils::{rand_array, rand_value};
417 use winter_math::fields::f64::BaseElement as Felt;
418
419 use super::*;
420 use crate::{EMPTY_WORD, ONE, ZERO, merkle::EmptySubtreeRoots};
421
422 #[test]
425 fn partial_smt_new_with_no_entries() {
426 let key0 = Word::from(rand_array::<Felt, 4>());
427 let value0 = Word::from(rand_array::<Felt, 4>());
428 let full = Smt::with_entries([(key0, value0)]).unwrap();
429
430 let partial_smt = PartialSmt::new(full.root());
431
432 assert!(!partial_smt.tracks_leaves());
433 assert_eq!(partial_smt.num_entries(), 0);
434 assert_eq!(partial_smt.num_leaves(), 0);
435 assert_eq!(partial_smt.entries().count(), 0);
436 assert_eq!(partial_smt.tracked_leaves().count(), 0);
437 assert_eq!(partial_smt.root(), full.root());
438 }
439
440 #[test]
444 fn partial_smt_insert_and_remove() {
445 let key0 = Word::from(rand_array::<Felt, 4>());
446 let key1 = Word::from(rand_array::<Felt, 4>());
447 let key2 = Word::from(rand_array::<Felt, 4>());
448 let key_empty = Word::from(rand_array::<Felt, 4>());
450
451 let value0 = Word::from(rand_array::<Felt, 4>());
452 let value1 = Word::from(rand_array::<Felt, 4>());
453 let value2 = Word::from(rand_array::<Felt, 4>());
454
455 let mut kv_pairs = vec![(key0, value0), (key1, value1), (key2, value2)];
456
457 kv_pairs.reserve(1000);
459 for _ in 0..1000 {
460 let key = Word::from(rand_array::<Felt, 4>());
461 let value = Word::from(rand_array::<Felt, 4>());
462 kv_pairs.push((key, value));
463 }
464
465 let mut full = Smt::with_entries(kv_pairs).unwrap();
466
467 let proof0 = full.open(&key0);
471 let proof2 = full.open(&key2);
472 let proof_empty = full.open(&key_empty);
473
474 assert!(proof_empty.leaf().is_empty());
475
476 let mut partial = PartialSmt::from_proofs([proof0, proof2, proof_empty]).unwrap();
477
478 assert_eq!(full.root(), partial.root());
479 assert_eq!(partial.get_value(&key0).unwrap(), value0);
480 let error = partial.get_value(&key1).unwrap_err();
481 assert_matches!(error, MerkleError::UntrackedKey(_));
482 assert_eq!(partial.get_value(&key2).unwrap(), value2);
483
484 let new_value0 = Word::from(rand_array::<Felt, 4>());
488 let new_value2 = Word::from(rand_array::<Felt, 4>());
489 let new_value_empty_key = Word::from(rand_array::<Felt, 4>());
491
492 full.insert(key0, new_value0).unwrap();
493 full.insert(key2, new_value2).unwrap();
494 full.insert(key_empty, new_value_empty_key).unwrap();
495
496 partial.insert(key0, new_value0).unwrap();
497 partial.insert(key2, new_value2).unwrap();
498 partial.insert(key_empty, new_value_empty_key).unwrap();
500
501 assert_eq!(full.root(), partial.root());
502 assert_eq!(partial.get_value(&key0).unwrap(), new_value0);
503 assert_eq!(partial.get_value(&key2).unwrap(), new_value2);
504 assert_eq!(partial.get_value(&key_empty).unwrap(), new_value_empty_key);
505
506 full.insert(key0, EMPTY_WORD).unwrap();
510 partial.insert(key0, EMPTY_WORD).unwrap();
511
512 assert_eq!(full.root(), partial.root());
513 assert_eq!(partial.get_value(&key0).unwrap(), EMPTY_WORD);
514
515 assert_eq!(full.open(&key0), partial.open(&key0).unwrap());
520 assert_eq!(full.open(&key2), partial.open(&key2).unwrap());
522
523 let error = partial.clone().insert(key1, Word::from(rand_array::<Felt, 4>())).unwrap_err();
527 assert_matches!(error, MerkleError::UntrackedKey(_));
528
529 let error = partial.insert(key1, EMPTY_WORD).unwrap_err();
530 assert_matches!(error, MerkleError::UntrackedKey(_));
531 }
532
533 #[test]
535 fn partial_smt_multiple_leaf_success() {
536 let key0 = Word::from([ZERO, ZERO, ZERO, ONE]);
538 let key1 = Word::from([ONE, ONE, ONE, ONE]);
539 let key2 = Word::from(rand_array::<Felt, 4>());
540
541 let value0 = Word::from(rand_array::<Felt, 4>());
542 let value1 = Word::from(rand_array::<Felt, 4>());
543 let value2 = Word::from(rand_array::<Felt, 4>());
544
545 let full = Smt::with_entries([(key0, value0), (key1, value1), (key2, value2)]).unwrap();
546
547 let SmtLeaf::Multiple(_) = full.get_leaf(&key0) else {
549 panic!("expected full tree to produce multiple leaf")
550 };
551
552 let proof0 = full.open(&key0);
553 let proof2 = full.open(&key2);
554
555 let partial = PartialSmt::from_proofs([proof0, proof2]).unwrap();
556
557 assert_eq!(partial.root(), full.root());
558
559 assert_eq!(partial.get_leaf(&key0).unwrap(), full.get_leaf(&key0));
560 assert_eq!(partial.get_leaf(&key1).unwrap(), full.get_leaf(&key1));
562 assert_eq!(partial.get_leaf(&key2).unwrap(), full.get_leaf(&key2));
563 }
564
565 #[test]
570 fn partial_smt_root_mismatch_on_empty_values() {
571 let key0 = Word::from(rand_array::<Felt, 4>());
572 let key1 = Word::from(rand_array::<Felt, 4>());
573 let key2 = Word::from(rand_array::<Felt, 4>());
574
575 let value0 = EMPTY_WORD;
576 let value1 = Word::from(rand_array::<Felt, 4>());
577 let value2 = EMPTY_WORD;
578
579 let kv_pairs = vec![(key0, value0)];
580
581 let mut full = Smt::with_entries(kv_pairs).unwrap();
582
583 let stale_proof = full.open(&key2);
585
586 full.insert(key1, value1).unwrap();
588 full.insert(key2, value2).unwrap();
589
590 let mut partial = PartialSmt::new(full.root());
592
593 let err = partial.add_proof(stale_proof).unwrap_err();
595 assert_matches!(err, MerkleError::ConflictingRoots { .. });
596 }
597
598 #[test]
603 fn partial_smt_root_mismatch_on_non_empty_values() {
604 let key0 = Word::new(rand_array());
605 let key1 = Word::new(rand_array());
606 let key2 = Word::new(rand_array());
607
608 let value0 = Word::new(rand_array());
609 let value1 = Word::new(rand_array());
610 let value2 = Word::new(rand_array());
611
612 let kv_pairs = vec![(key0, value0), (key1, value1)];
613
614 let mut full = Smt::with_entries(kv_pairs).unwrap();
615
616 let stale_proof = full.open(&key0);
618
619 full.insert(key2, value2).unwrap();
621
622 let mut partial = PartialSmt::new(full.root());
624
625 let err = partial.add_proof(stale_proof).unwrap_err();
627 assert_matches!(err, MerkleError::ConflictingRoots { .. });
628 }
629
630 #[test]
632 fn partial_smt_from_proofs_fails_on_root_mismatch() {
633 let key0 = Word::new(rand_array());
634 let key1 = Word::new(rand_array());
635
636 let value0 = Word::new(rand_array());
637 let value1 = Word::new(rand_array());
638
639 let mut full = Smt::with_entries([(key0, value0)]).unwrap();
640
641 let stale_proof = full.open(&key0);
643
644 full.insert(key1, value1).unwrap();
646
647 let err = PartialSmt::from_proofs([full.open(&key1), stale_proof]).unwrap_err();
649 assert_matches!(err, MerkleError::ConflictingRoots { .. });
650 }
651
652 #[test]
654 fn partial_smt_iterator_apis() {
655 let key0 = Word::new(rand_array());
656 let key1 = Word::new(rand_array());
657 let key2 = Word::new(rand_array());
658 let key_empty = Word::new(rand_array());
660
661 let value0 = Word::new(rand_array());
662 let value1 = Word::new(rand_array());
663 let value2 = Word::new(rand_array());
664
665 let mut kv_pairs = vec![(key0, value0), (key1, value1), (key2, value2)];
666
667 kv_pairs.reserve(1000);
669 for _ in 0..1000 {
670 let key = Word::new(rand_array());
671 let value = Word::new(rand_array());
672 kv_pairs.push((key, value));
673 }
674
675 let full = Smt::with_entries(kv_pairs).unwrap();
676
677 let proof0 = full.open(&key0);
681 let proof2 = full.open(&key2);
682 let proof_empty = full.open(&key_empty);
683
684 assert!(proof_empty.leaf().is_empty());
685
686 let proofs = [proof0, proof2, proof_empty];
687 let partial = PartialSmt::from_proofs(proofs.clone()).unwrap();
688
689 assert!(partial.tracks_leaves());
690 assert_eq!(full.root(), partial.root());
691 assert_eq!(partial.num_entries(), 2);
693 assert_eq!(partial.num_leaves(), 3);
695
696 let expected_leaves: BTreeMap<_, _> =
701 [SmtLeaf::new_single(key0, value0), SmtLeaf::new_single(key2, value2)]
702 .into_iter()
703 .map(|leaf| (leaf.index(), leaf))
704 .collect();
705
706 let actual_leaves = partial
707 .leaves()
708 .map(|(idx, leaf)| (idx, leaf.clone()))
709 .collect::<BTreeMap<_, _>>();
710
711 assert_eq!(actual_leaves.len(), expected_leaves.len());
712 assert_eq!(actual_leaves, expected_leaves);
713
714 let mut expected_tracked_leaves = expected_leaves;
718 let empty_leaf = SmtLeaf::new_empty(LeafIndex::from(key_empty));
719 expected_tracked_leaves.insert(empty_leaf.index(), empty_leaf);
720
721 let actual_tracked_leaves = partial
722 .tracked_leaves()
723 .map(|(idx, leaf)| (idx, leaf.clone()))
724 .collect::<BTreeMap<_, _>>();
725
726 assert_eq!(actual_tracked_leaves.len(), expected_tracked_leaves.len());
727 assert_eq!(actual_tracked_leaves, expected_tracked_leaves);
728
729 let partial_inner_nodes: BTreeSet<_> =
734 partial.inner_nodes().flat_map(|node| [node.left, node.right]).collect();
735 let empty_subtree_roots: BTreeSet<_> = (0..SMT_DEPTH)
736 .map(|depth| *EmptySubtreeRoots::entry(SMT_DEPTH, depth))
737 .collect();
738
739 for merkle_path in proofs.into_iter().map(|proof| proof.into_parts().0) {
740 for (idx, digest) in merkle_path.into_iter().enumerate() {
741 assert!(
742 partial_inner_nodes.contains(&digest) || empty_subtree_roots.contains(&digest),
743 "failed at idx {idx}"
744 );
745 }
746 }
747 }
748
749 #[test]
751 fn partial_smt_tracks_leaves() {
752 assert!(!PartialSmt::default().tracks_leaves());
753 }
754
755 #[test]
757 fn partial_smt_with_empty_leaves_serialization_roundtrip() {
758 let partial_smt = PartialSmt::new(rand_value());
759 assert_eq!(partial_smt, PartialSmt::read_from_bytes(&partial_smt.to_bytes()).unwrap());
760 }
761
762 #[test]
764 fn partial_smt_serialization_roundtrip() {
765 let key = rand_value();
766 let val = rand_value();
767
768 let key_1 = rand_value();
769 let val_1 = rand_value();
770
771 let key_2 = rand_value();
772 let val_2 = rand_value();
773
774 let smt: Smt = Smt::with_entries([(key, val), (key_1, val_1), (key_2, val_2)]).unwrap();
775
776 let partial_smt = PartialSmt::from_proofs([smt.open(&key)]).unwrap();
777
778 assert_eq!(partial_smt.root(), smt.root());
779 assert_matches!(partial_smt.open(&key_1), Err(MerkleError::UntrackedKey(_)));
780 assert_matches!(partial_smt.open(&key), Ok(_));
781
782 let bytes = partial_smt.to_bytes();
783 let decoded = PartialSmt::read_from_bytes(&bytes).unwrap();
784
785 assert_eq!(partial_smt, decoded);
786 }
787
788 #[test]
792 fn partial_smt_add_proof_num_entries() {
793 let key0 = Word::from([ZERO, ZERO, ZERO, ONE]);
795 let key1 = Word::from([ONE, ONE, ONE, ONE]);
796 let key2 = Word::from([ONE, ONE, ONE, Felt::new(5)]);
797 let value0 = Word::from(rand_array::<Felt, 4>());
798 let value1 = Word::from(rand_array::<Felt, 4>());
799 let value2 = Word::from(rand_array::<Felt, 4>());
800
801 let full = Smt::with_entries([(key0, value0), (key1, value1), (key2, value2)]).unwrap();
802 let mut partial = PartialSmt::new(full.root());
803
804 partial.add_proof(full.open(&key0)).unwrap();
806 assert_eq!(partial.num_entries(), 2);
807
808 partial.add_proof(full.open(&key2)).unwrap();
810 assert_eq!(partial.num_entries(), 3);
811
812 partial.insert(key0, Word::empty()).unwrap();
814 assert_eq!(partial.num_entries(), 2);
815 }
816}