Skip to main content

miden_crypto/merkle/smt/full/
leaf.rs

1use alloc::{string::ToString, vec::Vec};
2use core::cmp::Ordering;
3
4use super::EMPTY_WORD;
5use crate::{
6    Felt, Word,
7    hash::poseidon2::Poseidon2,
8    merkle::smt::{LeafIndex, MAX_LEAF_ENTRIES, SMT_DEPTH, SmtLeafError},
9    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
10};
11
12/// The number of field elements in a key-value pair (two Words, 4 Felts each).
13const DOUBLE_WORD_LEN: usize = 8;
14
15/// Represents a leaf node in the Sparse Merkle Tree.
16///
17/// A leaf can be empty, hold a single key-value pair, or multiple key-value pairs.
18#[derive(Clone, Debug, PartialEq, Eq)]
19#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
20pub enum SmtLeaf {
21    /// An empty leaf at the specified index.
22    Empty(LeafIndex<SMT_DEPTH>),
23    /// A leaf containing a single key-value pair.
24    Single((Word, Word)),
25    /// A leaf containing multiple key-value pairs.
26    Multiple(Vec<(Word, Word)>),
27}
28
29impl SmtLeaf {
30    // CONSTRUCTORS
31    // ---------------------------------------------------------------------------------------------
32
33    /// Returns a new leaf with the specified entries
34    ///
35    /// # Errors
36    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
37    ///   - Returns an error if 1 or more keys in `entries` map to a leaf index different from
38    ///     `leaf_index`
39    pub fn new(
40        entries: Vec<(Word, Word)>,
41        leaf_index: LeafIndex<SMT_DEPTH>,
42    ) -> Result<Self, SmtLeafError> {
43        match entries.len() {
44            0 => Ok(Self::new_empty(leaf_index)),
45            1 => {
46                let (key, value) = entries[0];
47
48                let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
49                if computed_index != leaf_index {
50                    return Err(SmtLeafError::InconsistentSingleLeafIndices {
51                        key,
52                        expected_leaf_index: leaf_index,
53                        actual_leaf_index: computed_index,
54                    });
55                }
56
57                Ok(Self::new_single(key, value))
58            },
59            _ => {
60                let leaf = Self::new_multiple(entries)?;
61
62                // `new_multiple()` checked that all keys map to the same leaf index. We still need
63                // to ensure that leaf index is `leaf_index`.
64                if leaf.index() != leaf_index {
65                    Err(SmtLeafError::InconsistentMultipleLeafIndices {
66                        leaf_index_from_keys: leaf.index(),
67                        leaf_index_supplied: leaf_index,
68                    })
69                } else {
70                    Ok(leaf)
71                }
72            },
73        }
74    }
75
76    /// Returns a new empty leaf with the specified leaf index
77    pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
78        Self::Empty(leaf_index)
79    }
80
81    /// Returns a new single leaf with the specified entry. The leaf index is derived from the
82    /// entry's key.
83    pub fn new_single(key: Word, value: Word) -> Self {
84        Self::Single((key, value))
85    }
86
87    /// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
88    /// entries' keys.
89    ///
90    /// # Errors
91    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
92    ///   - Returns an error if the number of entries exceeds [`MAX_LEAF_ENTRIES`]
93    pub fn new_multiple(entries: Vec<(Word, Word)>) -> Result<Self, SmtLeafError> {
94        if entries.len() < 2 {
95            return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
96        }
97
98        if entries.len() > MAX_LEAF_ENTRIES {
99            return Err(SmtLeafError::TooManyLeafEntries { actual: entries.len() });
100        }
101
102        // Check that all keys map to the same leaf index
103        {
104            let mut keys = entries.iter().map(|(key, _)| key);
105
106            let first_key = *keys.next().expect("ensured at least 2 entries");
107            let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
108
109            for &next_key in keys {
110                let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
111
112                if next_leaf_index != first_leaf_index {
113                    return Err(SmtLeafError::InconsistentMultipleLeafKeys {
114                        key_1: first_key,
115                        key_2: next_key,
116                    });
117                }
118            }
119        }
120
121        Ok(Self::Multiple(entries))
122    }
123
124    // PUBLIC ACCESSORS
125    // ---------------------------------------------------------------------------------------------
126
127    /// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
128    /// leaf.
129    pub fn get_value(&self, key: &Word) -> Option<Word> {
130        // Ensure that `key` maps to this leaf
131        if self.index() != (*key).into() {
132            return None;
133        }
134
135        match self {
136            SmtLeaf::Empty(_) => Some(EMPTY_WORD),
137            SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
138                if key == key_in_leaf {
139                    Some(*value_in_leaf)
140                } else {
141                    Some(EMPTY_WORD)
142                }
143            },
144            SmtLeaf::Multiple(kv_pairs) => {
145                for (key_in_leaf, value_in_leaf) in kv_pairs {
146                    if key == key_in_leaf {
147                        return Some(*value_in_leaf);
148                    }
149                }
150
151                Some(EMPTY_WORD)
152            },
153        }
154    }
155
156    /// Returns true if the leaf is empty
157    pub fn is_empty(&self) -> bool {
158        matches!(self, Self::Empty(_))
159    }
160
161    /// Returns the leaf's index in the [`super::Smt`]
162    pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
163        match self {
164            SmtLeaf::Empty(leaf_index) => *leaf_index,
165            SmtLeaf::Single((key, _)) => (*key).into(),
166            SmtLeaf::Multiple(entries) => {
167                // Note: All keys are guaranteed to have the same leaf index
168                let (first_key, _) = entries[0];
169                first_key.into()
170            },
171        }
172    }
173
174    /// Returns the number of entries stored in the leaf
175    pub fn num_entries(&self) -> usize {
176        match self {
177            SmtLeaf::Empty(_) => 0,
178            SmtLeaf::Single(_) => 1,
179            SmtLeaf::Multiple(entries) => entries.len(),
180        }
181    }
182
183    /// Computes the hash of the leaf
184    pub fn hash(&self) -> Word {
185        match self {
186            SmtLeaf::Empty(_) => EMPTY_WORD,
187            SmtLeaf::Single((key, value)) => Poseidon2::merge(&[*key, *value]),
188            SmtLeaf::Multiple(kvs) => {
189                let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
190                Poseidon2::hash_elements(&elements)
191            },
192        }
193    }
194
195    // ITERATORS
196    // ---------------------------------------------------------------------------------------------
197
198    /// Returns a slice with key-value pairs in the leaf.
199    pub fn entries(&self) -> &[(Word, Word)] {
200        match self {
201            SmtLeaf::Empty(_) => &[],
202            SmtLeaf::Single(kv_pair) => core::slice::from_ref(kv_pair),
203            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
204        }
205    }
206
207    // CONVERSIONS
208    // ---------------------------------------------------------------------------------------------
209
210    /// Returns an iterator over the field elements representing this leaf.
211    pub fn to_elements(&self) -> impl Iterator<Item = Felt> + '_ {
212        self.entries().iter().copied().flat_map(kv_to_elements)
213    }
214
215    /// Returns an iterator over the key-value pairs in the leaf.
216    pub fn to_entries(&self) -> impl Iterator<Item = (&Word, &Word)> + '_ {
217        // Needed for type conversion from `&(T, T)` to `(&T, &T)`.
218        self.entries().iter().map(|(k, v)| (k, v))
219    }
220
221    /// Converts a leaf to a list of field elements.
222    pub fn into_elements(self) -> Vec<Felt> {
223        self.into_entries().into_iter().flat_map(kv_to_elements).collect()
224    }
225
226    /// Converts a leaf the key-value pairs in the leaf
227    pub fn into_entries(self) -> Vec<(Word, Word)> {
228        match self {
229            SmtLeaf::Empty(_) => Vec::new(),
230            SmtLeaf::Single(kv_pair) => vec![kv_pair],
231            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
232        }
233    }
234
235    /// Converts a list of elements into a leaf
236    pub fn try_from_elements(
237        elements: &[Felt],
238        leaf_index: LeafIndex<SMT_DEPTH>,
239    ) -> Result<SmtLeaf, SmtLeafError> {
240        if elements.is_empty() {
241            return Ok(SmtLeaf::new_empty(leaf_index));
242        }
243
244        // Elements should be organized into a contiguous array of K/V Words (4 Felts each).
245        if !elements.len().is_multiple_of(DOUBLE_WORD_LEN) {
246            return Err(SmtLeafError::DecodingError(
247                "elements length is not a multiple of 8".into(),
248            ));
249        }
250
251        let num_entries = elements.len() / DOUBLE_WORD_LEN;
252
253        if num_entries == 1 {
254            // Single entry.
255            let key = Word::new([elements[0], elements[1], elements[2], elements[3]]);
256            let value = Word::new([elements[4], elements[5], elements[6], elements[7]]);
257            Ok(SmtLeaf::new_single(key, value))
258        } else {
259            // Multiple entries.
260            let mut entries = Vec::with_capacity(num_entries);
261            // Read k/v pairs from each entry.
262            for i in 0..num_entries {
263                let base_idx = i * DOUBLE_WORD_LEN;
264                let key = Word::new([
265                    elements[base_idx],
266                    elements[base_idx + 1],
267                    elements[base_idx + 2],
268                    elements[base_idx + 3],
269                ]);
270                let value = Word::new([
271                    elements[base_idx + 4],
272                    elements[base_idx + 5],
273                    elements[base_idx + 6],
274                    elements[base_idx + 7],
275                ]);
276                entries.push((key, value));
277            }
278            let leaf = SmtLeaf::new_multiple(entries)?;
279            Ok(leaf)
280        }
281    }
282
283    // HELPERS
284    // ---------------------------------------------------------------------------------------------
285
286    /// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
287    /// any.
288    ///
289    /// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
290    ///
291    /// # Errors
292    /// Returns an error if inserting the key-value pair would exceed [`MAX_LEAF_ENTRIES`] (1024
293    /// entries) in the leaf.
294    pub(in crate::merkle::smt) fn insert(
295        &mut self,
296        key: Word,
297        value: Word,
298    ) -> Result<Option<Word>, SmtLeafError> {
299        match self {
300            SmtLeaf::Empty(_) => {
301                *self = SmtLeaf::new_single(key, value);
302                Ok(None)
303            },
304            SmtLeaf::Single(kv_pair) => {
305                if kv_pair.0 == key {
306                    // the key is already in this leaf. Update the value and return the previous
307                    // value
308                    let old_value = kv_pair.1;
309                    kv_pair.1 = value;
310                    Ok(Some(old_value))
311                } else {
312                    // Another entry is present in this leaf. Transform the entry into a list
313                    // entry, and make sure the key-value pairs are sorted by key
314                    // This stays within MAX_LEAF_ENTRIES limit. We're only adding one entry to a
315                    // single leaf
316                    let mut pairs = vec![*kv_pair, (key, value)];
317                    pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
318                    *self = SmtLeaf::Multiple(pairs);
319                    Ok(None)
320                }
321            },
322            SmtLeaf::Multiple(kv_pairs) => {
323                match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
324                    Ok(pos) => {
325                        let old_value = kv_pairs[pos].1;
326                        kv_pairs[pos].1 = value;
327                        Ok(Some(old_value))
328                    },
329                    Err(pos) => {
330                        if kv_pairs.len() >= MAX_LEAF_ENTRIES {
331                            return Err(SmtLeafError::TooManyLeafEntries {
332                                actual: kv_pairs.len() + 1,
333                            });
334                        }
335                        kv_pairs.insert(pos, (key, value));
336                        Ok(None)
337                    },
338                }
339            },
340        }
341    }
342
343    /// Removes key-value pair from the leaf stored at key; returns the previous value associated
344    /// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
345    /// empty, and must be removed from the data structure it is contained in.
346    pub(in crate::merkle::smt) fn remove(&mut self, key: Word) -> (Option<Word>, bool) {
347        match self {
348            SmtLeaf::Empty(_) => (None, false),
349            SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
350                if *key_at_leaf == key {
351                    // our key was indeed stored in the leaf, so we return the value that was stored
352                    // in it, and indicate that the leaf should be removed
353                    let old_value = *value_at_leaf;
354
355                    // Note: this is not strictly needed, since the caller is expected to drop this
356                    // `SmtLeaf` object.
357                    *self = SmtLeaf::new_empty(key.into());
358
359                    (Some(old_value), true)
360                } else {
361                    // another key is stored at leaf; nothing to update
362                    (None, false)
363                }
364            },
365            SmtLeaf::Multiple(kv_pairs) => {
366                match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
367                    Ok(pos) => {
368                        let old_value = kv_pairs[pos].1;
369
370                        let _ = kv_pairs.remove(pos);
371                        debug_assert!(!kv_pairs.is_empty());
372
373                        if kv_pairs.len() == 1 {
374                            // convert the leaf into `Single`
375                            *self = SmtLeaf::Single(kv_pairs[0]);
376                        }
377
378                        (Some(old_value), false)
379                    },
380                    Err(_) => {
381                        // other keys are stored at leaf; nothing to update
382                        (None, false)
383                    },
384                }
385            },
386        }
387    }
388}
389
390impl Serializable for SmtLeaf {
391    fn write_into<W: ByteWriter>(&self, target: &mut W) {
392        // Write: num entries
393        self.num_entries().write_into(target);
394
395        // Write: leaf index
396        let leaf_index: u64 = self.index().position();
397        leaf_index.write_into(target);
398
399        // Write: entries
400        for (key, value) in self.entries() {
401            key.write_into(target);
402            value.write_into(target);
403        }
404    }
405}
406
407impl Deserializable for SmtLeaf {
408    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
409        // Read: num entries
410        let num_entries = source.read_usize()?;
411
412        // Read: leaf index
413        let leaf_index: LeafIndex<SMT_DEPTH> = {
414            let value = source.read_u64()?;
415            LeafIndex::new_max_depth(value)
416        };
417
418        // Read: entries using read_many_iter to avoid eager allocation
419        let entries: Vec<(Word, Word)> =
420            source.read_many_iter(num_entries)?.collect::<Result<_, _>>()?;
421
422        Self::new(entries, leaf_index)
423            .map_err(|err| DeserializationError::InvalidValue(err.to_string()))
424    }
425
426    /// Minimum serialized size: vint64 (num_entries) + u64 (leaf_index) with 0 entries.
427    fn min_serialized_size() -> usize {
428        1 + 8
429    }
430}
431
432// HELPER FUNCTIONS
433// ================================================================================================
434
435/// Converts a key-value tuple to an iterator of `Felt`s
436pub(crate) fn kv_to_elements((key, value): (Word, Word)) -> impl Iterator<Item = Felt> {
437    let key_elements = key.into_iter();
438    let value_elements = value.into_iter();
439
440    key_elements.chain(value_elements)
441}
442
443/// Compares two keys, compared element-by-element using their integer representations starting with
444/// the most significant element.
445pub(crate) fn cmp_keys(key_1: Word, key_2: Word) -> Ordering {
446    for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
447        let v1 = (*v1).as_canonical_u64();
448        let v2 = (*v2).as_canonical_u64();
449        if v1 != v2 {
450            return v1.cmp(&v2);
451        }
452    }
453
454    Ordering::Equal
455}