1use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
2
3use super::{LeafIndex, SMT_DEPTH};
4use crate::{
5 EMPTY_WORD, Word,
6 merkle::{
7 InnerNode, InnerNodeInfo, MerkleError, NodeIndex, Smt, SmtLeaf, SmtProof, SparseMerklePath,
8 smt::{InnerNodes, Leaves, SparseMerkleTree},
9 },
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
29#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
30pub struct PartialSmt(Smt);
31
32impl PartialSmt {
33 pub fn new(root: Word) -> Self {
40 let mut partial_smt = Self(Smt::default());
41
42 partial_smt.0.set_root(root);
43
44 partial_smt
45 }
46
47 pub fn from_proofs<I>(proofs: I) -> Result<Self, MerkleError>
57 where
58 I: IntoIterator<Item = SmtProof>,
59 {
60 let mut proofs = proofs.into_iter();
61
62 let Some(first_proof) = proofs.next() else {
63 return Ok(Self::default());
64 };
65
66 let mut partial_smt = Self::default();
70 let (path, leaf) = first_proof.into_parts();
71 let path_root = partial_smt.add_path_unchecked(leaf, path);
72 partial_smt.0.set_root(path_root);
73
74 for proof in proofs {
75 partial_smt.add_proof(proof)?;
76 }
77
78 Ok(partial_smt)
79 }
80
81 pub fn root(&self) -> Word {
86 self.0.root()
87 }
88
89 pub fn open(&self, key: &Word) -> Result<SmtProof, MerkleError> {
97 if !self.is_leaf_tracked(key) {
98 return Err(MerkleError::UntrackedKey(*key));
99 }
100
101 Ok(self.0.open(key))
102 }
103
104 pub fn get_leaf(&self, key: &Word) -> Result<SmtLeaf, MerkleError> {
111 if !self.is_leaf_tracked(key) {
112 return Err(MerkleError::UntrackedKey(*key));
113 }
114
115 Ok(self.0.get_leaf(key))
116 }
117
118 pub fn get_value(&self, key: &Word) -> Result<Word, MerkleError> {
125 if !self.is_leaf_tracked(key) {
126 return Err(MerkleError::UntrackedKey(*key));
127 }
128
129 Ok(self.0.get_value(key))
130 }
131
132 pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
151 if !self.is_leaf_tracked(&key) {
152 return Err(MerkleError::UntrackedKey(key));
153 }
154
155 let previous_value = self.0.insert(key, value)?;
156
157 if value == EMPTY_WORD {
161 let leaf_index = Smt::key_to_leaf_index(&key);
162 self.0.leaves.insert(leaf_index.value(), SmtLeaf::Empty(leaf_index));
163 }
164
165 Ok(previous_value)
166 }
167
168 pub fn add_proof(&mut self, proof: SmtProof) -> Result<(), MerkleError> {
173 let (path, leaf) = proof.into_parts();
174 self.add_path(leaf, path)
175 }
176
177 pub fn add_path(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Result<(), MerkleError> {
188 let path_root = self.add_path_unchecked(leaf, path);
189
190 if self.root() != path_root {
193 return Err(MerkleError::ConflictingRoots {
194 expected_root: self.root(),
195 actual_root: path_root,
196 });
197 }
198
199 Ok(())
200 }
201
202 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
204 self.0.inner_nodes()
205 }
206
207 pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
210 self.0.inner_node_indices()
211 }
212
213 pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
216 self.0.leaves().filter_map(
218 |(leaf_idx, leaf)| {
219 if leaf.is_empty() { None } else { Some((leaf_idx, leaf)) }
220 },
221 )
222 }
223
224 pub fn tracked_leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
228 self.0.leaves()
229 }
230
231 pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
234 self.0.entries()
235 }
236
237 pub fn num_leaves(&self) -> usize {
242 self.0.num_leaves()
243 }
244
245 pub fn num_entries(&self) -> usize {
250 self.0.num_entries()
251 }
252
253 pub fn tracks_leaves(&self) -> bool {
259 !self.0.leaves.is_empty()
260 }
261
262 fn add_path_unchecked(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Word {
272 let mut current_index = leaf.index().index;
273
274 let mut node_hash_at_current_index = leaf.hash();
275
276 let prev_entries = self
285 .0
286 .leaves
287 .get(¤t_index.value())
288 .map(|leaf| leaf.num_entries())
289 .unwrap_or(0);
290 let current_entries = leaf.num_entries();
291 self.0.leaves.insert(current_index.value(), leaf);
292
293 self.0.num_entries = self.0.num_entries + current_entries - prev_entries;
295
296 for sibling_hash in path {
297 let is_sibling_right = current_index.sibling().is_value_odd();
299
300 current_index.move_up();
302
303 let new_parent_node = if is_sibling_right {
306 InnerNode {
307 left: node_hash_at_current_index,
308 right: sibling_hash,
309 }
310 } else {
311 InnerNode {
312 left: sibling_hash,
313 right: node_hash_at_current_index,
314 }
315 };
316
317 self.0.insert_inner_node(current_index, new_parent_node);
318
319 node_hash_at_current_index = self.0.get_inner_node(current_index).hash();
320 }
321
322 node_hash_at_current_index
323 }
324
325 fn is_leaf_tracked(&self, key: &Word) -> bool {
330 self.0.leaves.contains_key(&Smt::key_to_leaf_index(key).value())
331 }
332}
333
334impl Default for PartialSmt {
335 fn default() -> Self {
339 Self::new(Smt::EMPTY_ROOT)
340 }
341}
342
343impl From<Smt> for PartialSmt {
347 fn from(smt: Smt) -> Self {
348 PartialSmt(smt)
349 }
350}
351
352impl Serializable for PartialSmt {
356 fn write_into<W: ByteWriter>(&self, target: &mut W) {
357 target.write(self.root());
358 target.write_usize(self.0.leaves.len());
359 for (i, leaf) in &self.0.leaves {
360 target.write_u64(*i);
361 target.write(leaf);
362 }
363 target.write_usize(self.0.inner_nodes.len());
364 for (idx, node) in &self.0.inner_nodes {
365 target.write(idx);
366 target.write(node);
367 }
368 }
369}
370
371impl Deserializable for PartialSmt {
372 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
373 let root: Word = source.read()?;
374
375 let mut leaves = Leaves::default();
376 for _ in 0..source.read_usize()? {
377 let pos: u64 = source.read()?;
378 let leaf: SmtLeaf = source.read()?;
379 leaves.insert(pos, leaf);
380 }
381
382 let mut nodes = InnerNodes::default();
383 for _ in 0..source.read_usize()? {
384 let idx: NodeIndex = source.read()?;
385 let node: InnerNode = source.read()?;
386 nodes.insert(idx, node);
387 }
388
389 let smt = Smt::from_raw_parts(nodes, leaves, root);
390 Ok(PartialSmt(smt))
391 }
392}
393
394#[cfg(test)]
398mod tests {
399
400 use alloc::collections::{BTreeMap, BTreeSet};
401
402 use assert_matches::assert_matches;
403 use rand_utils::{rand_array, rand_value};
404 use winter_math::fields::f64::BaseElement as Felt;
405
406 use super::*;
407 use crate::{EMPTY_WORD, ONE, ZERO, merkle::EmptySubtreeRoots};
408
409 #[test]
412 fn partial_smt_new_with_no_entries() {
413 let key0 = Word::from(rand_array::<Felt, 4>());
414 let value0 = Word::from(rand_array::<Felt, 4>());
415 let full = Smt::with_entries([(key0, value0)]).unwrap();
416
417 let partial_smt = PartialSmt::new(full.root());
418
419 assert!(!partial_smt.tracks_leaves());
420 assert_eq!(partial_smt.num_entries(), 0);
421 assert_eq!(partial_smt.num_leaves(), 0);
422 assert_eq!(partial_smt.entries().count(), 0);
423 assert_eq!(partial_smt.tracked_leaves().count(), 0);
424 assert_eq!(partial_smt.root(), full.root());
425 }
426
427 #[test]
431 fn partial_smt_insert_and_remove() {
432 let key0 = Word::from(rand_array::<Felt, 4>());
433 let key1 = Word::from(rand_array::<Felt, 4>());
434 let key2 = Word::from(rand_array::<Felt, 4>());
435 let key_empty = Word::from(rand_array::<Felt, 4>());
437
438 let value0 = Word::from(rand_array::<Felt, 4>());
439 let value1 = Word::from(rand_array::<Felt, 4>());
440 let value2 = Word::from(rand_array::<Felt, 4>());
441
442 let mut kv_pairs = vec![(key0, value0), (key1, value1), (key2, value2)];
443
444 kv_pairs.reserve(1000);
446 for _ in 0..1000 {
447 let key = Word::from(rand_array::<Felt, 4>());
448 let value = Word::from(rand_array::<Felt, 4>());
449 kv_pairs.push((key, value));
450 }
451
452 let mut full = Smt::with_entries(kv_pairs).unwrap();
453
454 let proof0 = full.open(&key0);
458 let proof2 = full.open(&key2);
459 let proof_empty = full.open(&key_empty);
460
461 assert!(proof_empty.leaf().is_empty());
462
463 let mut partial = PartialSmt::from_proofs([proof0, proof2, proof_empty]).unwrap();
464
465 assert_eq!(full.root(), partial.root());
466 assert_eq!(partial.get_value(&key0).unwrap(), value0);
467 let error = partial.get_value(&key1).unwrap_err();
468 assert_matches!(error, MerkleError::UntrackedKey(_));
469 assert_eq!(partial.get_value(&key2).unwrap(), value2);
470
471 let new_value0 = Word::from(rand_array::<Felt, 4>());
475 let new_value2 = Word::from(rand_array::<Felt, 4>());
476 let new_value_empty_key = Word::from(rand_array::<Felt, 4>());
478
479 full.insert(key0, new_value0).unwrap();
480 full.insert(key2, new_value2).unwrap();
481 full.insert(key_empty, new_value_empty_key).unwrap();
482
483 partial.insert(key0, new_value0).unwrap();
484 partial.insert(key2, new_value2).unwrap();
485 partial.insert(key_empty, new_value_empty_key).unwrap();
487
488 assert_eq!(full.root(), partial.root());
489 assert_eq!(partial.get_value(&key0).unwrap(), new_value0);
490 assert_eq!(partial.get_value(&key2).unwrap(), new_value2);
491 assert_eq!(partial.get_value(&key_empty).unwrap(), new_value_empty_key);
492
493 full.insert(key0, EMPTY_WORD).unwrap();
497 partial.insert(key0, EMPTY_WORD).unwrap();
498
499 assert_eq!(full.root(), partial.root());
500 assert_eq!(partial.get_value(&key0).unwrap(), EMPTY_WORD);
501
502 assert_eq!(full.open(&key0), partial.open(&key0).unwrap());
507 assert_eq!(full.open(&key2), partial.open(&key2).unwrap());
509
510 let error = partial.clone().insert(key1, Word::from(rand_array::<Felt, 4>())).unwrap_err();
514 assert_matches!(error, MerkleError::UntrackedKey(_));
515
516 let error = partial.insert(key1, EMPTY_WORD).unwrap_err();
517 assert_matches!(error, MerkleError::UntrackedKey(_));
518 }
519
520 #[test]
522 fn partial_smt_multiple_leaf_success() {
523 let key0 = Word::from([ZERO, ZERO, ZERO, ONE]);
525 let key1 = Word::from([ONE, ONE, ONE, ONE]);
526 let key2 = Word::from(rand_array::<Felt, 4>());
527
528 let value0 = Word::from(rand_array::<Felt, 4>());
529 let value1 = Word::from(rand_array::<Felt, 4>());
530 let value2 = Word::from(rand_array::<Felt, 4>());
531
532 let full = Smt::with_entries([(key0, value0), (key1, value1), (key2, value2)]).unwrap();
533
534 let SmtLeaf::Multiple(_) = full.get_leaf(&key0) else {
536 panic!("expected full tree to produce multiple leaf")
537 };
538
539 let proof0 = full.open(&key0);
540 let proof2 = full.open(&key2);
541
542 let partial = PartialSmt::from_proofs([proof0, proof2]).unwrap();
543
544 assert_eq!(partial.root(), full.root());
545
546 assert_eq!(partial.get_leaf(&key0).unwrap(), full.get_leaf(&key0));
547 assert_eq!(partial.get_leaf(&key1).unwrap(), full.get_leaf(&key1));
549 assert_eq!(partial.get_leaf(&key2).unwrap(), full.get_leaf(&key2));
550 }
551
552 #[test]
557 fn partial_smt_root_mismatch_on_empty_values() {
558 let key0 = Word::from(rand_array::<Felt, 4>());
559 let key1 = Word::from(rand_array::<Felt, 4>());
560 let key2 = Word::from(rand_array::<Felt, 4>());
561
562 let value0 = EMPTY_WORD;
563 let value1 = Word::from(rand_array::<Felt, 4>());
564 let value2 = EMPTY_WORD;
565
566 let kv_pairs = vec![(key0, value0)];
567
568 let mut full = Smt::with_entries(kv_pairs).unwrap();
569
570 let stale_proof = full.open(&key2);
572
573 full.insert(key1, value1).unwrap();
575 full.insert(key2, value2).unwrap();
576
577 let mut partial = PartialSmt::new(full.root());
579
580 let err = partial.add_proof(stale_proof).unwrap_err();
582 assert_matches!(err, MerkleError::ConflictingRoots { .. });
583 }
584
585 #[test]
590 fn partial_smt_root_mismatch_on_non_empty_values() {
591 let key0 = Word::new(rand_array());
592 let key1 = Word::new(rand_array());
593 let key2 = Word::new(rand_array());
594
595 let value0 = Word::new(rand_array());
596 let value1 = Word::new(rand_array());
597 let value2 = Word::new(rand_array());
598
599 let kv_pairs = vec![(key0, value0), (key1, value1)];
600
601 let mut full = Smt::with_entries(kv_pairs).unwrap();
602
603 let stale_proof = full.open(&key0);
605
606 full.insert(key2, value2).unwrap();
608
609 let mut partial = PartialSmt::new(full.root());
611
612 let err = partial.add_proof(stale_proof).unwrap_err();
614 assert_matches!(err, MerkleError::ConflictingRoots { .. });
615 }
616
617 #[test]
619 fn partial_smt_from_proofs_fails_on_root_mismatch() {
620 let key0 = Word::new(rand_array());
621 let key1 = Word::new(rand_array());
622
623 let value0 = Word::new(rand_array());
624 let value1 = Word::new(rand_array());
625
626 let mut full = Smt::with_entries([(key0, value0)]).unwrap();
627
628 let stale_proof = full.open(&key0);
630
631 full.insert(key1, value1).unwrap();
633
634 let err = PartialSmt::from_proofs([full.open(&key1), stale_proof]).unwrap_err();
636 assert_matches!(err, MerkleError::ConflictingRoots { .. });
637 }
638
639 #[test]
641 fn partial_smt_iterator_apis() {
642 let key0 = Word::new(rand_array());
643 let key1 = Word::new(rand_array());
644 let key2 = Word::new(rand_array());
645 let key_empty = Word::new(rand_array());
647
648 let value0 = Word::new(rand_array());
649 let value1 = Word::new(rand_array());
650 let value2 = Word::new(rand_array());
651
652 let mut kv_pairs = vec![(key0, value0), (key1, value1), (key2, value2)];
653
654 kv_pairs.reserve(1000);
656 for _ in 0..1000 {
657 let key = Word::new(rand_array());
658 let value = Word::new(rand_array());
659 kv_pairs.push((key, value));
660 }
661
662 let full = Smt::with_entries(kv_pairs).unwrap();
663
664 let proof0 = full.open(&key0);
668 let proof2 = full.open(&key2);
669 let proof_empty = full.open(&key_empty);
670
671 assert!(proof_empty.leaf().is_empty());
672
673 let proofs = [proof0, proof2, proof_empty];
674 let partial = PartialSmt::from_proofs(proofs.clone()).unwrap();
675
676 assert!(partial.tracks_leaves());
677 assert_eq!(full.root(), partial.root());
678 assert_eq!(partial.num_entries(), 2);
680 assert_eq!(partial.num_leaves(), 3);
682
683 let expected_leaves: BTreeMap<_, _> =
688 [SmtLeaf::new_single(key0, value0), SmtLeaf::new_single(key2, value2)]
689 .into_iter()
690 .map(|leaf| (leaf.index(), leaf))
691 .collect();
692
693 let actual_leaves = partial
694 .leaves()
695 .map(|(idx, leaf)| (idx, leaf.clone()))
696 .collect::<BTreeMap<_, _>>();
697
698 assert_eq!(actual_leaves.len(), expected_leaves.len());
699 assert_eq!(actual_leaves, expected_leaves);
700
701 let mut expected_tracked_leaves = expected_leaves;
705 let empty_leaf = SmtLeaf::new_empty(LeafIndex::from(key_empty));
706 expected_tracked_leaves.insert(empty_leaf.index(), empty_leaf);
707
708 let actual_tracked_leaves = partial
709 .tracked_leaves()
710 .map(|(idx, leaf)| (idx, leaf.clone()))
711 .collect::<BTreeMap<_, _>>();
712
713 assert_eq!(actual_tracked_leaves.len(), expected_tracked_leaves.len());
714 assert_eq!(actual_tracked_leaves, expected_tracked_leaves);
715
716 let partial_inner_nodes: BTreeSet<_> =
721 partial.inner_nodes().flat_map(|node| [node.left, node.right]).collect();
722 let empty_subtree_roots: BTreeSet<_> = (0..SMT_DEPTH)
723 .map(|depth| *EmptySubtreeRoots::entry(SMT_DEPTH, depth))
724 .collect();
725
726 for merkle_path in proofs.into_iter().map(|proof| proof.into_parts().0) {
727 for (idx, digest) in merkle_path.into_iter().enumerate() {
728 assert!(
729 partial_inner_nodes.contains(&digest) || empty_subtree_roots.contains(&digest),
730 "failed at idx {idx}"
731 );
732 }
733 }
734 }
735
736 #[test]
738 fn partial_smt_tracks_leaves() {
739 assert!(!PartialSmt::default().tracks_leaves());
740 }
741
742 #[test]
744 fn partial_smt_serialization_roundtrip() {
745 let key = rand_value();
746 let val = rand_value();
747
748 let key_1 = rand_value();
749 let val_1 = rand_value();
750
751 let key_2 = rand_value();
752 let val_2 = rand_value();
753
754 let smt: Smt = Smt::with_entries([(key, val), (key_1, val_1), (key_2, val_2)]).unwrap();
755
756 let partial_smt = PartialSmt::from_proofs([smt.open(&key)]).unwrap();
757
758 assert_eq!(partial_smt.root(), smt.root());
759 assert_matches!(partial_smt.open(&key_1), Err(MerkleError::UntrackedKey(_)));
760 assert_matches!(partial_smt.open(&key), Ok(_));
761
762 let bytes = partial_smt.to_bytes();
763 let decoded = PartialSmt::read_from_bytes(&bytes).unwrap();
764
765 assert_eq!(partial_smt, decoded);
766 }
767
768 #[test]
772 fn partial_smt_add_proof_num_entries() {
773 let key0 = Word::from([ZERO, ZERO, ZERO, ONE]);
775 let key1 = Word::from([ONE, ONE, ONE, ONE]);
776 let key2 = Word::from([ONE, ONE, ONE, Felt::new(5)]);
777 let value0 = Word::from(rand_array::<Felt, 4>());
778 let value1 = Word::from(rand_array::<Felt, 4>());
779 let value2 = Word::from(rand_array::<Felt, 4>());
780
781 let full = Smt::with_entries([(key0, value0), (key1, value1), (key2, value2)]).unwrap();
782 let mut partial = PartialSmt::new(full.root());
783
784 partial.add_proof(full.open(&key0)).unwrap();
786 assert_eq!(partial.num_entries(), 2);
787
788 partial.add_proof(full.open(&key2)).unwrap();
790 assert_eq!(partial.num_entries(), 3);
791
792 partial.insert(key0, Word::empty()).unwrap();
794 assert_eq!(partial.num_entries(), 2);
795 }
796}