1use alloc::vec::Vec;
4use core::{
5 fmt::{self, Display},
6 hash::Hash,
7};
8
9use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
10
11use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath};
12use crate::{EMPTY_WORD, Felt, Map, Word, hash::rpo::Rpo256};
13
14mod full;
15pub use full::{MAX_LEAF_ENTRIES, SMT_DEPTH, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError};
16
17#[cfg(feature = "concurrent")]
18mod large;
19#[cfg(feature = "internal")]
20pub use full::concurrent::{SubtreeLeaf, build_subtree_for_bench};
21#[cfg(feature = "concurrent")]
22pub use large::{
23 LargeSmt, LargeSmtError, MemoryStorage, SmtStorage, StorageUpdateParts, StorageUpdates,
24 Subtree, SubtreeError,
25};
26#[cfg(feature = "rocksdb")]
27pub use large::{RocksDbConfig, RocksDbStorage};
28
29mod large_forest;
30pub use large_forest::{
31 Backend, BackendError, Config as ForestConfig,
32 DEFAULT_MAX_HISTORY_VERSIONS as FOREST_DEFAULT_MAX_HISTORY_VERSIONS, ForestOperation,
33 InMemoryBackend as ForestInMemoryBackend, LargeSmtForest, LargeSmtForestError, LineageId,
34 MIN_HISTORY_VERSIONS as FOREST_MIN_HISTORY_VERSIONS, RootInfo, SmtForestUpdateBatch,
35 SmtUpdateBatch, TreeEntry, TreeId, TreeWithRoot, VersionId,
36};
37
38mod simple;
39pub use simple::{SimpleSmt, SimpleSmtProof};
40
41mod partial;
42pub use partial::PartialSmt;
43
44mod forest;
45pub use forest::SmtForest;
46
47pub const SMT_MIN_DEPTH: u8 = 1;
52
53pub const SMT_MAX_DEPTH: u8 = 64;
55
56type InnerNodes = Map<NodeIndex, InnerNode>;
60type Leaves<T> = Map<u64, T>;
61type NodeMutations = Map<NodeIndex, NodeMutation>;
62
63pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
83 type Key: Clone + Ord + Eq + Hash;
85 type Value: Clone + PartialEq;
87 type Leaf: Clone;
89 type Opening;
91
92 const EMPTY_VALUE: Self::Value;
94
95 const EMPTY_ROOT: Word;
97
98 fn get_path(&self, key: &Self::Key) -> SparseMerklePath {
105 let index = NodeIndex::from(Self::key_to_leaf_index(key));
106
107 SparseMerklePath::from_sized_iter(
109 index.proof_indices().map(|index| self.get_node_hash(index)),
110 )
111 .expect("failed to convert to SparseMerklePath")
112 }
113
114 fn get_node_hash(&self, index: NodeIndex) -> Word {
119 if index.is_root() {
120 return self.root();
121 }
122
123 let InnerNode { left, right } = self.get_inner_node(index.parent());
124
125 let index_is_right = index.is_value_odd();
126 if index_is_right { right } else { left }
127 }
128
129 fn open(&self, key: &Self::Key) -> Self::Opening {
132 let leaf = self.get_leaf(key);
133 let merkle_path = self.get_path(key);
134
135 Self::path_and_leaf_to_opening(merkle_path, leaf)
136 }
137
138 fn insert(&mut self, key: Self::Key, value: Self::Value) -> Result<Self::Value, MerkleError> {
145 let old_value = self.insert_value(key.clone(), value.clone())?.unwrap_or(Self::EMPTY_VALUE);
146
147 if value == old_value {
149 return Ok(value);
150 }
151
152 let leaf = self.get_leaf(&key);
153 let node_index = {
154 let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
155 leaf_index.into()
156 };
157
158 self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
159
160 Ok(old_value)
161 }
162
163 fn recompute_nodes_from_index_to_root(
166 &mut self,
167 mut index: NodeIndex,
168 node_hash_at_index: Word,
169 ) {
170 let mut node_hash = node_hash_at_index;
171 for node_depth in (0..index.depth()).rev() {
172 let is_right = index.is_value_odd();
173 index.move_up();
174 let InnerNode { left, right } = self.get_inner_node(index);
175 let (left, right) = if is_right {
176 (left, node_hash)
177 } else {
178 (node_hash, right)
179 };
180 node_hash = Rpo256::merge(&[left, right]);
181
182 if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
183 self.remove_inner_node(index);
186 } else {
187 self.insert_inner_node(index, InnerNode { left, right });
188 }
189 }
190 self.set_root(node_hash);
191 }
192
193 fn compute_mutations(
206 &self,
207 kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
208 ) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError> {
209 self.compute_mutations_sequential(kv_pairs)
210 }
211
212 fn compute_mutations_sequential(
215 &self,
216 kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
217 ) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError> {
218 use NodeMutation::*;
219
220 let mut new_root = self.root();
221 let mut new_pairs: Map<Self::Key, Self::Value> = Default::default();
222 let mut node_mutations: NodeMutations = Default::default();
223
224 for (key, value) in kv_pairs {
225 let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
229 if value == old_value {
230 continue;
231 }
232
233 let leaf_index = Self::key_to_leaf_index(&key);
234 let mut node_index = NodeIndex::from(leaf_index);
235
236 let old_leaf = {
240 let pairs_at_index = new_pairs
241 .iter()
242 .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
243
244 pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
245 let existing_leaf = acc.clone();
248 self.construct_prospective_leaf(existing_leaf, k, v)
249 .expect("current leaf should be valid")
250 })
251 };
252
253 let new_leaf =
254 self.construct_prospective_leaf(old_leaf, &key, &value).map_err(|e| match e {
255 SmtLeafError::TooManyLeafEntries { actual } => {
256 MerkleError::TooManyLeafEntries { actual }
257 },
258 other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
259 })?;
260
261 let mut new_child_hash = Self::hash_leaf(&new_leaf);
262
263 for node_depth in (0..node_index.depth()).rev() {
264 let is_right = node_index.is_value_odd();
266 node_index.move_up();
267
268 let old_node = node_mutations
269 .get(&node_index)
270 .map(|mutation| match mutation {
271 Addition(node) => node.clone(),
272 Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
273 })
274 .unwrap_or_else(|| self.get_inner_node(node_index));
275
276 let new_node = if is_right {
277 InnerNode {
278 left: old_node.left,
279 right: new_child_hash,
280 }
281 } else {
282 InnerNode {
283 left: new_child_hash,
284 right: old_node.right,
285 }
286 };
287
288 new_child_hash = new_node.hash();
290
291 let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
292 let is_removal = new_child_hash == equivalent_empty_hash;
293 let new_entry = if is_removal { Removal } else { Addition(new_node) };
294 node_mutations.insert(node_index, new_entry);
295 }
296
297 new_root = new_child_hash;
299 new_pairs.insert(key, value);
301 }
302
303 Ok(MutationSet {
304 old_root: self.root(),
305 new_root,
306 node_mutations,
307 new_pairs,
308 })
309 }
310
311 fn apply_mutations(
322 &mut self,
323 mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
324 ) -> Result<(), MerkleError>
325 where
326 Self: Sized,
327 {
328 use NodeMutation::*;
329 let MutationSet {
330 old_root,
331 node_mutations,
332 new_pairs,
333 new_root,
334 } = mutations;
335
336 if old_root != self.root() {
339 return Err(MerkleError::ConflictingRoots {
340 expected_root: self.root(),
341 actual_root: old_root,
342 });
343 }
344
345 for (index, mutation) in node_mutations {
346 match mutation {
347 Removal => {
348 self.remove_inner_node(index);
349 },
350 Addition(node) => {
351 self.insert_inner_node(index, node);
352 },
353 }
354 }
355
356 for (key, value) in new_pairs {
357 self.insert_value(key, value)?;
358 }
359
360 self.set_root(new_root);
361
362 Ok(())
363 }
364
365 fn apply_mutations_with_reversion(
375 &mut self,
376 mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
377 ) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
378 where
379 Self: Sized,
380 {
381 use NodeMutation::*;
382 let MutationSet {
383 old_root,
384 node_mutations,
385 new_pairs,
386 new_root,
387 } = mutations;
388
389 if old_root != self.root() {
392 return Err(MerkleError::ConflictingRoots {
393 expected_root: self.root(),
394 actual_root: old_root,
395 });
396 }
397
398 let mut reverse_mutations = NodeMutations::new();
399 for (index, mutation) in node_mutations {
400 match mutation {
401 Removal => {
402 if let Some(node) = self.remove_inner_node(index) {
403 reverse_mutations.insert(index, Addition(node));
404 }
405 },
406 Addition(node) => {
407 if let Some(old_node) = self.insert_inner_node(index, node) {
408 reverse_mutations.insert(index, Addition(old_node));
409 } else {
410 reverse_mutations.insert(index, Removal);
411 }
412 },
413 }
414 }
415
416 let mut reverse_pairs = Map::new();
417 for (key, value) in new_pairs {
418 match self.insert_value(key.clone(), value)? {
419 Some(old_value) => {
420 reverse_pairs.insert(key, old_value);
421 },
422 None => {
423 reverse_pairs.insert(key, Self::EMPTY_VALUE);
424 },
425 }
426 }
427
428 self.set_root(new_root);
429
430 Ok(MutationSet {
431 old_root: new_root,
432 node_mutations: reverse_mutations,
433 new_pairs: reverse_pairs,
434 new_root: old_root,
435 })
436 }
437
438 fn from_raw_parts(
444 inner_nodes: InnerNodes,
445 leaves: Leaves<Self::Leaf>,
446 root: Word,
447 ) -> Result<Self, MerkleError>
448 where
449 Self: Sized;
450
451 fn root(&self) -> Word;
453
454 fn set_root(&mut self, root: Word);
456
457 fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
459
460 fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
462
463 fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
465
466 fn insert_value(
468 &mut self,
469 key: Self::Key,
470 value: Self::Value,
471 ) -> Result<Option<Self::Value>, MerkleError>;
472
473 fn get_value(&self, key: &Self::Key) -> Self::Value;
476
477 fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
479
480 fn hash_leaf(leaf: &Self::Leaf) -> Word;
482
483 fn construct_prospective_leaf(
499 &self,
500 existing_leaf: Self::Leaf,
501 key: &Self::Key,
502 value: &Self::Value,
503 ) -> Result<Self::Leaf, SmtLeafError>;
504
505 fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
507
508 fn path_and_leaf_to_opening(path: SparseMerklePath, leaf: Self::Leaf) -> Self::Opening;
512}
513
514#[doc(hidden)]
520#[derive(Debug, Default, Clone, PartialEq, Eq)]
521#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
522pub struct InnerNode {
523 pub left: Word,
524 pub right: Word,
525}
526
527impl InnerNode {
528 pub fn hash(&self) -> Word {
529 Rpo256::merge(&[self.left, self.right])
530 }
531}
532
533#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
538#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
539pub struct LeafIndex<const DEPTH: u8> {
540 index: NodeIndex,
541}
542
543impl<const DEPTH: u8> LeafIndex<DEPTH> {
544 pub fn new(value: u64) -> Result<Self, MerkleError> {
550 if DEPTH < SMT_MIN_DEPTH {
551 return Err(MerkleError::DepthTooSmall(DEPTH));
552 }
553
554 Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
555 }
556
557 pub fn position(&self) -> u64 {
559 self.index.value()
560 }
561
562 pub fn value(&self) -> u64 {
563 self.position()
564 }
565}
566
567impl LeafIndex<SMT_MAX_DEPTH> {
568 pub const fn new_max_depth(value: u64) -> Self {
570 LeafIndex {
571 index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
572 }
573 }
574}
575
576impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
577 fn from(value: LeafIndex<DEPTH>) -> Self {
578 value.index
579 }
580}
581
582impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
583 type Error = MerkleError;
584
585 fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
586 if node_index.depth() != DEPTH {
587 return Err(MerkleError::InvalidNodeIndexDepth {
588 expected: DEPTH,
589 provided: node_index.depth(),
590 });
591 }
592
593 Self::new(node_index.value())
594 }
595}
596
597impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
598 fn write_into<W: ByteWriter>(&self, target: &mut W) {
599 self.index.write_into(target);
600 }
601}
602
603impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
604 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
605 Ok(Self { index: source.read()? })
606 }
607}
608
609impl<const DEPTH: u8> Display for LeafIndex<DEPTH> {
610 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
611 write!(f, "DEPTH={}, position={}", DEPTH, self.position())
612 }
613}
614
615#[derive(Debug, Clone, PartialEq, Eq)]
622pub enum NodeMutation {
623 Removal,
625 Addition(InnerNode),
627}
628
629#[derive(Debug, Clone, Default, PartialEq, Eq)]
633pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
634 old_root: Word,
638 node_mutations: NodeMutations,
644 new_pairs: Map<K, V>,
649 new_root: Word,
652}
653
654impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
655 pub fn root(&self) -> Word {
658 self.new_root
659 }
660
661 pub fn old_root(&self) -> Word {
663 self.old_root
664 }
665
666 pub fn node_mutations(&self) -> &NodeMutations {
668 &self.node_mutations
669 }
670
671 pub fn new_pairs(&self) -> &Map<K, V> {
674 &self.new_pairs
675 }
676
677 pub fn is_empty(&self) -> bool {
678 self.node_mutations.is_empty()
679 && self.new_pairs.is_empty()
680 && self.old_root == self.new_root
681 }
682}
683
684impl Serializable for InnerNode {
688 fn write_into<W: ByteWriter>(&self, target: &mut W) {
689 target.write(self.left);
690 target.write(self.right);
691 }
692}
693
694impl Deserializable for InnerNode {
695 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
696 let left = source.read()?;
697 let right = source.read()?;
698
699 Ok(Self { left, right })
700 }
701}
702
703impl Serializable for NodeMutation {
704 fn write_into<W: ByteWriter>(&self, target: &mut W) {
705 match self {
706 NodeMutation::Removal => target.write_bool(false),
707 NodeMutation::Addition(inner_node) => {
708 target.write_bool(true);
709 inner_node.write_into(target);
710 },
711 }
712 }
713}
714
715impl Deserializable for NodeMutation {
716 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
717 if source.read_bool()? {
718 let inner_node = source.read()?;
719 return Ok(NodeMutation::Addition(inner_node));
720 }
721
722 Ok(NodeMutation::Removal)
723 }
724}
725
726impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
727 for MutationSet<DEPTH, K, V>
728{
729 fn write_into<W: ByteWriter>(&self, target: &mut W) {
730 target.write(self.old_root);
731 target.write(self.new_root);
732
733 let inner_removals: Vec<_> = self
734 .node_mutations
735 .iter()
736 .filter(|(_, value)| matches!(value, NodeMutation::Removal))
737 .map(|(key, _)| key)
738 .collect();
739 let inner_additions: Vec<_> = self
740 .node_mutations
741 .iter()
742 .filter_map(|(key, value)| match value {
743 NodeMutation::Addition(node) => Some((key, node)),
744 _ => None,
745 })
746 .collect();
747
748 target.write(inner_removals);
749 target.write(inner_additions);
750
751 target.write_usize(self.new_pairs.len());
752 target.write_many(&self.new_pairs);
753 }
754}
755
756impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
757 for MutationSet<DEPTH, K, V>
758{
759 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
760 let old_root = source.read()?;
761 let new_root = source.read()?;
762
763 let inner_removals: Vec<NodeIndex> = source.read()?;
764 let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
765
766 let node_mutations = NodeMutations::from_iter(
767 inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
768 inner_additions
769 .into_iter()
770 .map(|(index, node)| (index, NodeMutation::Addition(node))),
771 ),
772 );
773
774 let num_new_pairs = source.read_usize()?;
775 let new_pairs = source.read_many(num_new_pairs)?;
776 let new_pairs = Map::from_iter(new_pairs);
777
778 Ok(Self {
779 old_root,
780 node_mutations,
781 new_pairs,
782 new_root,
783 })
784 }
785}