Skip to main content

miden_crypto/merkle/smt/full/
leaf.rs

1use alloc::{string::ToString, vec::Vec};
2
3use super::EMPTY_WORD;
4use crate::{
5    Felt, Word,
6    hash::poseidon2::Poseidon2,
7    merkle::smt::{LEAF_DOMAIN, LeafIndex, MAX_LEAF_ENTRIES, SMT_DEPTH, SmtLeafError},
8    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
9};
10
11/// The number of field elements in a key-value pair (two Words, 4 Felts each).
12const DOUBLE_WORD_LEN: usize = 8;
13
14/// Represents a leaf node in the Sparse Merkle Tree.
15///
16/// A leaf can be empty, hold a single key-value pair, or multiple key-value pairs.
17#[derive(Clone, Debug, PartialEq, Eq)]
18#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
19pub enum SmtLeaf {
20    /// An empty leaf at the specified index.
21    Empty(LeafIndex<SMT_DEPTH>),
22    /// A leaf containing a single key-value pair.
23    Single((Word, Word)),
24    /// A leaf containing multiple key-value pairs.
25    Multiple(Vec<(Word, Word)>),
26}
27
28impl SmtLeaf {
29    // CONSTRUCTORS
30    // ---------------------------------------------------------------------------------------------
31
32    /// Returns a new leaf with the specified entries
33    ///
34    /// # Errors
35    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
36    ///   - Returns an error if 1 or more keys in `entries` map to a leaf index different from
37    ///     `leaf_index`
38    pub fn new(
39        entries: Vec<(Word, Word)>,
40        leaf_index: LeafIndex<SMT_DEPTH>,
41    ) -> Result<Self, SmtLeafError> {
42        match entries.len() {
43            0 => Ok(Self::new_empty(leaf_index)),
44            1 => {
45                let (key, value) = entries[0];
46
47                let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
48                if computed_index != leaf_index {
49                    return Err(SmtLeafError::InconsistentSingleLeafIndices {
50                        key,
51                        expected_leaf_index: leaf_index,
52                        actual_leaf_index: computed_index,
53                    });
54                }
55
56                Ok(Self::new_single(key, value))
57            },
58            _ => {
59                let leaf = Self::new_multiple(entries)?;
60
61                // `new_multiple()` checked that all keys map to the same leaf index. We still need
62                // to ensure that leaf index is `leaf_index`.
63                if leaf.index() != leaf_index {
64                    Err(SmtLeafError::InconsistentMultipleLeafIndices {
65                        leaf_index_from_keys: leaf.index(),
66                        leaf_index_supplied: leaf_index,
67                    })
68                } else {
69                    Ok(leaf)
70                }
71            },
72        }
73    }
74
75    /// Returns a new empty leaf with the specified leaf index
76    pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
77        Self::Empty(leaf_index)
78    }
79
80    /// Returns a new single leaf with the specified entry. The leaf index is derived from the
81    /// entry's key.
82    pub fn new_single(key: Word, value: Word) -> Self {
83        Self::Single((key, value))
84    }
85
86    /// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
87    /// entries' keys.
88    ///
89    /// # Errors
90    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
91    ///   - Returns an error if the number of entries exceeds [`MAX_LEAF_ENTRIES`]
92    pub fn new_multiple(entries: Vec<(Word, Word)>) -> Result<Self, SmtLeafError> {
93        if entries.len() < 2 {
94            return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
95        }
96
97        if entries.len() > MAX_LEAF_ENTRIES {
98            return Err(SmtLeafError::TooManyLeafEntries { actual: entries.len() });
99        }
100
101        // Check that all keys map to the same leaf index
102        {
103            let mut keys = entries.iter().map(|(key, _)| key);
104
105            let first_key = *keys.next().expect("ensured at least 2 entries");
106            let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
107
108            for &next_key in keys {
109                let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
110
111                if next_leaf_index != first_leaf_index {
112                    return Err(SmtLeafError::InconsistentMultipleLeafKeys {
113                        key_1: first_key,
114                        key_2: next_key,
115                    });
116                }
117            }
118        }
119
120        Ok(Self::Multiple(entries))
121    }
122
123    // PUBLIC ACCESSORS
124    // ---------------------------------------------------------------------------------------------
125
126    /// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
127    /// leaf.
128    pub fn get_value(&self, key: &Word) -> Option<Word> {
129        // Ensure that `key` maps to this leaf
130        if self.index() != (*key).into() {
131            return None;
132        }
133
134        match self {
135            SmtLeaf::Empty(_) => Some(EMPTY_WORD),
136            SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
137                if key == key_in_leaf {
138                    Some(*value_in_leaf)
139                } else {
140                    Some(EMPTY_WORD)
141                }
142            },
143            SmtLeaf::Multiple(kv_pairs) => {
144                for (key_in_leaf, value_in_leaf) in kv_pairs {
145                    if key == key_in_leaf {
146                        return Some(*value_in_leaf);
147                    }
148                }
149
150                Some(EMPTY_WORD)
151            },
152        }
153    }
154
155    /// Returns true if the leaf is empty
156    pub fn is_empty(&self) -> bool {
157        matches!(self, Self::Empty(_))
158    }
159
160    /// Returns the leaf's index in the [`super::Smt`]
161    pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
162        match self {
163            SmtLeaf::Empty(leaf_index) => *leaf_index,
164            SmtLeaf::Single((key, _)) => (*key).into(),
165            SmtLeaf::Multiple(entries) => {
166                // Note: All keys are guaranteed to have the same leaf index
167                let (first_key, _) = entries[0];
168                first_key.into()
169            },
170        }
171    }
172
173    /// Returns the number of entries stored in the leaf
174    pub fn num_entries(&self) -> usize {
175        match self {
176            SmtLeaf::Empty(_) => 0,
177            SmtLeaf::Single(_) => 1,
178            SmtLeaf::Multiple(entries) => entries.len(),
179        }
180    }
181
182    /// Computes the hash of the leaf
183    pub fn hash(&self) -> Word {
184        match self {
185            SmtLeaf::Empty(_) => EMPTY_WORD,
186            SmtLeaf::Single((key, value)) => {
187                Poseidon2::merge_in_domain(&[*key, *value], LEAF_DOMAIN)
188            },
189            SmtLeaf::Multiple(kvs) => {
190                let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
191                Poseidon2::hash_elements_in_domain(&elements, LEAF_DOMAIN)
192            },
193        }
194    }
195
196    // ITERATORS
197    // ---------------------------------------------------------------------------------------------
198
199    /// Returns a slice with key-value pairs in the leaf.
200    pub fn entries(&self) -> &[(Word, Word)] {
201        match self {
202            SmtLeaf::Empty(_) => &[],
203            SmtLeaf::Single(kv_pair) => core::slice::from_ref(kv_pair),
204            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
205        }
206    }
207
208    // CONVERSIONS
209    // ---------------------------------------------------------------------------------------------
210
211    /// Returns an iterator over the field elements representing this leaf.
212    pub fn to_elements(&self) -> impl Iterator<Item = Felt> + '_ {
213        self.entries().iter().copied().flat_map(kv_to_elements)
214    }
215
216    /// Returns an iterator over the key-value pairs in the leaf.
217    pub fn to_entries(&self) -> impl Iterator<Item = (&Word, &Word)> + '_ {
218        // Needed for type conversion from `&(T, T)` to `(&T, &T)`.
219        self.entries().iter().map(|(k, v)| (k, v))
220    }
221
222    /// Converts a leaf to a list of field elements.
223    pub fn into_elements(self) -> Vec<Felt> {
224        self.into_entries().into_iter().flat_map(kv_to_elements).collect()
225    }
226
227    /// Converts a leaf the key-value pairs in the leaf
228    pub fn into_entries(self) -> Vec<(Word, Word)> {
229        match self {
230            SmtLeaf::Empty(_) => Vec::new(),
231            SmtLeaf::Single(kv_pair) => vec![kv_pair],
232            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
233        }
234    }
235
236    /// Converts a list of elements into a leaf
237    pub fn try_from_elements(
238        elements: &[Felt],
239        leaf_index: LeafIndex<SMT_DEPTH>,
240    ) -> Result<SmtLeaf, SmtLeafError> {
241        if elements.is_empty() {
242            return Ok(SmtLeaf::new_empty(leaf_index));
243        }
244
245        // Elements should be organized into a contiguous array of K/V Words (4 Felts each).
246        if !elements.len().is_multiple_of(DOUBLE_WORD_LEN) {
247            return Err(SmtLeafError::DecodingError(
248                "elements length is not a multiple of 8".into(),
249            ));
250        }
251
252        let num_entries = elements.len() / DOUBLE_WORD_LEN;
253
254        if num_entries == 1 {
255            // Single entry.
256            let key = Word::new([elements[0], elements[1], elements[2], elements[3]]);
257            let value = Word::new([elements[4], elements[5], elements[6], elements[7]]);
258            Ok(SmtLeaf::new_single(key, value))
259        } else {
260            // Multiple entries.
261            let mut entries = Vec::with_capacity(num_entries);
262            // Read k/v pairs from each entry.
263            for i in 0..num_entries {
264                let base_idx = i * DOUBLE_WORD_LEN;
265                let key = Word::new([
266                    elements[base_idx],
267                    elements[base_idx + 1],
268                    elements[base_idx + 2],
269                    elements[base_idx + 3],
270                ]);
271                let value = Word::new([
272                    elements[base_idx + 4],
273                    elements[base_idx + 5],
274                    elements[base_idx + 6],
275                    elements[base_idx + 7],
276                ]);
277                entries.push((key, value));
278            }
279            let leaf = SmtLeaf::new_multiple(entries)?;
280            Ok(leaf)
281        }
282    }
283
284    // HELPERS
285    // ---------------------------------------------------------------------------------------------
286
287    /// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
288    /// any.
289    ///
290    /// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
291    ///
292    /// # Errors
293    /// Returns an error if inserting the key-value pair would exceed [`MAX_LEAF_ENTRIES`] (1024
294    /// entries) in the leaf.
295    pub(in crate::merkle::smt) fn insert(
296        &mut self,
297        key: Word,
298        value: Word,
299    ) -> Result<Option<Word>, SmtLeafError> {
300        match self {
301            SmtLeaf::Empty(_) => {
302                *self = SmtLeaf::new_single(key, value);
303                Ok(None)
304            },
305            SmtLeaf::Single(kv_pair) => {
306                if kv_pair.0 == key {
307                    // the key is already in this leaf. Update the value and return the previous
308                    // value
309                    let old_value = kv_pair.1;
310                    kv_pair.1 = value;
311                    Ok(Some(old_value))
312                } else {
313                    // Another entry is present in this leaf. Transform the entry into a list
314                    // entry, and make sure the key-value pairs are sorted by key
315                    // This stays within MAX_LEAF_ENTRIES limit. We're only adding one entry to a
316                    // single leaf
317                    let mut pairs = vec![*kv_pair, (key, value)];
318                    pairs.sort_by(|(key_1, _), (key_2, _)| key_1.cmp(key_2));
319                    *self = SmtLeaf::Multiple(pairs);
320                    Ok(None)
321                }
322            },
323            SmtLeaf::Multiple(kv_pairs) => {
324                match kv_pairs.binary_search_by(|kv_pair| kv_pair.0.cmp(&key)) {
325                    Ok(pos) => {
326                        let old_value = kv_pairs[pos].1;
327                        kv_pairs[pos].1 = value;
328                        Ok(Some(old_value))
329                    },
330                    Err(pos) => {
331                        if kv_pairs.len() >= MAX_LEAF_ENTRIES {
332                            return Err(SmtLeafError::TooManyLeafEntries {
333                                actual: kv_pairs.len() + 1,
334                            });
335                        }
336                        kv_pairs.insert(pos, (key, value));
337                        Ok(None)
338                    },
339                }
340            },
341        }
342    }
343
344    /// Removes key-value pair from the leaf stored at key; returns the previous value associated
345    /// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
346    /// empty, and must be removed from the data structure it is contained in.
347    pub(in crate::merkle::smt) fn remove(&mut self, key: Word) -> (Option<Word>, bool) {
348        match self {
349            SmtLeaf::Empty(_) => (None, false),
350            SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
351                if *key_at_leaf == key {
352                    // our key was indeed stored in the leaf, so we return the value that was stored
353                    // in it, and indicate that the leaf should be removed
354                    let old_value = *value_at_leaf;
355
356                    // Note: this is not strictly needed, since the caller is expected to drop this
357                    // `SmtLeaf` object.
358                    *self = SmtLeaf::new_empty(key.into());
359
360                    (Some(old_value), true)
361                } else {
362                    // another key is stored at leaf; nothing to update
363                    (None, false)
364                }
365            },
366            SmtLeaf::Multiple(kv_pairs) => {
367                match kv_pairs.binary_search_by(|kv_pair| kv_pair.0.cmp(&key)) {
368                    Ok(pos) => {
369                        let old_value = kv_pairs[pos].1;
370
371                        let _ = kv_pairs.remove(pos);
372                        debug_assert!(!kv_pairs.is_empty());
373
374                        if kv_pairs.len() == 1 {
375                            // convert the leaf into `Single`
376                            *self = SmtLeaf::Single(kv_pairs[0]);
377                        }
378
379                        (Some(old_value), false)
380                    },
381                    Err(_) => {
382                        // other keys are stored at leaf; nothing to update
383                        (None, false)
384                    },
385                }
386            },
387        }
388    }
389}
390
391impl Serializable for SmtLeaf {
392    fn write_into<W: ByteWriter>(&self, target: &mut W) {
393        // Write: num entries
394        self.num_entries().write_into(target);
395
396        // Write: leaf index
397        let leaf_index: u64 = self.index().position();
398        leaf_index.write_into(target);
399
400        // Write: entries
401        for (key, value) in self.entries() {
402            key.write_into(target);
403            value.write_into(target);
404        }
405    }
406}
407
408impl Deserializable for SmtLeaf {
409    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
410        // Read: num entries
411        let num_entries = source.read_usize()?;
412
413        // Read: leaf index
414        let leaf_index: LeafIndex<SMT_DEPTH> = {
415            let value = source.read_u64()?;
416            LeafIndex::new_max_depth(value)
417        };
418
419        // Read: entries using read_many_iter to avoid eager allocation
420        let entries: Vec<(Word, Word)> =
421            source.read_many_iter(num_entries)?.collect::<Result<_, _>>()?;
422
423        Self::new(entries, leaf_index)
424            .map_err(|err| DeserializationError::InvalidValue(err.to_string()))
425    }
426
427    /// Minimum serialized size: vint64 (num_entries) + u64 (leaf_index) with 0 entries.
428    fn min_serialized_size() -> usize {
429        1 + 8
430    }
431}
432
433// HELPER FUNCTIONS
434// ================================================================================================
435
436/// Converts a key-value tuple to an iterator of `Felt`s
437pub(crate) fn kv_to_elements((key, value): (Word, Word)) -> impl Iterator<Item = Felt> {
438    let key_elements = key.into_iter();
439    let value_elements = value.into_iter();
440
441    key_elements.chain(value_elements)
442}