1use alloc::{string::ToString, vec::Vec};
2
3use super::{
4 EMPTY_WORD, EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex,
5 MerkleError, MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word,
6};
7
8mod error;
9pub use error::{SmtLeafError, SmtProofError};
10
11mod leaf;
12pub use leaf::SmtLeaf;
13
14mod proof;
15pub use proof::SmtProof;
16use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
17
18#[cfg(feature = "concurrent")]
20mod concurrent;
21#[cfg(feature = "internal")]
22pub use concurrent::{SubtreeLeaf, build_subtree_for_bench};
23
24#[cfg(test)]
25mod tests;
26
27pub const SMT_DEPTH: u8 = 64;
31
32type Leaves = super::Leaves<SmtLeaf>;
36
37#[derive(Debug, Clone, PartialEq, Eq)]
47#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
48pub struct Smt {
49 root: RpoDigest,
50 pub(super) leaves: Leaves,
52 inner_nodes: InnerNodes,
53}
54
55impl Smt {
56 pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<SMT_DEPTH>>::EMPTY_VALUE;
60
61 pub fn new() -> Self {
68 let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
69
70 Self {
71 root,
72 inner_nodes: Default::default(),
73 leaves: Default::default(),
74 }
75 }
76
77 pub fn with_entries(
87 entries: impl IntoIterator<Item = (RpoDigest, Word)>,
88 ) -> Result<Self, MerkleError> {
89 #[cfg(feature = "concurrent")]
90 {
91 Self::with_entries_concurrent(entries)
92 }
93 #[cfg(not(feature = "concurrent"))]
94 {
95 Self::with_entries_sequential(entries)
96 }
97 }
98
99 #[cfg(any(not(feature = "concurrent"), fuzzing, test))]
107 fn with_entries_sequential(
108 entries: impl IntoIterator<Item = (RpoDigest, Word)>,
109 ) -> Result<Self, MerkleError> {
110 use alloc::collections::BTreeSet;
111
112 let mut tree = Self::new();
114
115 let mut key_set_to_zero = BTreeSet::new();
118
119 for (key, value) in entries {
120 let old_value = tree.insert(key, value);
121
122 if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
123 return Err(MerkleError::DuplicateValuesForIndex(
124 LeafIndex::<SMT_DEPTH>::from(key).value(),
125 ));
126 }
127
128 if value == EMPTY_WORD {
129 key_set_to_zero.insert(key);
130 };
131 }
132 Ok(tree)
133 }
134
135 pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
144 <Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
146 }
147
148 pub const fn depth(&self) -> u8 {
153 SMT_DEPTH
154 }
155
156 pub fn root(&self) -> RpoDigest {
158 <Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
159 }
160
161 pub fn num_leaves(&self) -> usize {
166 self.leaves.len()
167 }
168
169 pub fn num_entries(&self) -> usize {
177 self.entries().count()
178 }
179
180 pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
182 <Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
183 }
184
185 pub fn get_value(&self, key: &RpoDigest) -> Word {
187 <Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key)
188 }
189
190 pub fn open(&self, key: &RpoDigest) -> SmtProof {
193 <Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
194 }
195
196 pub fn is_empty(&self) -> bool {
198 debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
199 self.root == Self::EMPTY_ROOT
200 }
201
202 pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
207 self.leaves
208 .iter()
209 .map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
210 }
211
212 pub fn entries(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
214 self.leaves().flat_map(|(_, leaf)| leaf.entries())
215 }
216
217 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
219 self.inner_nodes.values().map(|e| InnerNodeInfo {
220 value: e.hash(),
221 left: e.left,
222 right: e.right,
223 })
224 }
225
226 pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
236 <Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
237 }
238
239 pub fn compute_mutations(
260 &self,
261 kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
262 ) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
263 #[cfg(feature = "concurrent")]
264 {
265 self.compute_mutations_concurrent(kv_pairs)
266 }
267 #[cfg(not(feature = "concurrent"))]
268 {
269 <Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
270 }
271 }
272
273 pub fn apply_mutations(
281 &mut self,
282 mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
283 ) -> Result<(), MerkleError> {
284 <Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
285 }
286
287 pub fn apply_mutations_with_reversion(
298 &mut self,
299 mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
300 ) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
301 <Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
302 }
303
304 fn perform_insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
310 debug_assert_ne!(value, Self::EMPTY_VALUE);
311
312 let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
313
314 match self.leaves.get_mut(&leaf_index.value()) {
315 Some(leaf) => leaf.insert(key, value),
316 None => {
317 self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
318
319 None
320 },
321 }
322 }
323
324 fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
326 let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
327
328 if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
329 let (old_value, is_empty) = leaf.remove(key);
330 if is_empty {
331 self.leaves.remove(&leaf_index.value());
332 }
333 old_value
334 } else {
335 None
337 }
338 }
339}
340
341impl SparseMerkleTree<SMT_DEPTH> for Smt {
342 type Key = RpoDigest;
343 type Value = Word;
344 type Leaf = SmtLeaf;
345 type Opening = SmtProof;
346
347 const EMPTY_VALUE: Self::Value = EMPTY_WORD;
348 const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
349
350 fn from_raw_parts(
351 inner_nodes: InnerNodes,
352 leaves: Leaves,
353 root: RpoDigest,
354 ) -> Result<Self, MerkleError> {
355 if cfg!(debug_assertions) {
356 let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
357 assert_eq!(root_node.hash(), root);
358 }
359
360 Ok(Self { root, inner_nodes, leaves })
361 }
362
363 fn root(&self) -> RpoDigest {
364 self.root
365 }
366
367 fn set_root(&mut self, root: RpoDigest) {
368 self.root = root;
369 }
370
371 fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
372 self.inner_nodes
373 .get(&index)
374 .cloned()
375 .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
376 }
377
378 fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
379 self.inner_nodes.insert(index, inner_node)
380 }
381
382 fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
383 self.inner_nodes.remove(&index)
384 }
385
386 fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
387 if value != Self::EMPTY_VALUE {
389 self.perform_insert(key, value)
390 } else {
391 self.perform_remove(key)
392 }
393 }
394
395 fn get_value(&self, key: &Self::Key) -> Self::Value {
396 let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
397
398 match self.leaves.get(&leaf_pos) {
399 Some(leaf) => leaf.get_value(key).unwrap_or_default(),
400 None => EMPTY_WORD,
401 }
402 }
403
404 fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
405 let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
406
407 match self.leaves.get(&leaf_pos) {
408 Some(leaf) => leaf.clone(),
409 None => SmtLeaf::new_empty(key.into()),
410 }
411 }
412
413 fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest {
414 leaf.hash()
415 }
416
417 fn construct_prospective_leaf(
418 &self,
419 mut existing_leaf: SmtLeaf,
420 key: &RpoDigest,
421 value: &Word,
422 ) -> SmtLeaf {
423 debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
424
425 match existing_leaf {
426 SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
427 _ => {
428 if *value != EMPTY_WORD {
429 existing_leaf.insert(*key, *value);
430 } else {
431 existing_leaf.remove(*key);
432 }
433
434 existing_leaf
435 },
436 }
437 }
438
439 fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
440 let most_significant_felt = key[3];
441 LeafIndex::new_max_depth(most_significant_felt.as_int())
442 }
443
444 fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
445 SmtProof::new_unchecked(path, leaf)
446 }
447}
448
449impl Default for Smt {
450 fn default() -> Self {
451 Self::new()
452 }
453}
454
455impl From<Word> for LeafIndex<SMT_DEPTH> {
459 fn from(value: Word) -> Self {
460 Self::new_max_depth(value[3].as_int())
462 }
463}
464
465impl From<RpoDigest> for LeafIndex<SMT_DEPTH> {
466 fn from(value: RpoDigest) -> Self {
467 Word::from(value).into()
468 }
469}
470
471impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
472 fn from(value: &RpoDigest) -> Self {
473 Word::from(value).into()
474 }
475}
476
477impl Serializable for Smt {
481 fn write_into<W: ByteWriter>(&self, target: &mut W) {
482 target.write_usize(self.entries().count());
484
485 for (key, value) in self.entries() {
487 target.write(key);
488 target.write(value);
489 }
490 }
491
492 fn get_size_hint(&self) -> usize {
493 let entries_count = self.entries().count();
494
495 entries_count.get_size_hint()
497 + entries_count * (RpoDigest::SERIALIZED_SIZE + EMPTY_WORD.get_size_hint())
498 }
499}
500
501impl Deserializable for Smt {
502 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
503 let num_filled_leaves = source.read_usize()?;
505 let mut entries = Vec::with_capacity(num_filled_leaves);
506
507 for _ in 0..num_filled_leaves {
508 let key = source.read()?;
509 let value = source.read()?;
510 entries.push((key, value));
511 }
512
513 Self::with_entries(entries)
514 .map_err(|err| DeserializationError::InvalidValue(err.to_string()))
515 }
516}
517
518#[cfg(fuzzing)]
522impl Smt {
523 pub fn fuzz_with_entries_sequential(
524 entries: impl IntoIterator<Item = (RpoDigest, Word)>,
525 ) -> Result<Smt, MerkleError> {
526 Self::with_entries_sequential(entries)
527 }
528
529 pub fn fuzz_compute_mutations_sequential(
530 &self,
531 kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
532 ) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
533 <Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
534 }
535}
536
537#[test]
541fn test_smt_serialization_deserialization() {
542 let smt_default = Smt::default();
544 let bytes = smt_default.to_bytes();
545 assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap());
546 assert_eq!(bytes.len(), smt_default.get_size_hint());
547
548 let smt_leaves_2: [(RpoDigest, Word); 2] = [
550 (
551 RpoDigest::new([Felt::new(101), Felt::new(102), Felt::new(103), Felt::new(104)]),
552 [Felt::new(1_u64), Felt::new(2_u64), Felt::new(3_u64), Felt::new(4_u64)],
553 ),
554 (
555 RpoDigest::new([Felt::new(105), Felt::new(106), Felt::new(107), Felt::new(108)]),
556 [Felt::new(5_u64), Felt::new(6_u64), Felt::new(7_u64), Felt::new(8_u64)],
557 ),
558 ];
559 let smt = Smt::with_entries(smt_leaves_2).unwrap();
560
561 let bytes = smt.to_bytes();
562 assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap());
563 assert_eq!(bytes.len(), smt.get_size_hint());
564}