1use alloc::vec::Vec;
2use core::hash::Hash;
3
4use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
7use crate::{
8 EMPTY_WORD, Felt, Word,
9 hash::rpo::{Rpo256, RpoDigest},
10};
11
12mod full;
13pub use full::{SMT_DEPTH, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError};
14#[cfg(feature = "internal")]
15pub use full::{SubtreeLeaf, build_subtree_for_bench};
16
17mod simple;
18pub use simple::SimpleSmt;
19
20mod partial;
21pub use partial::PartialSmt;
22
23pub const SMT_MIN_DEPTH: u8 = 1;
28
29pub const SMT_MAX_DEPTH: u8 = 64;
31
32#[cfg(feature = "smt_hashmaps")]
37type UnorderedMap<K, V> = hashbrown::HashMap<K, V>;
38#[cfg(not(feature = "smt_hashmaps"))]
39type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
40type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
41type Leaves<T> = UnorderedMap<u64, T>;
42type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
43
44pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
64 type Key: Clone + Ord + Eq + Hash;
66 type Value: Clone + PartialEq;
68 type Leaf: Clone;
70 type Opening;
72
73 const EMPTY_VALUE: Self::Value;
75
76 const EMPTY_ROOT: RpoDigest;
78
79 fn open(&self, key: &Self::Key) -> Self::Opening {
85 let leaf = self.get_leaf(key);
86
87 let mut index: NodeIndex = {
88 let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
89 leaf_index.into()
90 };
91
92 let merkle_path = {
93 let mut path = Vec::with_capacity(index.depth() as usize);
94 for _ in 0..index.depth() {
95 let is_right = index.is_value_odd();
96 index.move_up();
97 let InnerNode { left, right } = self.get_inner_node(index);
98 let value = if is_right { left } else { right };
99 path.push(value);
100 }
101
102 MerklePath::new(path)
103 };
104
105 Self::path_and_leaf_to_opening(merkle_path, leaf)
106 }
107
108 fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
115 let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
116
117 if value == old_value {
119 return value;
120 }
121
122 let leaf = self.get_leaf(&key);
123 let node_index = {
124 let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
125 leaf_index.into()
126 };
127
128 self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
129
130 old_value
131 }
132
133 fn recompute_nodes_from_index_to_root(
136 &mut self,
137 mut index: NodeIndex,
138 node_hash_at_index: RpoDigest,
139 ) {
140 let mut node_hash = node_hash_at_index;
141 for node_depth in (0..index.depth()).rev() {
142 let is_right = index.is_value_odd();
143 index.move_up();
144 let InnerNode { left, right } = self.get_inner_node(index);
145 let (left, right) = if is_right {
146 (left, node_hash)
147 } else {
148 (node_hash, right)
149 };
150 node_hash = Rpo256::merge(&[left, right]);
151
152 if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
153 self.remove_inner_node(index);
156 } else {
157 self.insert_inner_node(index, InnerNode { left, right });
158 }
159 }
160 self.set_root(node_hash);
161 }
162
163 fn compute_mutations(
172 &self,
173 kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
174 ) -> MutationSet<DEPTH, Self::Key, Self::Value> {
175 self.compute_mutations_sequential(kv_pairs)
176 }
177
178 fn compute_mutations_sequential(
181 &self,
182 kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
183 ) -> MutationSet<DEPTH, Self::Key, Self::Value> {
184 use NodeMutation::*;
185
186 let mut new_root = self.root();
187 let mut new_pairs: UnorderedMap<Self::Key, Self::Value> = Default::default();
188 let mut node_mutations: NodeMutations = Default::default();
189
190 for (key, value) in kv_pairs {
191 let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
195 if value == old_value {
196 continue;
197 }
198
199 let leaf_index = Self::key_to_leaf_index(&key);
200 let mut node_index = NodeIndex::from(leaf_index);
201
202 let old_leaf = {
206 let pairs_at_index = new_pairs
207 .iter()
208 .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
209
210 pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
211 let existing_leaf = acc.clone();
214 self.construct_prospective_leaf(existing_leaf, k, v)
215 })
216 };
217
218 let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
219
220 let mut new_child_hash = Self::hash_leaf(&new_leaf);
221
222 for node_depth in (0..node_index.depth()).rev() {
223 let is_right = node_index.is_value_odd();
225 node_index.move_up();
226
227 let old_node = node_mutations
228 .get(&node_index)
229 .map(|mutation| match mutation {
230 Addition(node) => node.clone(),
231 Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
232 })
233 .unwrap_or_else(|| self.get_inner_node(node_index));
234
235 let new_node = if is_right {
236 InnerNode {
237 left: old_node.left,
238 right: new_child_hash,
239 }
240 } else {
241 InnerNode {
242 left: new_child_hash,
243 right: old_node.right,
244 }
245 };
246
247 new_child_hash = new_node.hash();
249
250 let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
251 let is_removal = new_child_hash == equivalent_empty_hash;
252 let new_entry = if is_removal { Removal } else { Addition(new_node) };
253 node_mutations.insert(node_index, new_entry);
254 }
255
256 new_root = new_child_hash;
258 new_pairs.insert(key, value);
260 }
261
262 MutationSet {
263 old_root: self.root(),
264 new_root,
265 node_mutations,
266 new_pairs,
267 }
268 }
269
270 fn apply_mutations(
279 &mut self,
280 mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
281 ) -> Result<(), MerkleError>
282 where
283 Self: Sized,
284 {
285 use NodeMutation::*;
286 let MutationSet {
287 old_root,
288 node_mutations,
289 new_pairs,
290 new_root,
291 } = mutations;
292
293 if old_root != self.root() {
296 return Err(MerkleError::ConflictingRoots {
297 expected_root: self.root(),
298 actual_root: old_root,
299 });
300 }
301
302 for (index, mutation) in node_mutations {
303 match mutation {
304 Removal => {
305 self.remove_inner_node(index);
306 },
307 Addition(node) => {
308 self.insert_inner_node(index, node);
309 },
310 }
311 }
312
313 for (key, value) in new_pairs {
314 self.insert_value(key, value);
315 }
316
317 self.set_root(new_root);
318
319 Ok(())
320 }
321
322 fn apply_mutations_with_reversion(
332 &mut self,
333 mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
334 ) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
335 where
336 Self: Sized,
337 {
338 use NodeMutation::*;
339 let MutationSet {
340 old_root,
341 node_mutations,
342 new_pairs,
343 new_root,
344 } = mutations;
345
346 if old_root != self.root() {
349 return Err(MerkleError::ConflictingRoots {
350 expected_root: self.root(),
351 actual_root: old_root,
352 });
353 }
354
355 let mut reverse_mutations = NodeMutations::new();
356 for (index, mutation) in node_mutations {
357 match mutation {
358 Removal => {
359 if let Some(node) = self.remove_inner_node(index) {
360 reverse_mutations.insert(index, Addition(node));
361 }
362 },
363 Addition(node) => {
364 if let Some(old_node) = self.insert_inner_node(index, node) {
365 reverse_mutations.insert(index, Addition(old_node));
366 } else {
367 reverse_mutations.insert(index, Removal);
368 }
369 },
370 }
371 }
372
373 let mut reverse_pairs = UnorderedMap::new();
374 for (key, value) in new_pairs {
375 if let Some(old_value) = self.insert_value(key.clone(), value) {
376 reverse_pairs.insert(key, old_value);
377 } else {
378 reverse_pairs.insert(key, Self::EMPTY_VALUE);
379 }
380 }
381
382 self.set_root(new_root);
383
384 Ok(MutationSet {
385 old_root: new_root,
386 node_mutations: reverse_mutations,
387 new_pairs: reverse_pairs,
388 new_root: old_root,
389 })
390 }
391
392 fn from_raw_parts(
398 inner_nodes: InnerNodes,
399 leaves: Leaves<Self::Leaf>,
400 root: RpoDigest,
401 ) -> Result<Self, MerkleError>
402 where
403 Self: Sized;
404
405 fn root(&self) -> RpoDigest;
407
408 fn set_root(&mut self, root: RpoDigest);
410
411 fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
413
414 fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
416
417 fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
419
420 fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
422
423 fn get_value(&self, key: &Self::Key) -> Self::Value;
426
427 fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
429
430 fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
432
433 fn construct_prospective_leaf(
445 &self,
446 existing_leaf: Self::Leaf,
447 key: &Self::Key,
448 value: &Self::Value,
449 ) -> Self::Leaf;
450
451 fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
453
454 fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
458}
459
460#[doc(hidden)]
466#[derive(Debug, Default, Clone, PartialEq, Eq)]
467#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
468pub struct InnerNode {
469 pub left: RpoDigest,
470 pub right: RpoDigest,
471}
472
473impl InnerNode {
474 pub fn hash(&self) -> RpoDigest {
475 Rpo256::merge(&[self.left, self.right])
476 }
477}
478
479#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
484#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
485pub struct LeafIndex<const DEPTH: u8> {
486 index: NodeIndex,
487}
488
489impl<const DEPTH: u8> LeafIndex<DEPTH> {
490 pub fn new(value: u64) -> Result<Self, MerkleError> {
491 if DEPTH < SMT_MIN_DEPTH {
492 return Err(MerkleError::DepthTooSmall(DEPTH));
493 }
494
495 Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
496 }
497
498 pub fn value(&self) -> u64 {
499 self.index.value()
500 }
501}
502
503impl LeafIndex<SMT_MAX_DEPTH> {
504 pub const fn new_max_depth(value: u64) -> Self {
505 LeafIndex {
506 index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
507 }
508 }
509}
510
511impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
512 fn from(value: LeafIndex<DEPTH>) -> Self {
513 value.index
514 }
515}
516
517impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
518 type Error = MerkleError;
519
520 fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
521 if node_index.depth() != DEPTH {
522 return Err(MerkleError::InvalidNodeIndexDepth {
523 expected: DEPTH,
524 provided: node_index.depth(),
525 });
526 }
527
528 Self::new(node_index.value())
529 }
530}
531
532impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
533 fn write_into<W: ByteWriter>(&self, target: &mut W) {
534 self.index.write_into(target);
535 }
536}
537
538impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
539 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
540 Ok(Self { index: source.read()? })
541 }
542}
543
544#[derive(Debug, Clone, PartialEq, Eq)]
551pub enum NodeMutation {
552 Removal,
554 Addition(InnerNode),
556}
557
558#[derive(Debug, Clone, Default, PartialEq, Eq)]
562pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
563 old_root: RpoDigest,
567 node_mutations: NodeMutations,
573 new_pairs: UnorderedMap<K, V>,
578 new_root: RpoDigest,
581}
582
583impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
584 pub fn root(&self) -> RpoDigest {
587 self.new_root
588 }
589
590 pub fn old_root(&self) -> RpoDigest {
592 self.old_root
593 }
594
595 pub fn node_mutations(&self) -> &NodeMutations {
597 &self.node_mutations
598 }
599
600 pub fn new_pairs(&self) -> &UnorderedMap<K, V> {
603 &self.new_pairs
604 }
605}
606
607impl Serializable for InnerNode {
611 fn write_into<W: ByteWriter>(&self, target: &mut W) {
612 target.write(self.left);
613 target.write(self.right);
614 }
615}
616
617impl Deserializable for InnerNode {
618 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
619 let left = source.read()?;
620 let right = source.read()?;
621
622 Ok(Self { left, right })
623 }
624}
625
626impl Serializable for NodeMutation {
627 fn write_into<W: ByteWriter>(&self, target: &mut W) {
628 match self {
629 NodeMutation::Removal => target.write_bool(false),
630 NodeMutation::Addition(inner_node) => {
631 target.write_bool(true);
632 inner_node.write_into(target);
633 },
634 }
635 }
636}
637
638impl Deserializable for NodeMutation {
639 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
640 if source.read_bool()? {
641 let inner_node = source.read()?;
642 return Ok(NodeMutation::Addition(inner_node));
643 }
644
645 Ok(NodeMutation::Removal)
646 }
647}
648
649impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
650 for MutationSet<DEPTH, K, V>
651{
652 fn write_into<W: ByteWriter>(&self, target: &mut W) {
653 target.write(self.old_root);
654 target.write(self.new_root);
655
656 let inner_removals: Vec<_> = self
657 .node_mutations
658 .iter()
659 .filter(|(_, value)| matches!(value, NodeMutation::Removal))
660 .map(|(key, _)| key)
661 .collect();
662 let inner_additions: Vec<_> = self
663 .node_mutations
664 .iter()
665 .filter_map(|(key, value)| match value {
666 NodeMutation::Addition(node) => Some((key, node)),
667 _ => None,
668 })
669 .collect();
670
671 target.write(inner_removals);
672 target.write(inner_additions);
673
674 target.write_usize(self.new_pairs.len());
675 target.write_many(&self.new_pairs);
676 }
677}
678
679impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
680 for MutationSet<DEPTH, K, V>
681{
682 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
683 let old_root = source.read()?;
684 let new_root = source.read()?;
685
686 let inner_removals: Vec<NodeIndex> = source.read()?;
687 let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
688
689 let node_mutations = NodeMutations::from_iter(
690 inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
691 inner_additions
692 .into_iter()
693 .map(|(index, node)| (index, NodeMutation::Addition(node))),
694 ),
695 );
696
697 let num_new_pairs = source.read_usize()?;
698 let new_pairs = source.read_many(num_new_pairs)?;
699 let new_pairs = UnorderedMap::from_iter(new_pairs);
700
701 Ok(Self {
702 old_root,
703 node_mutations,
704 new_pairs,
705 new_root,
706 })
707 }
708}